-
Notifications
You must be signed in to change notification settings - Fork 56
/
Copy pathstore_embeddings.py
261 lines (221 loc) · 13.5 KB
/
store_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
import asyncio
import asyncpg
import pandas as pd
import numpy as np
from pgvector.asyncpg import register_vector
from google.cloud.sql.connector import Connector
from langchain_community.embeddings import VertexAIEmbeddings
from google.cloud import bigquery
from dbconnectors import pgconnector
from agents import EmbedderAgent
from sqlalchemy.sql import text
from utilities import VECTOR_STORE, PROJECT_ID, PG_INSTANCE, PG_DATABASE, PG_USER, PG_PASSWORD, PG_REGION, BQ_OPENDATAQNA_DATASET_NAME, BQ_REGION, EMBEDDING_MODEL
embedder = EmbedderAgent(EMBEDDING_MODEL)
async def store_schema_embeddings(table_details_embeddings,
tablecolumn_details_embeddings,
project_id,
instance_name,
database_name,
schema,
database_user,
database_password,
region,
VECTOR_STORE):
"""
Store the vectorised table and column details in the DB table.
This code may run for a few minutes.
"""
if VECTOR_STORE == "cloudsql-pgvector":
loop = asyncio.get_running_loop()
async with Connector(loop=loop) as connector:
# Create connection to Cloud SQL database.
conn: asyncpg.Connection = await connector.connect_async(
f"{project_id}:{region}:{instance_name}", # Cloud SQL instance connection name
"asyncpg",
user=f"{database_user}",
password=f"{database_password}",
db=f"{database_name}",
)
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
await register_vector(conn)
# await conn.execute(f"DROP SCHEMA IF EXISTS {pg_schema} CASCADE")
# await conn.execute(f"CREATE SCHEMA {pg_schema}")
# await conn.execute("DROP TABLE IF EXISTS table_details_embeddings")
# Create the `table_details_embeddings` table to store vector embeddings.
await conn.execute(
"""CREATE TABLE IF NOT EXISTS table_details_embeddings(
source_type VARCHAR(100) NOT NULL,
user_grouping VARCHAR(100) NOT NULL,
table_schema VARCHAR(1024) NOT NULL,
table_name VARCHAR(1024) NOT NULL,
content TEXT,
embedding vector(768))"""
)
# Store all the generated embeddings back into the database.
for index, row in table_details_embeddings.iterrows():
await conn.execute(
f"""
MERGE INTO table_details_embeddings AS target
USING (SELECT $1::text AS source_type, $2::text AS user_grouping, $3::text AS table_schema, $4::text AS table_name, $5::text AS content, $6::vector AS embedding) AS source
ON target.user_grouping = source.user_grouping AND target.table_name = source.table_name
WHEN MATCHED THEN
UPDATE SET source_type = source.source_type, table_schema = source.table_schema, content = source.content, embedding = source.embedding
WHEN NOT MATCHED THEN
INSERT (source_type, user_grouping, table_schema, table_name, content, embedding)
VALUES (source.source_type, source.user_grouping, source.table_schema, source.table_name, source.content, source.embedding);
""",
row["source_type"],
row["user_grouping"],
row["table_schema"],
row["table_name"],
row["content"],
np.array(row["embedding"]),
)
# await conn.execute("DROP TABLE IF EXISTS tablecolumn_details_embeddings")
# Create the `table_details_embeddings` table to store vector embeddings.
await conn.execute(
"""CREATE TABLE IF NOT EXISTS tablecolumn_details_embeddings(
source_type VARCHAR(100) NOT NULL,
user_grouping VARCHAR(100) NOT NULL,
table_schema VARCHAR(1024) NOT NULL,
table_name VARCHAR(1024) NOT NULL,
column_name VARCHAR(1024) NOT NULL,
content TEXT,
embedding vector(768))"""
)
# Store all the generated embeddings back into the database.
for index, row in tablecolumn_details_embeddings.iterrows():
await conn.execute(
f"""
MERGE INTO tablecolumn_details_embeddings AS target
USING (SELECT $1::text AS source_type, $2::text AS user_grouping, $3::text AS table_schema,
$4::text AS table_name, $5::text AS column_name, $6::text AS content, $7::vector AS embedding) AS source
ON target.user_grouping = source.user_grouping
AND target.table_name = source.table_name
AND target.column_name = source.column_name
WHEN MATCHED THEN
UPDATE SET source_type = source.source_type, table_schema = source.table_schema, content = source.content, embedding = source.embedding
WHEN NOT MATCHED THEN
INSERT (source_type, user_grouping, table_schema, table_name, column_name, content, embedding)
VALUES (source.source_type, source.user_grouping, source.table_schema, source.table_name, source.column_name, source.content, source.embedding);
""",
row["source_type"],
row["user_grouping"],
row["table_schema"],
row["table_name"],
row["column_name"],
row["content"],
np.array(row["embedding"]),
)
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
await register_vector(conn)
# await conn.execute("DROP TABLE IF EXISTS example_prompt_sql_embeddings")
await conn.execute(
"""CREATE TABLE IF NOT EXISTS example_prompt_sql_embeddings(
user_grouping VARCHAR(1024) NOT NULL,
example_user_question text NOT NULL,
example_generated_sql text NOT NULL,
embedding vector(768))"""
)
await conn.close()
elif VECTOR_STORE == "bigquery-vector":
client=bigquery.Client(project=project_id)
#Store table embeddings
client.query_and_wait(f'''CREATE TABLE IF NOT EXISTS `{project_id}.{schema}.table_details_embeddings` (
source_type string NOT NULL, user_grouping string NOT NULL, table_schema string NOT NULL, table_name string NOT NULL, content string, embedding ARRAY<FLOAT64>)''')
#job_config = bigquery.LoadJobConfig(write_disposition="WRITE_TRUNCATE")
delete_conditions = table_details_embeddings[['user_grouping', 'table_name']].apply(tuple, axis=1).tolist()
where_clause = " OR ".join([f"(user_grouping = '{cond[0]}' AND table_name = '{cond[1]}')" for cond in delete_conditions])
delete_query = f"""
DELETE FROM `{project_id}.{schema}.table_details_embeddings`
WHERE {where_clause}
"""
client.query_and_wait(delete_query)
client.load_table_from_dataframe(table_details_embeddings,f'{project_id}.{schema}.table_details_embeddings')
#Store column embeddings
client.query_and_wait(f'''CREATE TABLE IF NOT EXISTS `{project_id}.{schema}.tablecolumn_details_embeddings` (
source_type string NOT NULL,user_grouping string NOT NULL, table_schema string NOT NULL, table_name string NOT NULL, column_name string NOT NULL,
content string, embedding ARRAY<FLOAT64>)''')
#job_config = bigquery.LoadJobConfig(write_disposition="WRITE_TRUNCATE")
delete_conditions = tablecolumn_details_embeddings[['user_grouping', 'table_name', 'column_name']].apply(tuple, axis=1).tolist()
where_clause = " OR ".join([f"(user_grouping = '{cond[0]}' AND table_name = '{cond[1]}' AND column_name = '{cond[2]}')" for cond in delete_conditions])
delete_query = f"""
DELETE FROM `{project_id}.{schema}.tablecolumn_details_embeddings`
WHERE {where_clause}
"""
client.query_and_wait(delete_query)
client.load_table_from_dataframe(tablecolumn_details_embeddings,f'{project_id}.{schema}.tablecolumn_details_embeddings')
client.query_and_wait(f'''CREATE TABLE IF NOT EXISTS `{project_id}.{schema}.example_prompt_sql_embeddings` (
user_grouping string NOT NULL, example_user_question string NOT NULL, example_generated_sql string NOT NULL,
embedding ARRAY<FLOAT64>)''')
else: raise ValueError("Please provide a valid Vector Store.")
return "Embeddings are stored successfully"
async def add_sql_embedding(user_question, generated_sql, database):
emb=embedder.create(user_question)
if VECTOR_STORE == "cloudsql-pgvector":
# sql= f'''MERGE INTO example_prompt_sql_embeddings as tgt
# using (SELECT '{user_question}' as example_user_question) as src
# on tgt.example_user_question=src.example_user_question
# when not matched then
# insert (table_schema, example_user_question,example_generated_sql,embedding)
# values('{database}','{user_question}','{generated_sql}','{(emb)}')
# when matched then update set
# table_schema = '{database}',
# example_generated_sql = '{generated_sql}',
# embedding = '{(emb)}' '''
# # print(sql)
# conn=pgconnector.pool.connect()
# await conn.execute(text(sql))
# pgconnector.retrieve_df(sql)
loop = asyncio.get_running_loop()
async with Connector(loop=loop) as connector:
# Create connection to Cloud SQL database.
conn: asyncpg.Connection = await connector.connect_async(
f"{PROJECT_ID}:{PG_REGION}:{PG_INSTANCE}", # Cloud SQL instance connection name
"asyncpg",
user=f"{PG_USER}",
password=f"{PG_PASSWORD}",
db=f"{PG_DATABASE}",
)
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
await register_vector(conn)
await conn.execute("DELETE FROM example_prompt_sql_embeddings WHERE user_grouping= $1 and example_user_question=$2",
database,
user_question)
cleaned_sql =generated_sql.replace("\r", " ").replace("\n", " ")
await conn.execute(
"INSERT INTO example_prompt_sql_embeddings (user_grouping, example_user_question, example_generated_sql, embedding) VALUES ($1, $2, $3, $4)",
database,
user_question,
cleaned_sql,
np.array(emb),
)
elif VECTOR_STORE == "bigquery-vector":
client=bigquery.Client(project=PROJECT_ID)
client.query_and_wait(f'''CREATE TABLE IF NOT EXISTS `{PROJECT_ID}.{BQ_OPENDATAQNA_DATASET_NAME}.example_prompt_sql_embeddings` (
user_grouping string NOT NULL, example_user_question string NOT NULL, example_generated_sql string NOT NULL,
embedding ARRAY<FLOAT64>)''')
client.query_and_wait(f'''DELETE FROM `{PROJECT_ID}.{BQ_OPENDATAQNA_DATASET_NAME}.example_prompt_sql_embeddings`
WHERE user_grouping= '{database}' and example_user_question= "{user_question}" '''
)
# embedding=np.array(row["embedding"])
cleaned_sql = generated_sql.replace("\r", " ").replace("\n", " ")
client.query_and_wait(f'''INSERT INTO `{PROJECT_ID}.{BQ_OPENDATAQNA_DATASET_NAME}.example_prompt_sql_embeddings`
VALUES ("{database}","{user_question}" ,
"{cleaned_sql}",{emb})''')
return 1
if __name__ == '__main__':
from retrieve_embeddings import retrieve_embeddings
from utilities import PG_SCHEMA, PROJECT_ID, PG_INSTANCE, PG_DATABASE, PG_USER, PG_PASSWORD, PG_REGION
VECTOR_STORE = "cloudsql-pgvector"
t, c = retrieve_embeddings(VECTOR_STORE, PG_SCHEMA)
asyncio.run(store_schema_embeddings(t,
c,
PROJECT_ID,
PG_INSTANCE,
PG_DATABASE,
PG_SCHEMA,
PG_USER,
PG_PASSWORD,
PG_REGION,
VECTOR_STORE = VECTOR_STORE))