Skip to content

Commit 09353cf

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: add update_corpus method for vertex rag
PiperOrigin-RevId: 683781828
1 parent b1d5007 commit 09353cf

File tree

5 files changed

+238
-2
lines changed

5 files changed

+238
-2
lines changed

tests/unit/vertex_rag/conftest.py

+2
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ def rag_data_client_mock_exception():
7979
api_client_mock = mock.Mock(spec=VertexRagDataServiceClient)
8080
# create_rag_corpus
8181
api_client_mock.create_rag_corpus.side_effect = Exception
82+
# update_rag_corpus
83+
api_client_mock.update_rag_corpus.side_effect = Exception
8284
# get_rag_corpus
8385
api_client_mock.get_rag_corpus.side_effect = Exception
8486
# list_rag_corpora

tests/unit/vertex_rag/test_rag_data.py

+141
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,70 @@ def create_rag_corpus_mock_pinecone():
111111
yield create_rag_corpus_mock_pinecone
112112

113113

114+
@pytest.fixture
115+
def update_rag_corpus_mock_weaviate():
116+
with mock.patch.object(
117+
VertexRagDataServiceClient,
118+
"update_rag_corpus",
119+
) as update_rag_corpus_mock_weaviate:
120+
update_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation)
121+
update_rag_corpus_lro_mock.done.return_value = True
122+
update_rag_corpus_lro_mock.result.return_value = (
123+
tc.TEST_GAPIC_RAG_CORPUS_WEAVIATE
124+
)
125+
update_rag_corpus_mock_weaviate.return_value = update_rag_corpus_lro_mock
126+
yield update_rag_corpus_mock_weaviate
127+
128+
129+
@pytest.fixture
130+
def update_rag_corpus_mock_vertex_feature_store():
131+
with mock.patch.object(
132+
VertexRagDataServiceClient,
133+
"update_rag_corpus",
134+
) as update_rag_corpus_mock_vertex_feature_store:
135+
update_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation)
136+
update_rag_corpus_lro_mock.done.return_value = True
137+
update_rag_corpus_lro_mock.result.return_value = (
138+
tc.TEST_GAPIC_RAG_CORPUS_VERTEX_FEATURE_STORE
139+
)
140+
update_rag_corpus_mock_vertex_feature_store.return_value = (
141+
update_rag_corpus_lro_mock
142+
)
143+
yield update_rag_corpus_mock_vertex_feature_store
144+
145+
146+
@pytest.fixture
147+
def update_rag_corpus_mock_vertex_vector_search():
148+
with mock.patch.object(
149+
VertexRagDataServiceClient,
150+
"update_rag_corpus",
151+
) as update_rag_corpus_mock_vertex_vector_search:
152+
update_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation)
153+
update_rag_corpus_lro_mock.done.return_value = True
154+
update_rag_corpus_lro_mock.result.return_value = (
155+
tc.TEST_GAPIC_RAG_CORPUS_VERTEX_VECTOR_SEARCH
156+
)
157+
update_rag_corpus_mock_vertex_vector_search.return_value = (
158+
update_rag_corpus_lro_mock
159+
)
160+
yield update_rag_corpus_mock_vertex_vector_search
161+
162+
163+
@pytest.fixture
164+
def update_rag_corpus_mock_pinecone():
165+
with mock.patch.object(
166+
VertexRagDataServiceClient,
167+
"update_rag_corpus",
168+
) as update_rag_corpus_mock_pinecone:
169+
update_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation)
170+
update_rag_corpus_lro_mock.done.return_value = True
171+
update_rag_corpus_lro_mock.result.return_value = (
172+
tc.TEST_GAPIC_RAG_CORPUS_PINECONE
173+
)
174+
update_rag_corpus_mock_pinecone.return_value = update_rag_corpus_lro_mock
175+
yield update_rag_corpus_mock_pinecone
176+
177+
114178
@pytest.fixture
115179
def list_rag_corpora_pager_mock():
116180
with mock.patch.object(
@@ -298,6 +362,83 @@ def test_create_corpus_failure(self):
298362
rag.create_corpus(display_name=tc.TEST_CORPUS_DISPLAY_NAME)
299363
e.match("Failed in RagCorpus creation due to")
300364

365+
@pytest.mark.usefixtures("update_rag_corpus_mock_weaviate")
366+
def test_update_corpus_weaviate_success(self):
367+
rag_corpus = rag.update_corpus(
368+
corpus_name=tc.TEST_RAG_CORPUS_RESOURCE_NAME,
369+
display_name=tc.TEST_CORPUS_DISPLAY_NAME,
370+
vector_db=tc.TEST_WEAVIATE_CONFIG,
371+
)
372+
373+
rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS_WEAVIATE)
374+
375+
@pytest.mark.usefixtures("update_rag_corpus_mock_weaviate")
376+
def test_update_corpus_weaviate_no_display_name_success(self):
377+
rag_corpus = rag.update_corpus(
378+
corpus_name=tc.TEST_RAG_CORPUS_RESOURCE_NAME,
379+
vector_db=tc.TEST_WEAVIATE_CONFIG,
380+
)
381+
382+
rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS_WEAVIATE)
383+
384+
@pytest.mark.usefixtures("update_rag_corpus_mock_weaviate")
385+
def test_update_corpus_weaviate_with_description_success(self):
386+
rag_corpus = rag.update_corpus(
387+
corpus_name=tc.TEST_RAG_CORPUS_RESOURCE_NAME,
388+
description=tc.TEST_CORPUS_DISCRIPTION,
389+
vector_db=tc.TEST_WEAVIATE_CONFIG,
390+
)
391+
392+
rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS_WEAVIATE)
393+
394+
@pytest.mark.usefixtures("update_rag_corpus_mock_weaviate")
395+
def test_update_corpus_weaviate_with_description_and_display_name_success(self):
396+
rag_corpus = rag.update_corpus(
397+
corpus_name=tc.TEST_RAG_CORPUS_RESOURCE_NAME,
398+
description=tc.TEST_CORPUS_DISCRIPTION,
399+
display_name=tc.TEST_CORPUS_DISPLAY_NAME,
400+
vector_db=tc.TEST_WEAVIATE_CONFIG,
401+
)
402+
403+
rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS_WEAVIATE)
404+
405+
@pytest.mark.usefixtures("update_rag_corpus_mock_vertex_feature_store")
406+
def test_update_corpus_vertex_feature_store_success(self):
407+
rag_corpus = rag.update_corpus(
408+
corpus_name=tc.TEST_RAG_CORPUS_RESOURCE_NAME,
409+
display_name=tc.TEST_CORPUS_DISPLAY_NAME,
410+
vector_db=tc.TEST_VERTEX_FEATURE_STORE_CONFIG,
411+
)
412+
413+
rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS_VERTEX_FEATURE_STORE)
414+
415+
@pytest.mark.usefixtures("update_rag_corpus_mock_vertex_vector_search")
416+
def test_update_corpus_vertex_vector_search_success(self):
417+
rag_corpus = rag.update_corpus(
418+
corpus_name=tc.TEST_RAG_CORPUS_RESOURCE_NAME,
419+
display_name=tc.TEST_CORPUS_DISPLAY_NAME,
420+
vector_db=tc.TEST_VERTEX_VECTOR_SEARCH_CONFIG,
421+
)
422+
rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS_VERTEX_VECTOR_SEARCH)
423+
424+
@pytest.mark.usefixtures("update_rag_corpus_mock_pinecone")
425+
def test_update_corpus_pinecone_success(self):
426+
rag_corpus = rag.update_corpus(
427+
corpus_name=tc.TEST_RAG_CORPUS_RESOURCE_NAME,
428+
display_name=tc.TEST_CORPUS_DISPLAY_NAME,
429+
vector_db=tc.TEST_PINECONE_CONFIG,
430+
)
431+
rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS_PINECONE)
432+
433+
@pytest.mark.usefixtures("rag_data_client_mock_exception")
434+
def test_update_corpus_failure(self):
435+
with pytest.raises(RuntimeError) as e:
436+
rag.update_corpus(
437+
corpus_name=tc.TEST_RAG_CORPUS_RESOURCE_NAME,
438+
display_name=tc.TEST_CORPUS_DISPLAY_NAME,
439+
)
440+
e.match("Failed in RagCorpus update due to")
441+
301442
@pytest.mark.usefixtures("rag_data_client_mock")
302443
def test_get_corpus_success(self):
303444
rag_corpus = rag.get_corpus(tc.TEST_RAG_CORPUS_RESOURCE_NAME)

vertexai/preview/rag/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from vertexai.preview.rag.rag_data import (
1919
create_corpus,
20+
update_corpus,
2021
list_corpora,
2122
get_corpus,
2223
delete_corpus,
@@ -84,4 +85,5 @@
8485
"list_files",
8586
"retrieval_query",
8687
"upload_file",
88+
"update_corpus",
8789
)

vertexai/preview/rag/rag_data.py

+80-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#
1717
"""RAG data management SDK."""
1818

19-
from typing import Optional, Union, Sequence
19+
from typing import Optional, Sequence, Union
2020
from google import auth
2121
from google.api_core import operation_async
2222
from google.auth.transport import requests as google_auth_requests
@@ -33,8 +33,8 @@
3333
ListRagCorporaRequest,
3434
ListRagFilesRequest,
3535
RagCorpus as GapicRagCorpus,
36+
UpdateRagCorpusRequest,
3637
)
37-
3838
from google.cloud.aiplatform_v1beta1.services.vertex_rag_data_service.pagers import (
3939
ListRagCorporaPager,
4040
ListRagFilesPager,
@@ -121,6 +121,84 @@ def create_corpus(
121121
return _gapic_utils.convert_gapic_to_rag_corpus(response.result(timeout=600))
122122

123123

124+
def update_corpus(
125+
corpus_name: str,
126+
display_name: Optional[str] = None,
127+
description: Optional[str] = None,
128+
vector_db: Optional[
129+
Union[
130+
Weaviate,
131+
VertexFeatureStore,
132+
VertexVectorSearch,
133+
Pinecone,
134+
RagManagedDb,
135+
]
136+
] = None,
137+
) -> RagCorpus:
138+
"""Updates a RagCorpus resource.
139+
140+
Example usage:
141+
```
142+
import vertexai
143+
from vertexai.preview import rag
144+
145+
vertexai.init(project="my-project")
146+
147+
rag_corpus = rag.update_corpus(
148+
corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1",
149+
display_name="my-corpus-1",
150+
)
151+
```
152+
153+
Args:
154+
corpus_name: The name of the RagCorpus resource to update. Format:
155+
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` or
156+
``{rag_corpus}``.
157+
display_name: If not provided, the display name will not be updated. The
158+
display name of the RagCorpus. The name can be up to 128 characters long
159+
and can consist of any UTF-8 characters.
160+
description: The description of the RagCorpus. If not provided, the
161+
description will not be updated.
162+
vector_db: The vector db config of the RagCorpus. If not provided, the
163+
vector db will not be updated.
164+
165+
Returns:
166+
RagCorpus.
167+
Raises:
168+
RuntimeError: Failed in RagCorpus update due to exception.
169+
RuntimeError: Failed in RagCorpus update due to operation error.
170+
"""
171+
corpus_name = _gapic_utils.get_corpus_name(corpus_name)
172+
if display_name and description:
173+
rag_corpus = GapicRagCorpus(
174+
name=corpus_name, display_name=display_name, description=description
175+
)
176+
elif display_name:
177+
rag_corpus = GapicRagCorpus(name=corpus_name, display_name=display_name)
178+
elif description:
179+
rag_corpus = GapicRagCorpus(name=corpus_name, description=description)
180+
else:
181+
rag_corpus = GapicRagCorpus(name=corpus_name)
182+
183+
_gapic_utils.set_vector_db(
184+
vector_db=vector_db,
185+
rag_corpus=rag_corpus,
186+
)
187+
188+
request = UpdateRagCorpusRequest(
189+
rag_corpus=rag_corpus,
190+
)
191+
client = _gapic_utils.create_rag_data_service_client()
192+
193+
try:
194+
response = client.update_rag_corpus(request=request)
195+
except Exception as e:
196+
raise RuntimeError("Failed in RagCorpus update due to: ", e) from e
197+
return _gapic_utils.convert_gapic_to_rag_corpus_no_embedding_model_config(
198+
response.result(timeout=600)
199+
)
200+
201+
124202
def get_corpus(name: str) -> RagCorpus:
125203
"""
126204
Get an existing RagCorpus.

vertexai/preview/rag/utils/_gapic_utils.py

+13
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,19 @@ def convert_gapic_to_rag_corpus(gapic_rag_corpus: GapicRagCorpus) -> RagCorpus:
180180
return rag_corpus
181181

182182

183+
def convert_gapic_to_rag_corpus_no_embedding_model_config(
184+
gapic_rag_corpus: GapicRagCorpus,
185+
) -> RagCorpus:
186+
"""Convert GapicRagCorpus without embedding model config (for UpdateRagCorpus) to RagCorpus."""
187+
rag_corpus = RagCorpus(
188+
name=gapic_rag_corpus.name,
189+
display_name=gapic_rag_corpus.display_name,
190+
description=gapic_rag_corpus.description,
191+
vector_db=convert_gapic_to_vector_db(gapic_rag_corpus.rag_vector_db_config),
192+
)
193+
return rag_corpus
194+
195+
183196
def convert_gapic_to_rag_file(gapic_rag_file: GapicRagFile) -> RagFile:
184197
"""Convert GapicRagFile to RagFile."""
185198
rag_file = RagFile(

0 commit comments

Comments
 (0)