Skip to content

Commit cfc3421

Browse files
speedstorm1copybara-github
authored andcommitted
feat: Adding Feature Store Vector DB option for RAG corpuses to SDK
PiperOrigin-RevId: 673571692
1 parent 73490b2 commit cfc3421

File tree

6 files changed

+83
-7
lines changed

6 files changed

+83
-7
lines changed

tests/unit/vertex_rag/test_rag_constants.py

+21
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
JiraSource,
2929
JiraQuery,
3030
Weaviate,
31+
VertexFeatureStore,
3132
)
3233
from google.cloud.aiplatform_v1beta1 import (
3334
GoogleDriveSource,
@@ -68,6 +69,7 @@
6869
collection_name=TEST_WEAVIATE_COLLECTION_NAME,
6970
api_key=TEST_WEAVIATE_API_KEY_SECRET_VERSION,
7071
)
72+
TEST_VERTEX_FEATURE_STORE_RESOURCE_NAME = "test-feature-view-resource-name"
7173
TEST_GAPIC_RAG_CORPUS = GapicRagCorpus(
7274
name=TEST_RAG_CORPUS_RESOURCE_NAME,
7375
display_name=TEST_CORPUS_DISPLAY_NAME,
@@ -94,9 +96,22 @@
9496
),
9597
),
9698
)
99+
TEST_GAPIC_RAG_CORPUS_VERTEX_FEATURE_STORE = GapicRagCorpus(
100+
name=TEST_RAG_CORPUS_RESOURCE_NAME,
101+
display_name=TEST_CORPUS_DISPLAY_NAME,
102+
description=TEST_CORPUS_DISCRIPTION,
103+
rag_vector_db_config=RagVectorDbConfig(
104+
vertex_feature_store=RagVectorDbConfig.VertexFeatureStore(
105+
feature_view_resource_name=TEST_VERTEX_FEATURE_STORE_RESOURCE_NAME
106+
),
107+
),
108+
)
97109
TEST_EMBEDDING_MODEL_CONFIG = EmbeddingModelConfig(
98110
publisher_model="publishers/google/models/textembedding-gecko",
99111
)
112+
TEST_VERTEX_FEATURE_STORE_CONFIG = VertexFeatureStore(
113+
resource_name=TEST_VERTEX_FEATURE_STORE_RESOURCE_NAME,
114+
)
100115
TEST_RAG_CORPUS = RagCorpus(
101116
name=TEST_RAG_CORPUS_RESOURCE_NAME,
102117
display_name=TEST_CORPUS_DISPLAY_NAME,
@@ -109,6 +124,12 @@
109124
description=TEST_CORPUS_DISCRIPTION,
110125
vector_db=TEST_WEAVIATE_CONFIG,
111126
)
127+
TEST_RAG_CORPUS_VERTEX_FEATURE_STORE = RagCorpus(
128+
name=TEST_RAG_CORPUS_RESOURCE_NAME,
129+
display_name=TEST_CORPUS_DISPLAY_NAME,
130+
description=TEST_CORPUS_DISCRIPTION,
131+
vector_db=TEST_VERTEX_FEATURE_STORE_CONFIG,
132+
)
112133
TEST_PAGE_TOKEN = "test-page-token"
113134

114135
# RagFiles

tests/unit/vertex_rag/test_rag_data.py

+26
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,23 @@ def create_rag_corpus_mock_weaviate():
6262
yield create_rag_corpus_mock_weaviate
6363

6464

65+
@pytest.fixture
66+
def create_rag_corpus_mock_vertex_feature_store():
67+
with mock.patch.object(
68+
VertexRagDataServiceClient,
69+
"create_rag_corpus",
70+
) as create_rag_corpus_mock_vertex_feature_store:
71+
create_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation)
72+
create_rag_corpus_lro_mock.done.return_value = True
73+
create_rag_corpus_lro_mock.result.return_value = (
74+
tc.TEST_GAPIC_RAG_CORPUS_VERTEX_FEATURE_STORE
75+
)
76+
create_rag_corpus_mock_vertex_feature_store.return_value = (
77+
create_rag_corpus_lro_mock
78+
)
79+
yield create_rag_corpus_mock_vertex_feature_store
80+
81+
6582
@pytest.fixture
6683
def list_rag_corpora_pager_mock():
6784
with mock.patch.object(
@@ -216,6 +233,15 @@ def test_create_corpus_weaviate_success(self):
216233

217234
rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS_WEAVIATE)
218235

236+
@pytest.mark.usefixtures("create_rag_corpus_mock_vertex_feature_store")
237+
def test_create_corpus_vertex_feature_store_success(self):
238+
rag_corpus = rag.create_corpus(
239+
display_name=tc.TEST_CORPUS_DISPLAY_NAME,
240+
vector_db=tc.TEST_VERTEX_FEATURE_STORE_CONFIG,
241+
)
242+
243+
rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS_VERTEX_FEATURE_STORE)
244+
219245
@pytest.mark.usefixtures("rag_data_client_mock_exception")
220246
def test_create_corpus_failure(self):
221247
with pytest.raises(RuntimeError) as e:

vertexai/preview/rag/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
RagResource,
4646
SlackChannel,
4747
SlackChannelsSource,
48+
VertexFeatureStore,
4849
Weaviate,
4950
)
5051

@@ -59,6 +60,7 @@
5960
"Retrieval",
6061
"SlackChannel",
6162
"SlackChannelsSource",
63+
"VertexFeatureStore",
6264
"VertexRagStore",
6365
"Weaviate",
6466
"create_corpus",

vertexai/preview/rag/rag_data.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
RagCorpus,
4949
RagFile,
5050
SlackChannelsSource,
51+
VertexFeatureStore,
5152
Weaviate,
5253
)
5354

@@ -56,7 +57,7 @@ def create_corpus(
5657
display_name: Optional[str] = None,
5758
description: Optional[str] = None,
5859
embedding_model_config: Optional[EmbeddingModelConfig] = None,
59-
vector_db: Optional[Weaviate] = None,
60+
vector_db: Optional[Union[Weaviate, VertexFeatureStore]] = None,
6061
) -> RagCorpus:
6162
"""Creates a new RagCorpus resource.
6263

vertexai/preview/rag/utils/_gapic_utils.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
RagFile,
4343
SlackChannelsSource,
4444
JiraSource,
45+
VertexFeatureStore,
4546
Weaviate,
4647
)
4748

@@ -97,14 +98,18 @@ def convert_gapic_to_embedding_model_config(
9798

9899
def convert_gapic_to_vector_db(
99100
gapic_vector_db: RagVectorDbConfig,
100-
) -> Weaviate:
101-
"""Convert Gapic RagVectorDbConfig to Weaviate."""
101+
) -> Union[Weaviate, VertexFeatureStore]:
102+
"""Convert Gapic RagVectorDbConfig to Weaviate or VertexFeatureStore."""
102103
if gapic_vector_db.__contains__("weaviate"):
103104
return Weaviate(
104105
weaviate_http_endpoint=gapic_vector_db.weaviate.http_endpoint,
105106
collection_name=gapic_vector_db.weaviate.collection_name,
106107
api_key=gapic_vector_db.api_auth.api_key_config.api_key_secret_version,
107108
)
109+
elif gapic_vector_db.__contains__("vertex_feature_store"):
110+
return VertexFeatureStore(
111+
resource_name=gapic_vector_db.vertex_feature_store.feature_view_resource_name,
112+
)
108113
else:
109114
return None
110115

@@ -390,7 +395,7 @@ def set_embedding_model_config(
390395

391396

392397
def set_vector_db(
393-
vector_db: Weaviate,
398+
vector_db: Union[Weaviate, VertexFeatureStore],
394399
rag_corpus: GapicRagCorpus,
395400
) -> None:
396401
"""Sets the vector db configuration for the rag corpus."""
@@ -410,5 +415,13 @@ def set_vector_db(
410415
),
411416
),
412417
)
418+
elif isinstance(vector_db, VertexFeatureStore):
419+
resource_name = vector_db.resource_name
420+
421+
rag_corpus.rag_vector_db_config = RagVectorDbConfig(
422+
vertex_feature_store=RagVectorDbConfig.VertexFeatureStore(
423+
feature_view_resource_name=resource_name,
424+
),
425+
)
413426
else:
414-
raise TypeError("vector_db must be a Weaviate.")
427+
raise TypeError("vector_db must be a Weaviate or VertexFeatureStore.")

vertexai/preview/rag/utils/resources.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#
1717

1818
import dataclasses
19-
from typing import List, Optional, Sequence
19+
from typing import List, Optional, Sequence, Union
2020

2121
from google.protobuf import timestamp_pb2
2222

@@ -85,6 +85,19 @@ class Weaviate:
8585
api_key: str
8686

8787

88+
@dataclasses.dataclass
89+
class VertexFeatureStore:
90+
"""VertexFeatureStore.
91+
92+
Attributes:
93+
resource_name: The resource name of the FeatureView. Format:
94+
``projects/{project}/locations/{location}/featureOnlineStores/
95+
{feature_online_store}/featureViews/{feature_view}``
96+
"""
97+
98+
resource_name: str
99+
100+
88101
@dataclasses.dataclass
89102
class RagCorpus:
90103
"""RAG corpus(output only).
@@ -102,7 +115,7 @@ class RagCorpus:
102115
display_name: Optional[str] = None
103116
description: Optional[str] = None
104117
embedding_model_config: Optional[EmbeddingModelConfig] = None
105-
vector_db: Optional[Weaviate] = None
118+
vector_db: Optional[Union[Weaviate, VertexFeatureStore]] = None
106119

107120

108121
@dataclasses.dataclass

0 commit comments

Comments
 (0)