Skip to content

Commit 26657ff

Browse files
yinghsienwucopybara-github
authored andcommitted
feat: Automatically populate parents for full resource name in Vertex RAG SDK
PiperOrigin-RevId: 629849569
1 parent 2d19137 commit 26657ff

File tree

3 files changed

+112
-15
lines changed

3 files changed

+112
-15
lines changed

tests/unit/vertex_rag/test_rag_data.py

+26-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
import importlib
1818
from google.api_core import operation as ga_operation
1919
from vertexai.preview import rag
20-
from vertexai.preview.rag.utils._gapic_utils import prepare_import_files_request
20+
from vertexai.preview.rag.utils._gapic_utils import (
21+
prepare_import_files_request,
22+
)
2123
from google.cloud.aiplatform_v1beta1 import (
2224
VertexRagDataServiceAsyncClient,
2325
VertexRagDataServiceClient,
@@ -184,6 +186,11 @@ def test_get_corpus_success(self):
184186
rag_corpus = rag.get_corpus(tc.TEST_RAG_CORPUS_RESOURCE_NAME)
185187
rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS)
186188

189+
@pytest.mark.usefixtures("rag_data_client_mock")
190+
def test_get_corpus_id_success(self):
191+
rag_corpus = rag.get_corpus(tc.TEST_RAG_CORPUS_ID)
192+
rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS)
193+
187194
@pytest.mark.usefixtures("rag_data_client_mock_exception")
188195
def test_get_corpus_failure(self):
189196
with pytest.raises(RuntimeError) as e:
@@ -208,7 +215,11 @@ def test_list_corpora_failure(self):
208215

209216
def test_delete_corpus_success(self, rag_data_client_mock):
210217
rag.delete_corpus(tc.TEST_RAG_CORPUS_RESOURCE_NAME)
211-
rag_data_client_mock.assert_called_once()
218+
assert rag_data_client_mock.call_count == 2
219+
220+
def test_delete_corpus_id_success(self, rag_data_client_mock):
221+
rag.delete_corpus(tc.TEST_RAG_CORPUS_ID)
222+
assert rag_data_client_mock.call_count == 2
212223

213224
@pytest.mark.usefixtures("rag_data_client_mock_exception")
214225
def test_delete_corpus_failure(self):
@@ -311,6 +322,13 @@ def test_get_file_success(self):
311322
rag_file = rag.get_file(tc.TEST_RAG_FILE_RESOURCE_NAME)
312323
rag_file_eq(rag_file, tc.TEST_RAG_FILE)
313324

325+
@pytest.mark.usefixtures("rag_data_client_mock")
326+
def test_get_file_id_success(self):
327+
rag_file = rag.get_file(
328+
name=tc.TEST_RAG_FILE_ID, corpus_name=tc.TEST_RAG_CORPUS_ID
329+
)
330+
rag_file_eq(rag_file, tc.TEST_RAG_FILE)
331+
314332
@pytest.mark.usefixtures("rag_data_client_mock_exception")
315333
def test_get_file_failure(self):
316334
with pytest.raises(RuntimeError) as e:
@@ -333,7 +351,12 @@ def test_list_files_failure(self):
333351

334352
def test_delete_file_success(self, rag_data_client_mock):
335353
rag.delete_file(tc.TEST_RAG_FILE_RESOURCE_NAME)
336-
rag_data_client_mock.assert_called_once()
354+
assert rag_data_client_mock.call_count == 2
355+
356+
def test_delete_file_id_success(self, rag_data_client_mock):
357+
rag.delete_file(name=tc.TEST_RAG_FILE_ID, corpus_name=tc.TEST_RAG_CORPUS_ID)
358+
# Passing corpus_name will result in 3 calls to rag_data_client
359+
assert rag_data_client_mock.call_count == 3
337360

338361
@pytest.mark.usefixtures("rag_data_client_mock_exception")
339362
def test_delete_file_failure(self):

vertexai/preview/rag/rag_data.py

+37-12
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from google.cloud.aiplatform_v1beta1.services.vertex_rag_data_service.pagers import (
3838
ListRagCorporaPager,
3939
ListRagFilesPager,
40-
4140
)
4241
from vertexai.preview.rag.utils import (
4342
_gapic_utils,
@@ -100,10 +99,12 @@ def get_corpus(name: str) -> RagCorpus:
10099
Args:
101100
name: An existing RagCorpus resource name. Format:
102101
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
102+
or ``{rag_corpus}``.
103103
Returns:
104104
RagCorpus.
105105
"""
106-
request = GetRagCorpusRequest(name=name)
106+
corpus_name = _gapic_utils.get_corpus_name(name)
107+
request = GetRagCorpusRequest(name=corpus_name)
107108
client = _gapic_utils.create_rag_data_service_client()
108109
try:
109110
response = client.get_rag_corpus(request=request)
@@ -163,8 +164,10 @@ def delete_corpus(name: str) -> None:
163164
Args:
164165
name: An existing RagCorpus resource name. Format:
165166
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
167+
or ``{rag_corpus}``.
166168
"""
167-
request = DeleteRagCorpusRequest(name=name)
169+
corpus_name = _gapic_utils.get_corpus_name(name)
170+
request = DeleteRagCorpusRequest(name=corpus_name)
168171

169172
client = _gapic_utils.create_rag_data_service_client()
170173
try:
@@ -200,7 +203,8 @@ def upload_file(
200203
201204
Args:
202205
corpus_name: The name of the RagCorpus resource into which to upload the file.
203-
Format: ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
206+
Format: ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
207+
or ``{rag_corpus}``.
204208
path: A local file path. For example,
205209
"usr/home/my_file.txt".
206210
display_name: The display name of the data file.
@@ -212,6 +216,7 @@ def upload_file(
212216
ValueError: RagCorpus is not found.
213217
RuntimeError: Failed in indexing the RagFile.
214218
"""
219+
corpus_name = _gapic_utils.get_corpus_name(corpus_name)
215220
location = initializer.global_config.location
216221
# GAPIC doesn't expose a path (scotty). Use requests API instead
217222
if display_name is None:
@@ -286,6 +291,7 @@ def import_files(
286291
Args:
287292
corpus_name: The name of the RagCorpus resource into which to import files.
288293
Format: ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
294+
or ``{rag_corpus}``.
289295
paths: A list of uris. Elligible uris will be Google Cloud Storage
290296
directory ("gs://my-bucket/my_dir") or a Google Drive url for file
291297
(https://drive.google.com/file/... or folder
@@ -296,7 +302,7 @@ def import_files(
296302
Returns:
297303
ImportRagFilesResponse.
298304
"""
299-
305+
corpus_name = _gapic_utils.get_corpus_name(corpus_name)
300306
request = _gapic_utils.prepare_import_files_request(
301307
corpus_name=corpus_name,
302308
paths=paths,
@@ -347,6 +353,7 @@ async def import_files_async(
347353
Args:
348354
corpus_name: The name of the RagCorpus resource into which to import files.
349355
Format: ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
356+
or ``{rag_corpus}``.
350357
paths: A list of uris. Elligible uris will be Google Cloud Storage
351358
directory ("gs://my-bucket/my_dir") or a Google Drive url for file
352359
(https://drive.google.com/file/... or folder
@@ -356,7 +363,7 @@ async def import_files_async(
356363
Returns:
357364
operation_async.AsyncOperation.
358365
"""
359-
366+
corpus_name = _gapic_utils.get_corpus_name(corpus_name)
360367
request = _gapic_utils.prepare_import_files_request(
361368
corpus_name=corpus_name,
362369
paths=paths,
@@ -371,16 +378,24 @@ async def import_files_async(
371378
return response
372379

373380

374-
def get_file(name: str) -> RagFile:
381+
def get_file(name: str, corpus_name: Optional[str] = None) -> RagFile:
375382
"""
376383
Get an existing RagFile.
377384
378385
Args:
379-
name: A RagFile resource name. Format:
380-
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}``
386+
name: Either a full RagFile resource name must be provided, or a RagCorpus
387+
name and a RagFile name must be provided. Format:
388+
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}``
389+
or ``{rag_file}``.
390+
corpus_name: If `name` is not a full resource name, an existing RagCorpus
391+
name must be provided. Format:
392+
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
393+
or ``{rag_corpus}``.
381394
Returns:
382395
RagFile.
383396
"""
397+
corpus_name = _gapic_utils.get_corpus_name(corpus_name)
398+
name = _gapic_utils.get_file_name(name, corpus_name)
384399
request = GetRagFileRequest(name=name)
385400
client = _gapic_utils.create_rag_data_service_client()
386401
try:
@@ -423,13 +438,15 @@ def list_files(
423438
424439
Args:
425440
corpus_name: An existing RagCorpus name. Format:
426-
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
441+
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
442+
or ``{rag_corpus}``.
427443
page_size: The standard list page size. Leaving out the page_size
428444
causes all of the results to be returned.
429445
page_token: The standard list page token.
430446
Returns:
431447
ListRagFilesPager.
432448
"""
449+
corpus_name = _gapic_utils.get_corpus_name(corpus_name)
433450
request = ListRagFilesRequest(
434451
parent=corpus_name,
435452
page_size=page_size,
@@ -444,14 +461,22 @@ def list_files(
444461
return pager
445462

446463

447-
def delete_file(name: str) -> None:
464+
def delete_file(name: str, corpus_name: Optional[str] = None) -> None:
448465
"""
449466
Delete RagFile from an existing RagCorpus.
450467
451468
Args:
452-
name: A RagFile resource name. Format:
469+
name: Either a full RagFile resource name must be provided, or a RagCorpus
470+
name and a RagFile name must be provided. Format:
453471
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}``
472+
or ``{rag_file}``.
473+
corpus_name: If `name` is not a full resource name, an existing RagCorpus
474+
name must be provided. Format:
475+
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}``
476+
or ``{rag_corpus}``.
454477
"""
478+
corpus_name = _gapic_utils.get_corpus_name(corpus_name)
479+
name = _gapic_utils.get_file_name(name, corpus_name)
455480
request = DeleteRagFileRequest(name=name)
456481

457482
client = _gapic_utils.create_rag_data_service_client()

vertexai/preview/rag/utils/_gapic_utils.py

+49
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17+
import re
1718
from typing import Any, Dict, Sequence, Union
1819
from google.cloud.aiplatform_v1beta1 import (
1920
GoogleDriveSource,
@@ -35,6 +36,9 @@
3536
)
3637

3738

39+
_VALID_RESOURCE_NAME_REGEX = "[a-z][a-zA-Z0-9._-]{0,127}"
40+
41+
3842
def create_rag_data_service_client():
3943
return initializer.global_config.create_client(
4044
client_class=VertexRagDataClientWithOverride,
@@ -153,3 +157,48 @@ def prepare_import_files_request(
153157
parent=corpus_name, import_rag_files_config=import_rag_files_config
154158
)
155159
return request
160+
161+
162+
def get_corpus_name(
163+
name: str,
164+
) -> str:
165+
if name:
166+
client = create_rag_data_service_client()
167+
if client.parse_rag_corpus_path(name):
168+
return name
169+
elif re.match("^{}$".format(_VALID_RESOURCE_NAME_REGEX), name):
170+
return client.rag_corpus_path(
171+
project=initializer.global_config.project,
172+
location=initializer.global_config.location,
173+
rag_corpus=name,
174+
)
175+
else:
176+
raise ValueError(
177+
"name must be of the format `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}` or `{rag_corpus}`"
178+
)
179+
return name
180+
181+
182+
def get_file_name(
183+
name: str,
184+
corpus_name: str,
185+
) -> str:
186+
client = create_rag_data_service_client()
187+
if client.parse_rag_file_path(name):
188+
return name
189+
elif re.match("^{}$".format(_VALID_RESOURCE_NAME_REGEX), name):
190+
if not corpus_name:
191+
raise ValueError(
192+
"corpus_name must be provided if name is a `{rag_file}`, not a "
193+
"full resource name (`projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}`). "
194+
)
195+
return client.rag_file_path(
196+
project=initializer.global_config.project,
197+
location=initializer.global_config.location,
198+
rag_corpus=get_corpus_name(corpus_name),
199+
rag_file=name,
200+
)
201+
else:
202+
raise ValueError(
203+
"name must be of the format `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}` or `{rag_file}`"
204+
)

0 commit comments

Comments
 (0)