|
15 | 15 | # limitations under the License.
|
16 | 16 | #
|
17 | 17 |
|
18 |
| -from vertexai.preview.rag.utils.resources import ( |
19 |
| - EmbeddingModelConfig, |
20 |
| - RagCorpus, |
21 |
| - RagFile, |
22 |
| - RagResource, |
23 |
| -) |
| 18 | + |
24 | 19 | from google.cloud import aiplatform
|
| 20 | + |
| 21 | +from vertexai.preview import rag |
25 | 22 | from google.cloud.aiplatform_v1beta1 import (
|
26 | 23 | GoogleDriveSource,
|
27 | 24 | RagFileChunkingConfig,
|
28 | 25 | ImportRagFilesConfig,
|
29 | 26 | ImportRagFilesRequest,
|
30 | 27 | ImportRagFilesResponse,
|
| 28 | + JiraSource as GapicJiraSource, |
31 | 29 | RagCorpus as GapicRagCorpus,
|
32 | 30 | RagFile as GapicRagFile,
|
| 31 | + SlackSource as GapicSlackSource, |
33 | 32 | RagContexts,
|
34 | 33 | RetrieveContextsResponse,
|
35 | 34 | )
|
| 35 | +from google.cloud.aiplatform_v1beta1.types import api_auth |
| 36 | +from google.protobuf import timestamp_pb2 |
36 | 37 |
|
37 | 38 |
|
38 | 39 | TEST_PROJECT = "test-project"
|
|
55 | 56 | TEST_PROJECT, TEST_REGION
|
56 | 57 | )
|
57 | 58 | )
|
58 |
| -TEST_EMBEDDING_MODEL_CONFIG = EmbeddingModelConfig( |
| 59 | +TEST_EMBEDDING_MODEL_CONFIG = rag.EmbeddingModelConfig( |
59 | 60 | publisher_model="publishers/google/models/textembedding-gecko",
|
60 | 61 | )
|
61 |
| -TEST_RAG_CORPUS = RagCorpus( |
| 62 | +TEST_RAG_CORPUS = rag.RagCorpus( |
62 | 63 | name=TEST_RAG_CORPUS_RESOURCE_NAME,
|
63 | 64 | display_name=TEST_CORPUS_DISPLAY_NAME,
|
64 | 65 | description=TEST_CORPUS_DISCRIPTION,
|
|
144 | 145 | display_name=TEST_FILE_DISPLAY_NAME,
|
145 | 146 | description=TEST_FILE_DESCRIPTION,
|
146 | 147 | )
|
147 |
| -TEST_RAG_FILE = RagFile( |
| 148 | +TEST_RAG_FILE = rag.RagFile( |
148 | 149 | name=TEST_RAG_FILE_RESOURCE_NAME,
|
149 | 150 | display_name=TEST_FILE_DISPLAY_NAME,
|
150 | 151 | description=TEST_FILE_DESCRIPTION,
|
151 | 152 | )
|
| 153 | +# Slack sources |
| 154 | +TEST_SLACK_CHANNEL_ID = "123" |
| 155 | +TEST_SLACK_CHANNEL_ID_2 = "456" |
| 156 | +TEST_SLACK_START_TIME = timestamp_pb2.Timestamp() |
| 157 | +TEST_SLACK_START_TIME.GetCurrentTime() |
| 158 | +TEST_SLACK_END_TIME = timestamp_pb2.Timestamp() |
| 159 | +TEST_SLACK_END_TIME.GetCurrentTime() |
| 160 | +TEST_SLACK_API_KEY_SECRET_VERSION = ( |
| 161 | + "projects/test-project/secrets/test-secret/versions/1" |
| 162 | +) |
| 163 | +TEST_SLACK_API_KEY_SECRET_VERSION_2 = ( |
| 164 | + "projects/test-project/secrets/test-secret/versions/2" |
| 165 | +) |
| 166 | +TEST_SLACK_SOURCE = rag.SlackChannelsSource( |
| 167 | + channels=[ |
| 168 | + rag.SlackChannel( |
| 169 | + channel_id=TEST_SLACK_CHANNEL_ID, |
| 170 | + api_key=TEST_SLACK_API_KEY_SECRET_VERSION, |
| 171 | + start_time=TEST_SLACK_START_TIME, |
| 172 | + end_time=TEST_SLACK_END_TIME, |
| 173 | + ), |
| 174 | + rag.SlackChannel( |
| 175 | + channel_id=TEST_SLACK_CHANNEL_ID_2, |
| 176 | + api_key=TEST_SLACK_API_KEY_SECRET_VERSION_2, |
| 177 | + ), |
| 178 | + ], |
| 179 | +) |
| 180 | +TEST_IMPORT_FILES_CONFIG_SLACK_SOURCE = ImportRagFilesConfig( |
| 181 | + rag_file_chunking_config=RagFileChunkingConfig( |
| 182 | + chunk_size=TEST_CHUNK_SIZE, |
| 183 | + chunk_overlap=TEST_CHUNK_OVERLAP, |
| 184 | + ) |
| 185 | +) |
| 186 | +TEST_IMPORT_FILES_CONFIG_SLACK_SOURCE.slack_source.channels = [ |
| 187 | + GapicSlackSource.SlackChannels( |
| 188 | + channels=[ |
| 189 | + GapicSlackSource.SlackChannels.SlackChannel( |
| 190 | + channel_id=TEST_SLACK_CHANNEL_ID, |
| 191 | + start_time=TEST_SLACK_START_TIME, |
| 192 | + end_time=TEST_SLACK_END_TIME, |
| 193 | + ), |
| 194 | + ], |
| 195 | + api_key_config=api_auth.ApiAuth.ApiKeyConfig( |
| 196 | + api_key_secret_version=TEST_SLACK_API_KEY_SECRET_VERSION |
| 197 | + ), |
| 198 | + ), |
| 199 | + GapicSlackSource.SlackChannels( |
| 200 | + channels=[ |
| 201 | + GapicSlackSource.SlackChannels.SlackChannel( |
| 202 | + channel_id=TEST_SLACK_CHANNEL_ID_2, |
| 203 | + start_time=None, |
| 204 | + end_time=None, |
| 205 | + ), |
| 206 | + ], |
| 207 | + api_key_config=api_auth.ApiAuth.ApiKeyConfig( |
| 208 | + api_key_secret_version=TEST_SLACK_API_KEY_SECRET_VERSION_2 |
| 209 | + ), |
| 210 | + ), |
| 211 | +] |
| 212 | +TEST_IMPORT_REQUEST_SLACK_SOURCE = ImportRagFilesRequest( |
| 213 | + parent=TEST_RAG_CORPUS_RESOURCE_NAME, |
| 214 | + import_rag_files_config=TEST_IMPORT_FILES_CONFIG_SLACK_SOURCE, |
| 215 | +) |
| 216 | +# Jira sources |
| 217 | +TEST_JIRA_EMAIL = "[email protected]" |
| 218 | +TEST_JIRA_PROJECT = "test-project" |
| 219 | +TEST_JIRA_CUSTOM_QUERY = "test-custom-query" |
| 220 | +TEST_JIRA_SERVER_URI = "test.atlassian.net" |
| 221 | +TEST_JIRA_API_KEY_SECRET_VERSION = ( |
| 222 | + "projects/test-project/secrets/test-secret/versions/1" |
| 223 | +) |
| 224 | +TEST_JIRA_SOURCE = rag.JiraSource( |
| 225 | + queries=[ |
| 226 | + rag.JiraQuery( |
| 227 | + email=TEST_JIRA_EMAIL, |
| 228 | + jira_projects=[TEST_JIRA_PROJECT], |
| 229 | + custom_queries=[TEST_JIRA_CUSTOM_QUERY], |
| 230 | + api_key=TEST_JIRA_API_KEY_SECRET_VERSION, |
| 231 | + server_uri=TEST_JIRA_SERVER_URI, |
| 232 | + ) |
| 233 | + ], |
| 234 | +) |
| 235 | +TEST_IMPORT_FILES_CONFIG_JIRA_SOURCE = ImportRagFilesConfig( |
| 236 | + rag_file_chunking_config=RagFileChunkingConfig( |
| 237 | + chunk_size=TEST_CHUNK_SIZE, |
| 238 | + chunk_overlap=TEST_CHUNK_OVERLAP, |
| 239 | + ) |
| 240 | +) |
| 241 | +TEST_IMPORT_FILES_CONFIG_JIRA_SOURCE.jira_source.jira_queries = [ |
| 242 | + GapicJiraSource.JiraQueries( |
| 243 | + custom_queries=[TEST_JIRA_CUSTOM_QUERY], |
| 244 | + projects=[TEST_JIRA_PROJECT], |
| 245 | + email=TEST_JIRA_EMAIL, |
| 246 | + server_uri=TEST_JIRA_SERVER_URI, |
| 247 | + api_key_config=api_auth.ApiAuth.ApiKeyConfig( |
| 248 | + api_key_secret_version=TEST_JIRA_API_KEY_SECRET_VERSION |
| 249 | + ), |
| 250 | + ) |
| 251 | +] |
| 252 | +TEST_IMPORT_REQUEST_JIRA_SOURCE = ImportRagFilesRequest( |
| 253 | + parent=TEST_RAG_CORPUS_RESOURCE_NAME, |
| 254 | + import_rag_files_config=TEST_IMPORT_FILES_CONFIG_JIRA_SOURCE, |
| 255 | +) |
152 | 256 |
|
153 | 257 | # Retrieval
|
154 | 258 | TEST_QUERY_TEXT = "What happen to the fox and the dog?"
|
|
162 | 266 | ]
|
163 | 267 | )
|
164 | 268 | TEST_RETRIEVAL_RESPONSE = RetrieveContextsResponse(contexts=TEST_CONTEXTS)
|
165 |
| -TEST_RAG_RESOURCE = RagResource( |
| 269 | +TEST_RAG_RESOURCE = rag.RagResource( |
166 | 270 | rag_corpus=TEST_RAG_CORPUS_RESOURCE_NAME,
|
167 | 271 | rag_file_ids=[TEST_RAG_FILE_ID],
|
168 | 272 | )
|
169 |
| -TEST_RAG_RESOURCE_INVALID_NAME = RagResource( |
| 273 | +TEST_RAG_RESOURCE_INVALID_NAME = rag.RagResource( |
170 | 274 | rag_corpus="213lkj-1/23jkl/",
|
171 | 275 | rag_file_ids=[TEST_RAG_FILE_ID],
|
172 | 276 | )
|
0 commit comments