17
17
import re
18
18
from typing import Any , Dict , Sequence , Union
19
19
from google .cloud .aiplatform_v1beta1 import (
20
+ RagEmbeddingModelConfig ,
20
21
GoogleDriveSource ,
21
22
ImportRagFilesConfig ,
22
23
ImportRagFilesRequest ,
31
32
VertexRagClientWithOverride ,
32
33
)
33
34
from vertexai .preview .rag .utils .resources import (
35
+ EmbeddingModelConfig ,
34
36
RagCorpus ,
35
37
RagFile ,
36
38
)
@@ -57,12 +59,43 @@ def create_rag_service_client():
57
59
)
58
60
59
61
62
+ def convert_gapic_to_embedding_model_config (
63
+ gapic_embedding_model_config : RagEmbeddingModelConfig ,
64
+ ) -> EmbeddingModelConfig :
65
+ """Convert GapicRagEmbeddingModelConfig to EmbeddingModelConfig."""
66
+ embedding_model_config = EmbeddingModelConfig ()
67
+ path = gapic_embedding_model_config .vertex_prediction_endpoint .endpoint
68
+ publisher_model = re .match (
69
+ r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/publishers/google/models/(?P<model_id>.+?)$" ,
70
+ path ,
71
+ )
72
+ endpoint = re .match (
73
+ r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/endpoints/(?P<endpoint>.+?)$" ,
74
+ path ,
75
+ )
76
+ if publisher_model :
77
+ embedding_model_config .publisher_model = path
78
+ if endpoint :
79
+ embedding_model_config .endpoint = path
80
+ embedding_model_config .model = (
81
+ gapic_embedding_model_config .vertex_prediction_endpoint .model
82
+ )
83
+ embedding_model_config .model_version_id = (
84
+ gapic_embedding_model_config .vertex_prediction_endpoint .model_version_id
85
+ )
86
+
87
+ return embedding_model_config
88
+
89
+
60
90
def convert_gapic_to_rag_corpus (gapic_rag_corpus : GapicRagCorpus ) -> RagCorpus :
61
91
""" "Convert GapicRagCorpus to RagCorpus."""
62
92
rag_corpus = RagCorpus (
63
93
name = gapic_rag_corpus .name ,
64
94
display_name = gapic_rag_corpus .display_name ,
65
95
description = gapic_rag_corpus .description ,
96
+ embedding_model_config = convert_gapic_to_embedding_model_config (
97
+ gapic_rag_corpus .rag_embedding_model_config
98
+ ),
66
99
)
67
100
return rag_corpus
68
101
@@ -124,6 +157,7 @@ def prepare_import_files_request(
124
157
paths : Sequence [str ],
125
158
chunk_size : int = 1024 ,
126
159
chunk_overlap : int = 200 ,
160
+ max_embedding_requests_per_min : int = 1000 ,
127
161
) -> ImportRagFilesRequest :
128
162
if len (corpus_name .split ("/" )) != 6 :
129
163
raise ValueError (
@@ -135,7 +169,8 @@ def prepare_import_files_request(
135
169
chunk_overlap = chunk_overlap ,
136
170
)
137
171
import_rag_files_config = ImportRagFilesConfig (
138
- rag_file_chunking_config = rag_file_chunking_config
172
+ rag_file_chunking_config = rag_file_chunking_config ,
173
+ max_embedding_requests_per_min = max_embedding_requests_per_min ,
139
174
)
140
175
141
176
uris = []
@@ -204,3 +239,65 @@ def get_file_name(
204
239
raise ValueError (
205
240
"name must be of the format `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}` or `{rag_file}`"
206
241
)
242
+
243
+
244
+ def set_embedding_model_config (
245
+ embedding_model_config : EmbeddingModelConfig ,
246
+ rag_corpus : GapicRagCorpus ,
247
+ ) -> GapicRagCorpus :
248
+ if embedding_model_config .publisher_model and embedding_model_config .endpoint :
249
+ raise ValueError ("publisher_model and endpoint cannot be set at the same time." )
250
+ if (
251
+ not embedding_model_config .publisher_model
252
+ and not embedding_model_config .endpoint
253
+ ):
254
+ raise ValueError ("At least one of publisher_model and endpoint must be set." )
255
+ parent = initializer .global_config .common_location_path (project = None , location = None )
256
+
257
+ if embedding_model_config .publisher_model :
258
+ publisher_model = embedding_model_config .publisher_model
259
+ full_resource_name = re .match (
260
+ r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/publishers/google/models/(?P<model_id>.+?)$" ,
261
+ publisher_model ,
262
+ )
263
+ resource_name = re .match (
264
+ r"^publishers/google/models/(?P<model_id>.+?)$" ,
265
+ publisher_model ,
266
+ )
267
+ if full_resource_name :
268
+ rag_corpus .rag_embedding_model_config .vertex_prediction_endpoint .endpoint = (
269
+ publisher_model
270
+ )
271
+ elif resource_name :
272
+ rag_corpus .rag_embedding_model_config .vertex_prediction_endpoint .endpoint = (
273
+ parent + "/" + publisher_model
274
+ )
275
+ else :
276
+ raise ValueError (
277
+ "publisher_model must be of the format `projects/{project}/locations/{location}/publishers/google/models/{model_id}` or `publishers/google/models/{model_id}`"
278
+ )
279
+
280
+ if embedding_model_config .endpoint :
281
+ endpoint = embedding_model_config .endpoint
282
+ full_resource_name = re .match (
283
+ r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)/endpoints/(?P<endpoint>.+?)$" ,
284
+ endpoint ,
285
+ )
286
+ resource_name = re .match (
287
+ r"^endpoints/(?P<endpoint>.+?)$" ,
288
+ endpoint ,
289
+ )
290
+ if full_resource_name :
291
+ rag_corpus .rag_embedding_model_config .vertex_prediction_endpoint .endpoint = (
292
+ endpoint
293
+ )
294
+ elif resource_name :
295
+ rag_corpus .rag_embedding_model_config .vertex_prediction_endpoint .endpoint = (
296
+ parent + "/" + endpoint
297
+ )
298
+ else :
299
+ raise ValueError (
300
+ "endpoint must be of the format `projects/{project}/locations/{location}/endpoints/{endpoint}` or `endpoints/{endpoint}`"
301
+ )
302
+
303
+ return rag_corpus
0 commit comments