Skip to content

Commit a7453da

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Adding Vertex AI Search Config for RAG corpuses to SDK
PiperOrigin-RevId: 741675260
1 parent f090ca1 commit a7453da

File tree

6 files changed

+277
-8
lines changed

6 files changed

+277
-8
lines changed

tests/unit/vertex_rag/test_rag_constants.py

+41
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
JiraQuery,
3939
VertexVectorSearch,
4040
RagEmbeddingModelConfig,
41+
VertexAiSearchConfig,
4142
VertexPredictionEndpoint,
4243
)
4344

@@ -57,6 +58,7 @@
5758
RagContexts,
5859
RetrieveContextsResponse,
5960
RagVectorDbConfig as GapicRagVectorDbConfig,
61+
VertexAiSearchConfig as GapicVertexAiSearchConfig,
6062
)
6163
from google.cloud.aiplatform_v1.types import api_auth
6264
from google.protobuf import timestamp_pb2
@@ -162,6 +164,45 @@
162164
)
163165
TEST_PAGE_TOKEN = "test-page-token"
164166

167+
# Vertex AI Search Config
168+
TEST_VERTEX_AI_SEARCH_ENGINE_SERVING_CONFIG = f"projects/{TEST_PROJECT_NUMBER}/locations/{TEST_REGION}/collections/test-collection/engines/test-engine/servingConfigs/test-serving-config"
169+
TEST_VERTEX_AI_SEARCH_DATASTORE_SERVING_CONFIG = f"projects/{TEST_PROJECT_NUMBER}/locations/{TEST_REGION}/collections/test-collection/dataStores/test-datastore/servingConfigs/test-serving-config"
170+
TEST_GAPIC_RAG_CORPUS_VERTEX_AI_ENGINE_SEARCH_CONFIG = GapicRagCorpus(
171+
name=TEST_RAG_CORPUS_RESOURCE_NAME,
172+
display_name=TEST_CORPUS_DISPLAY_NAME,
173+
vertex_ai_search_config=GapicVertexAiSearchConfig(
174+
serving_config=TEST_VERTEX_AI_SEARCH_ENGINE_SERVING_CONFIG,
175+
),
176+
)
177+
TEST_GAPIC_RAG_CORPUS_VERTEX_AI_DATASTORE_SEARCH_CONFIG = GapicRagCorpus(
178+
name=TEST_RAG_CORPUS_RESOURCE_NAME,
179+
display_name=TEST_CORPUS_DISPLAY_NAME,
180+
vertex_ai_search_config=GapicVertexAiSearchConfig(
181+
serving_config=TEST_VERTEX_AI_SEARCH_DATASTORE_SERVING_CONFIG,
182+
),
183+
)
184+
TEST_VERTEX_AI_SEARCH_CONFIG_ENGINE = VertexAiSearchConfig(
185+
serving_config=TEST_VERTEX_AI_SEARCH_ENGINE_SERVING_CONFIG,
186+
)
187+
TEST_VERTEX_AI_SEARCH_CONFIG_DATASTORE = VertexAiSearchConfig(
188+
serving_config=TEST_VERTEX_AI_SEARCH_DATASTORE_SERVING_CONFIG,
189+
)
190+
TEST_VERTEX_AI_SEARCH_CONFIG_INVALID = VertexAiSearchConfig(
191+
serving_config="invalid-serving-config",
192+
)
193+
TEST_VERTEX_AI_SEARCH_CONFIG_EMPTY = VertexAiSearchConfig()
194+
195+
TEST_RAG_CORPUS_VERTEX_AI_ENGINE_SEARCH_CONFIG = RagCorpus(
196+
name=TEST_RAG_CORPUS_RESOURCE_NAME,
197+
display_name=TEST_CORPUS_DISPLAY_NAME,
198+
vertex_ai_search_config=TEST_VERTEX_AI_SEARCH_CONFIG_ENGINE,
199+
)
200+
TEST_RAG_CORPUS_VERTEX_AI_DATASTORE_SEARCH_CONFIG = RagCorpus(
201+
name=TEST_RAG_CORPUS_RESOURCE_NAME,
202+
display_name=TEST_CORPUS_DISPLAY_NAME,
203+
vertex_ai_search_config=TEST_VERTEX_AI_SEARCH_CONFIG_DATASTORE,
204+
)
205+
165206
# RagFiles
166207
TEST_PATH = "usr/home/my_file.txt"
167208
TEST_GCS_PATH = "gs://usr/home/data_dir/"

tests/unit/vertex_rag/test_rag_data.py

+132
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,57 @@ def create_rag_corpus_mock_pinecone():
8585
yield create_rag_corpus_mock_pinecone
8686

8787

88+
@pytest.fixture
89+
def create_rag_corpus_mock_vertex_ai_engine_search_config():
90+
with mock.patch.object(
91+
VertexRagDataServiceClient,
92+
"create_rag_corpus",
93+
) as create_rag_corpus_mock_vertex_ai_engine_search_config:
94+
create_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation)
95+
create_rag_corpus_lro_mock.done.return_value = True
96+
create_rag_corpus_lro_mock.result.return_value = (
97+
test_rag_constants.TEST_GAPIC_RAG_CORPUS_VERTEX_AI_ENGINE_SEARCH_CONFIG
98+
)
99+
create_rag_corpus_mock_vertex_ai_engine_search_config.return_value = (
100+
create_rag_corpus_lro_mock
101+
)
102+
yield create_rag_corpus_mock_vertex_ai_engine_search_config
103+
104+
105+
@pytest.fixture
106+
def create_rag_corpus_mock_vertex_ai_datastore_search_config():
107+
with mock.patch.object(
108+
VertexRagDataServiceClient,
109+
"create_rag_corpus",
110+
) as create_rag_corpus_mock_vertex_ai_datastore_search_config:
111+
create_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation)
112+
create_rag_corpus_lro_mock.done.return_value = True
113+
create_rag_corpus_lro_mock.result.return_value = (
114+
test_rag_constants.TEST_GAPIC_RAG_CORPUS_VERTEX_AI_DATASTORE_SEARCH_CONFIG
115+
)
116+
create_rag_corpus_mock_vertex_ai_datastore_search_config.return_value = (
117+
create_rag_corpus_lro_mock
118+
)
119+
yield create_rag_corpus_mock_vertex_ai_datastore_search_config
120+
121+
122+
@pytest.fixture
123+
def update_rag_corpus_mock_vertex_ai_engine_search_config():
124+
with mock.patch.object(
125+
VertexRagDataServiceClient,
126+
"update_rag_corpus",
127+
) as update_rag_corpus_mock_vertex_ai_engine_search_config:
128+
update_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation)
129+
update_rag_corpus_lro_mock.done.return_value = True
130+
update_rag_corpus_lro_mock.result.return_value = (
131+
test_rag_constants.TEST_GAPIC_RAG_CORPUS_VERTEX_AI_ENGINE_SEARCH_CONFIG
132+
)
133+
update_rag_corpus_mock_vertex_ai_engine_search_config.return_value = (
134+
update_rag_corpus_lro_mock
135+
)
136+
yield update_rag_corpus_mock_vertex_ai_engine_search_config
137+
138+
88139
@pytest.fixture
89140
def update_rag_corpus_mock_vertex_vector_search():
90141
with mock.patch.object(
@@ -247,6 +298,9 @@ def rag_corpus_eq(returned_corpus, expected_corpus):
247298
assert returned_corpus.name == expected_corpus.name
248299
assert returned_corpus.display_name == expected_corpus.display_name
249300
assert returned_corpus.backend_config.__eq__(expected_corpus.backend_config)
301+
assert returned_corpus.vertex_ai_search_config.__eq__(
302+
expected_corpus.vertex_ai_search_config
303+
)
250304

251305

252306
def rag_file_eq(returned_file, expected_file):
@@ -328,12 +382,90 @@ def test_create_corpus_pinecone_success(self):
328382

329383
rag_corpus_eq(rag_corpus, test_rag_constants.TEST_RAG_CORPUS_PINECONE)
330384

385+
@pytest.mark.usefixtures("create_rag_corpus_mock_vertex_ai_engine_search_config")
386+
def test_create_corpus_vais_engine_search_config_success(self):
387+
rag_corpus = rag.create_corpus(
388+
display_name=test_rag_constants.TEST_CORPUS_DISPLAY_NAME,
389+
vertex_ai_search_config=test_rag_constants.TEST_VERTEX_AI_SEARCH_CONFIG_ENGINE,
390+
)
391+
392+
rag_corpus_eq(
393+
rag_corpus,
394+
test_rag_constants.TEST_RAG_CORPUS_VERTEX_AI_ENGINE_SEARCH_CONFIG,
395+
)
396+
397+
@pytest.mark.usefixtures("create_rag_corpus_mock_vertex_ai_datastore_search_config")
398+
def test_create_corpus_vais_datastore_search_config_success(self):
399+
rag_corpus = rag.create_corpus(
400+
display_name=test_rag_constants.TEST_CORPUS_DISPLAY_NAME,
401+
vertex_ai_search_config=test_rag_constants.TEST_VERTEX_AI_SEARCH_CONFIG_DATASTORE,
402+
)
403+
404+
rag_corpus_eq(
405+
rag_corpus,
406+
test_rag_constants.TEST_RAG_CORPUS_VERTEX_AI_DATASTORE_SEARCH_CONFIG,
407+
)
408+
409+
def test_create_corpus_vais_datastore_search_config_with_backend_config_failure(
410+
self,
411+
):
412+
with pytest.raises(ValueError) as e:
413+
rag.create_corpus(
414+
display_name=test_rag_constants.TEST_CORPUS_DISPLAY_NAME,
415+
vertex_ai_search_config=test_rag_constants.TEST_VERTEX_AI_SEARCH_CONFIG_DATASTORE,
416+
backend_config=test_rag_constants.TEST_BACKEND_CONFIG_VERTEX_VECTOR_SEARCH_CONFIG,
417+
)
418+
e.match("Only one of vertex_ai_search_config or backend_config can be set.")
419+
420+
def test_set_vertex_ai_search_config_with_invalid_serving_config_failure(self):
421+
with pytest.raises(ValueError) as e:
422+
rag.create_corpus(
423+
display_name=test_rag_constants.TEST_CORPUS_DISPLAY_NAME,
424+
vertex_ai_search_config=test_rag_constants.TEST_VERTEX_AI_SEARCH_CONFIG_INVALID,
425+
)
426+
e.match(
427+
"serving_config must be of the format `projects/{project}/locations/{location}/collections/{collection}/engines/{engine}/servingConfigs/{serving_config}` or `projects/{project}/locations/{location}/collections/{collection}/dataStores/{data_store}/servingConfigs/{serving_config}`"
428+
)
429+
430+
def test_set_vertex_ai_search_config_with_empty_serving_config_failure(self):
431+
with pytest.raises(ValueError) as e:
432+
rag.create_corpus(
433+
display_name=test_rag_constants.TEST_CORPUS_DISPLAY_NAME,
434+
vertex_ai_search_config=test_rag_constants.TEST_VERTEX_AI_SEARCH_CONFIG_EMPTY,
435+
)
436+
e.match("serving_config must be set.")
437+
331438
@pytest.mark.usefixtures("rag_data_client_mock_exception")
332439
def test_create_corpus_failure(self):
333440
with pytest.raises(RuntimeError) as e:
334441
rag.create_corpus(display_name=test_rag_constants.TEST_CORPUS_DISPLAY_NAME)
335442
e.match("Failed in RagCorpus creation due to")
336443

444+
@pytest.mark.usefixtures("update_rag_corpus_mock_vertex_ai_engine_search_config")
445+
def test_update_corpus_vais_engine_search_config_success(self):
446+
rag_corpus = rag.update_corpus(
447+
corpus_name=test_rag_constants.TEST_RAG_CORPUS_RESOURCE_NAME,
448+
display_name=test_rag_constants.TEST_CORPUS_DISPLAY_NAME,
449+
vertex_ai_search_config=test_rag_constants.TEST_VERTEX_AI_SEARCH_CONFIG_ENGINE,
450+
)
451+
452+
rag_corpus_eq(
453+
rag_corpus,
454+
test_rag_constants.TEST_RAG_CORPUS_VERTEX_AI_ENGINE_SEARCH_CONFIG,
455+
)
456+
457+
def test_update_corpus_vais_datastore_search_config_with_backend_config_failure(
458+
self,
459+
):
460+
with pytest.raises(ValueError) as e:
461+
rag.update_corpus(
462+
corpus_name=test_rag_constants.TEST_RAG_CORPUS_RESOURCE_NAME,
463+
display_name=test_rag_constants.TEST_CORPUS_DISPLAY_NAME,
464+
vertex_ai_search_config=test_rag_constants.TEST_VERTEX_AI_SEARCH_CONFIG_DATASTORE,
465+
backend_config=test_rag_constants.TEST_BACKEND_CONFIG_VERTEX_VECTOR_SEARCH_CONFIG,
466+
)
467+
e.match("Only one of vertex_ai_search_config or backend_config can be set.")
468+
337469
@pytest.mark.usefixtures("update_rag_corpus_mock_pinecone")
338470
def test_update_corpus_pinecone_success(self):
339471
rag_corpus = rag.update_corpus(

vertexai/rag/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
SlackChannel,
6060
SlackChannelsSource,
6161
TransformationConfig,
62+
VertexAiSearchConfig,
6263
VertexPredictionEndpoint,
6364
VertexVectorSearch,
6465
)
@@ -87,6 +88,7 @@
8788
"SlackChannel",
8889
"SlackChannelsSource",
8990
"TransformationConfig",
91+
"VertexAiSearchConfig",
9092
"VertexRagStore",
9193
"VertexPredictionEndpoint",
9294
"VertexVectorSearch",

vertexai/rag/rag_data.py

+42-8
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,15 @@
5050
RagVectorDbConfig,
5151
SharePointSources,
5252
SlackChannelsSource,
53+
VertexAiSearchConfig,
5354
TransformationConfig,
5455
)
5556

5657

5758
def create_corpus(
5859
display_name: Optional[str] = None,
5960
description: Optional[str] = None,
61+
vertex_ai_search_config: Optional[VertexAiSearchConfig] = None,
6062
backend_config: Optional[
6163
Union[
6264
RagVectorDbConfig,
@@ -83,6 +85,9 @@ def create_corpus(
8385
the RagCorpus. The name can be up to 128 characters long and can
8486
consist of any UTF-8 characters.
8587
description: The description of the RagCorpus.
88+
vertex_ai_search_config: The Vertex AI Search config of the RagCorpus.
89+
Note: backend_config cannot be set if vertex_ai_search_config is
90+
specified.
8691
backend_config: The backend config of the RagCorpus, specifying a
8792
data store and/or embedding model.
8893
Returns:
@@ -91,15 +96,27 @@ def create_corpus(
9196
RuntimeError: Failed in RagCorpus creation due to exception.
9297
RuntimeError: Failed in RagCorpus creation due to operation error.
9398
"""
99+
if vertex_ai_search_config and backend_config:
100+
raise ValueError(
101+
"Only one of vertex_ai_search_config or backend_config can be set."
102+
)
103+
94104
if not display_name:
95105
display_name = "vertex-" + utils.timestamped_unique_name()
96106
parent = initializer.global_config.common_location_path(project=None, location=None)
97107

98108
rag_corpus = GapicRagCorpus(display_name=display_name, description=description)
99-
_gapic_utils.set_backend_config(
100-
backend_config=backend_config,
101-
rag_corpus=rag_corpus,
102-
)
109+
110+
if backend_config:
111+
_gapic_utils.set_backend_config(
112+
backend_config=backend_config,
113+
rag_corpus=rag_corpus,
114+
)
115+
elif vertex_ai_search_config:
116+
_gapic_utils.set_vertex_ai_search_config(
117+
vertex_ai_search_config=vertex_ai_search_config,
118+
rag_corpus=rag_corpus,
119+
)
103120

104121
request = CreateRagCorpusRequest(
105122
parent=parent,
@@ -118,6 +135,7 @@ def update_corpus(
118135
corpus_name: str,
119136
display_name: Optional[str] = None,
120137
description: Optional[str] = None,
138+
vertex_ai_search_config: Optional[VertexAiSearchConfig] = None,
121139
backend_config: Optional[
122140
Union[
123141
RagVectorDbConfig,
@@ -149,6 +167,10 @@ def update_corpus(
149167
and can consist of any UTF-8 characters.
150168
description: The description of the RagCorpus. If not provided, the
151169
description will not be updated.
170+
vertex_ai_search_config: The Vertex AI Search config of the RagCorpus.
171+
If not provided, the Vertex AI Search config will not be updated.
172+
Note: backend_config cannot be set if vertex_ai_search_config is
173+
specified.
152174
backend_config: The backend config of the RagCorpus, specifying a
153175
data store and/or embedding model.
154176
@@ -158,6 +180,11 @@ def update_corpus(
158180
RuntimeError: Failed in RagCorpus update due to exception.
159181
RuntimeError: Failed in RagCorpus update due to operation error.
160182
"""
183+
if vertex_ai_search_config and backend_config:
184+
raise ValueError(
185+
"Only one of vertex_ai_search_config or backend_config can be set."
186+
)
187+
161188
corpus_name = _gapic_utils.get_corpus_name(corpus_name)
162189
if display_name and description:
163190
rag_corpus = GapicRagCorpus(
@@ -170,10 +197,17 @@ def update_corpus(
170197
else:
171198
rag_corpus = GapicRagCorpus(name=corpus_name)
172199

173-
_gapic_utils.set_backend_config(
174-
backend_config=backend_config,
175-
rag_corpus=rag_corpus,
176-
)
200+
if backend_config:
201+
_gapic_utils.set_backend_config(
202+
backend_config=backend_config,
203+
rag_corpus=rag_corpus,
204+
)
205+
206+
if vertex_ai_search_config:
207+
_gapic_utils.set_vertex_ai_search_config(
208+
vertex_ai_search_config=vertex_ai_search_config,
209+
rag_corpus=rag_corpus,
210+
)
177211

178212
request = UpdateRagCorpusRequest(
179213
rag_corpus=rag_corpus,

0 commit comments

Comments
 (0)