Skip to content

Commit 8b3beb6

Browse files
yinghsienwucopybara-github
authored andcommitted
feat: Add support for user-configurable 1P embedding models and quota for RAG
PiperOrigin-RevId: 642414350
1 parent cf8bc3d commit 8b3beb6

File tree

6 files changed

+223
-3
lines changed

6 files changed

+223
-3
lines changed

tests/unit/vertex_rag/test_rag_constants.py

+12
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#
1717

1818
from vertexai.preview.rag.utils.resources import (
19+
EmbeddingModelConfig,
1920
RagCorpus,
2021
RagFile,
2122
RagResource,
@@ -49,10 +50,19 @@
4950
display_name=TEST_CORPUS_DISPLAY_NAME,
5051
description=TEST_CORPUS_DISCRIPTION,
5152
)
53+
TEST_GAPIC_RAG_CORPUS.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = (
54+
"projects/{}/locations/{}/publishers/google/models/textembedding-gecko".format(
55+
TEST_PROJECT, TEST_REGION
56+
)
57+
)
58+
TEST_EMBEDDING_MODEL_CONFIG = EmbeddingModelConfig(
59+
publisher_model="publishers/google/models/textembedding-gecko",
60+
)
5261
TEST_RAG_CORPUS = RagCorpus(
5362
name=TEST_RAG_CORPUS_RESOURCE_NAME,
5463
display_name=TEST_CORPUS_DISPLAY_NAME,
5564
description=TEST_CORPUS_DISCRIPTION,
65+
embedding_model_config=TEST_EMBEDDING_MODEL_CONFIG,
5666
)
5767
TEST_PAGE_TOKEN = "test-page-token"
5868

@@ -114,6 +124,8 @@
114124
chunk_overlap=TEST_CHUNK_OVERLAP,
115125
)
116126
)
127+
TEST_IMPORT_FILES_CONFIG_DRIVE_FILE.max_embedding_requests_per_min = 800
128+
117129
TEST_IMPORT_FILES_CONFIG_DRIVE_FILE.google_drive_source.resource_ids = [
118130
GoogleDriveSource.ResourceId(
119131
resource_id=TEST_DRIVE_FILE_ID,

tests/unit/vertex_rag/test_rag_data.py

+45-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from vertexai.preview import rag
2020
from vertexai.preview.rag.utils._gapic_utils import (
2121
prepare_import_files_request,
22+
set_embedding_model_config,
2223
)
2324
from google.cloud.aiplatform_v1beta1 import (
2425
VertexRagDataServiceAsyncClient,
@@ -171,7 +172,10 @@ def teardown_method(self):
171172

172173
@pytest.mark.usefixtures("create_rag_corpus_mock")
173174
def test_create_corpus_success(self):
174-
rag_corpus = rag.create_corpus(display_name=tc.TEST_CORPUS_DISPLAY_NAME)
175+
rag_corpus = rag.create_corpus(
176+
display_name=tc.TEST_CORPUS_DISPLAY_NAME,
177+
embedding_model_config=tc.TEST_EMBEDDING_MODEL_CONFIG,
178+
)
175179

176180
rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS)
177181

@@ -391,6 +395,7 @@ def test_prepare_import_files_request_drive_files(self):
391395
paths=paths,
392396
chunk_size=tc.TEST_CHUNK_SIZE,
393397
chunk_overlap=tc.TEST_CHUNK_OVERLAP,
398+
max_embedding_requests_per_min=800,
394399
)
395400
import_files_request_eq(request, tc.TEST_IMPORT_REQUEST_DRIVE_FILE)
396401

@@ -415,3 +420,42 @@ def test_prepare_import_files_request_invalid_path(self):
415420
chunk_overlap=tc.TEST_CHUNK_OVERLAP,
416421
)
417422
e.match("path must be a Google Cloud Storage uri or a Google Drive url")
423+
424+
def test_set_embedding_model_config_set_both_error(self):
425+
embedding_model_config = rag.EmbeddingModelConfig(
426+
publisher_model="whatever",
427+
endpoint="whatever",
428+
)
429+
with pytest.raises(ValueError) as e:
430+
set_embedding_model_config(
431+
embedding_model_config,
432+
tc.TEST_GAPIC_RAG_CORPUS,
433+
)
434+
e.match("publisher_model and endpoint cannot be set at the same time")
435+
436+
def test_set_embedding_model_config_not_set_error(self):
437+
embedding_model_config = rag.EmbeddingModelConfig()
438+
with pytest.raises(ValueError) as e:
439+
set_embedding_model_config(
440+
embedding_model_config,
441+
tc.TEST_GAPIC_RAG_CORPUS,
442+
)
443+
e.match("At least one of publisher_model and endpoint must be set")
444+
445+
def test_set_embedding_model_config_wrong_publisher_model_format_error(self):
446+
embedding_model_config = rag.EmbeddingModelConfig(publisher_model="whatever")
447+
with pytest.raises(ValueError) as e:
448+
set_embedding_model_config(
449+
embedding_model_config,
450+
tc.TEST_GAPIC_RAG_CORPUS,
451+
)
452+
e.match("publisher_model must be of the format ")
453+
454+
def test_set_embedding_model_config_wrong_endpoint_format_error(self):
455+
embedding_model_config = rag.EmbeddingModelConfig(endpoint="whatever")
456+
with pytest.raises(ValueError) as e:
457+
set_embedding_model_config(
458+
embedding_model_config,
459+
tc.TEST_GAPIC_RAG_CORPUS,
460+
)
461+
e.match("endpoint must be of the format ")

vertexai/preview/rag/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
VertexRagStore,
3838
)
3939
from vertexai.preview.rag.utils.resources import (
40+
EmbeddingModelConfig,
4041
RagResource,
4142
)
4243

@@ -53,6 +54,7 @@
5354
"list_files",
5455
"delete_file",
5556
"retrieval_query",
57+
"EmbeddingModelConfig",
5658
"Retrieval",
5759
"VertexRagStore",
5860
"RagResource",

vertexai/preview/rag/rag_data.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,16 @@
4343
_gapic_utils,
4444
)
4545
from vertexai.preview.rag.utils.resources import (
46+
EmbeddingModelConfig,
4647
RagCorpus,
4748
RagFile,
4849
)
4950

5051

5152
def create_corpus(
52-
display_name: Optional[str] = None, description: Optional[str] = None
53+
display_name: Optional[str] = None,
54+
description: Optional[str] = None,
55+
embedding_model_config: Optional[EmbeddingModelConfig] = None,
5356
) -> RagCorpus:
5457
"""Creates a new RagCorpus resource.
5558
@@ -69,6 +72,7 @@ def create_corpus(
6972
the RagCorpus. The name can be up to 128 characters long and can
7073
consist of any UTF-8 characters.
7174
description: The description of the RagCorpus.
75+
embedding_model_config: The embedding model config.
7276
Returns:
7377
RagCorpus.
7478
Raises:
@@ -80,6 +84,12 @@ def create_corpus(
8084
parent = initializer.global_config.common_location_path(project=None, location=None)
8185

8286
rag_corpus = GapicRagCorpus(display_name=display_name, description=description)
87+
if embedding_model_config:
88+
rag_corpus = _gapic_utils.set_embedding_model_config(
89+
embedding_model_config,
90+
rag_corpus,
91+
)
92+
8393
request = CreateRagCorpusRequest(
8494
parent=parent,
8595
rag_corpus=rag_corpus,
@@ -264,6 +274,7 @@ def import_files(
264274
chunk_size: int = 1024,
265275
chunk_overlap: int = 200,
266276
timeout: int = 600,
277+
max_embedding_requests_per_min: int = 1000,
267278
) -> ImportRagFilesResponse:
268279
"""
269280
Import files to an existing RagCorpus, wait until completion.
@@ -299,6 +310,15 @@ def import_files(
299310
"https://drive.google.com/corp/drive/folders/...").
300311
chunk_size: The size of the chunks.
301312
chunk_overlap: The overlap between chunks.
313+
max_embedding_requests_per_min:
314+
Optional. The max number of queries per
315+
minute that this job is allowed to make to the
316+
embedding model specified on the corpus. This
317+
value is specific to this job and not shared
318+
across other import jobs. Consult the Quotas
319+
page on the project to set an appropriate value
320+
here. If unspecified, a default value of 1,000
321+
QPM would be used.
302322
timeout: Default is 600 seconds.
303323
Returns:
304324
ImportRagFilesResponse.
@@ -309,6 +329,7 @@ def import_files(
309329
paths=paths,
310330
chunk_size=chunk_size,
311331
chunk_overlap=chunk_overlap,
332+
max_embedding_requests_per_min=max_embedding_requests_per_min,
312333
)
313334
client = _gapic_utils.create_rag_data_service_client()
314335
try:
@@ -324,6 +345,7 @@ async def import_files_async(
324345
paths: Sequence[str],
325346
chunk_size: int = 1024,
326347
chunk_overlap: int = 200,
348+
max_embedding_requests_per_min: int = 1000,
327349
) -> operation_async.AsyncOperation:
328350
"""
329351
Import files to an existing RagCorpus asynchronously.
@@ -361,6 +383,15 @@ async def import_files_async(
361383
"https://drive.google.com/corp/drive/folders/...").
362384
chunk_size: The size of the chunks.
363385
chunk_overlap: The overlap between chunks.
386+
max_embedding_requests_per_min:
387+
Optional. The max number of queries per
388+
minute that this job is allowed to make to the
389+
embedding model specified on the corpus. This
390+
value is specific to this job and not shared
391+
across other import jobs. Consult the Quotas
392+
page on the project to set an appropriate value
393+
here. If unspecified, a default value of 1,000
394+
QPM would be used.
364395
Returns:
365396
operation_async.AsyncOperation.
366397
"""
@@ -370,6 +401,7 @@ async def import_files_async(
370401
paths=paths,
371402
chunk_size=chunk_size,
372403
chunk_overlap=chunk_overlap,
404+
max_embedding_requests_per_min=max_embedding_requests_per_min,
373405
)
374406
async_client = _gapic_utils.create_rag_data_service_async_client()
375407
try:

vertexai/preview/rag/utils/_gapic_utils.py

+98-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import re
1818
from typing import Any, Dict, Sequence, Union
1919
from google.cloud.aiplatform_v1beta1 import (
20+
RagEmbeddingModelConfig,
2021
GoogleDriveSource,
2122
ImportRagFilesConfig,
2223
ImportRagFilesRequest,
@@ -31,6 +32,7 @@
3132
VertexRagClientWithOverride,
3233
)
3334
from vertexai.preview.rag.utils.resources import (
35+
EmbeddingModelConfig,
3436
RagCorpus,
3537
RagFile,
3638
)
@@ -57,12 +59,43 @@ def create_rag_service_client():
5759
)
5860

5961

62+
def convert_gapic_to_embedding_model_config(
63+
gapic_embedding_model_config: RagEmbeddingModelConfig,
64+
) -> EmbeddingModelConfig:
65+
"""Convert GapicRagEmbeddingModelConfig to EmbeddingModelConfig."""
66+
embedding_model_config = EmbeddingModelConfig()
67+
path = gapic_embedding_model_config.vertex_prediction_endpoint.endpoint
68+
publisher_model = re.match(
69+
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/publishers/google/models/(?P<model_id>.+?)$",
70+
path,
71+
)
72+
endpoint = re.match(
73+
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/endpoints/(?P<endpoint>.+?)$",
74+
path,
75+
)
76+
if publisher_model:
77+
embedding_model_config.publisher_model = path
78+
if endpoint:
79+
embedding_model_config.endpoint = path
80+
embedding_model_config.model = (
81+
gapic_embedding_model_config.vertex_prediction_endpoint.model
82+
)
83+
embedding_model_config.model_version_id = (
84+
gapic_embedding_model_config.vertex_prediction_endpoint.model_version_id
85+
)
86+
87+
return embedding_model_config
88+
89+
6090
def convert_gapic_to_rag_corpus(gapic_rag_corpus: GapicRagCorpus) -> RagCorpus:
6191
""" "Convert GapicRagCorpus to RagCorpus."""
6292
rag_corpus = RagCorpus(
6393
name=gapic_rag_corpus.name,
6494
display_name=gapic_rag_corpus.display_name,
6595
description=gapic_rag_corpus.description,
96+
embedding_model_config=convert_gapic_to_embedding_model_config(
97+
gapic_rag_corpus.rag_embedding_model_config
98+
),
6699
)
67100
return rag_corpus
68101

@@ -124,6 +157,7 @@ def prepare_import_files_request(
124157
paths: Sequence[str],
125158
chunk_size: int = 1024,
126159
chunk_overlap: int = 200,
160+
max_embedding_requests_per_min: int = 1000,
127161
) -> ImportRagFilesRequest:
128162
if len(corpus_name.split("/")) != 6:
129163
raise ValueError(
@@ -135,7 +169,8 @@ def prepare_import_files_request(
135169
chunk_overlap=chunk_overlap,
136170
)
137171
import_rag_files_config = ImportRagFilesConfig(
138-
rag_file_chunking_config=rag_file_chunking_config
172+
rag_file_chunking_config=rag_file_chunking_config,
173+
max_embedding_requests_per_min=max_embedding_requests_per_min,
139174
)
140175

141176
uris = []
@@ -204,3 +239,65 @@ def get_file_name(
204239
raise ValueError(
205240
"name must be of the format `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}` or `{rag_file}`"
206241
)
242+
243+
244+
def set_embedding_model_config(
245+
embedding_model_config: EmbeddingModelConfig,
246+
rag_corpus: GapicRagCorpus,
247+
) -> GapicRagCorpus:
248+
if embedding_model_config.publisher_model and embedding_model_config.endpoint:
249+
raise ValueError("publisher_model and endpoint cannot be set at the same time.")
250+
if (
251+
not embedding_model_config.publisher_model
252+
and not embedding_model_config.endpoint
253+
):
254+
raise ValueError("At least one of publisher_model and endpoint must be set.")
255+
parent = initializer.global_config.common_location_path(project=None, location=None)
256+
257+
if embedding_model_config.publisher_model:
258+
publisher_model = embedding_model_config.publisher_model
259+
full_resource_name = re.match(
260+
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/publishers/google/models/(?P<model_id>.+?)$",
261+
publisher_model,
262+
)
263+
resource_name = re.match(
264+
r"^publishers/google/models/(?P<model_id>.+?)$",
265+
publisher_model,
266+
)
267+
if full_resource_name:
268+
rag_corpus.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = (
269+
publisher_model
270+
)
271+
elif resource_name:
272+
rag_corpus.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = (
273+
parent + "/" + publisher_model
274+
)
275+
else:
276+
raise ValueError(
277+
"publisher_model must be of the format `projects/{project}/locations/{location}/publishers/google/models/{model_id}` or `publishers/google/models/{model_id}`"
278+
)
279+
280+
if embedding_model_config.endpoint:
281+
endpoint = embedding_model_config.endpoint
282+
full_resource_name = re.match(
283+
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/endpoints/(?P<endpoint>.+?)$",
284+
endpoint,
285+
)
286+
resource_name = re.match(
287+
r"^endpoints/(?P<endpoint>.+?)$",
288+
endpoint,
289+
)
290+
if full_resource_name:
291+
rag_corpus.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = (
292+
endpoint
293+
)
294+
elif resource_name:
295+
rag_corpus.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = (
296+
parent + "/" + endpoint
297+
)
298+
else:
299+
raise ValueError(
300+
"endpoint must be of the format `projects/{project}/locations/{location}/endpoints/{endpoint}` or `endpoints/{endpoint}`"
301+
)
302+
303+
return rag_corpus

0 commit comments

Comments
 (0)