Skip to content

Commit c21d8bf

Browse files
authored
Merge pull request #647 from andreped/pgvector-support
Added pgvector support
2 parents ba657ef + 32871cb commit c21d8bf

File tree

5 files changed

+293
-0
lines changed

5 files changed

+293
-0
lines changed

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -53,5 +53,6 @@ milvus = ["pymilvus[model]"]
5353
bedrock = ["boto3", "botocore"]
5454
weaviate = ["weaviate-client"]
5555
azuresearch = ["azure-search-documents", "azure-identity", "azure-common", "fastembed"]
56+
pgvector = ["langchain-postgres>=0.0.12"]
5657
faiss-cpu = ["faiss-cpu"]
5758
faiss-gpu = ["faiss-gpu"]

src/vanna/pgvector/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .pgvector import PG_VectorStore

src/vanna/pgvector/pgvector.py

+265
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
import ast
2+
import json
3+
import logging
4+
import uuid
5+
6+
import pandas as pd
7+
from langchain_core.documents import Document
8+
from langchain_postgres.vectorstores import PGVector
9+
from sqlalchemy import create_engine, text
10+
11+
from .. import ValidationError
12+
from ..base import VannaBase
13+
from ..types import TrainingPlan, TrainingPlanItem
14+
15+
16+
class PG_VectorStore(VannaBase):
17+
def __init__(self, config=None):
18+
if not config or "connection_string" not in config:
19+
raise ValueError(
20+
"A valid 'config' dictionary with a 'connection_string' is required.")
21+
22+
VannaBase.__init__(self, config=config)
23+
24+
if config and "connection_string" in config:
25+
self.connection_string = config.get("connection_string")
26+
self.n_results = config.get("n_results", 10)
27+
28+
if config and "embedding_function" in config:
29+
self.embedding_function = config.get("embedding_function")
30+
else:
31+
from sentence_transformers import SentenceTransformer
32+
self.embedding_function = SentenceTransformer("sentence-transformers/all-MiniLM-l6-v2")
33+
34+
self.sql_vectorstore = PGVector(
35+
embeddings=self.embedding_function,
36+
collection_name="sql",
37+
connection=self.connection_string,
38+
)
39+
self.ddl_vectorstore = PGVector(
40+
embeddings=self.embedding_function,
41+
collection_name="ddl",
42+
connection=self.connection_string,
43+
)
44+
self.documentation_vectorstore = PGVector(
45+
embeddings=self.embedding_function,
46+
collection_name="documentation",
47+
connection=self.connection_string,
48+
)
49+
50+
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
51+
question_sql_json = json.dumps(
52+
{
53+
"question": question,
54+
"sql": sql,
55+
},
56+
ensure_ascii=False,
57+
)
58+
id = str(uuid.uuid4()) + "-sql"
59+
createdat = kwargs.get("createdat")
60+
doc = Document(
61+
page_content=question_sql_json,
62+
metadata={"id": id, "createdat": createdat},
63+
)
64+
self.sql_collection.add_documents([doc], ids=[doc.metadata["id"]])
65+
66+
return id
67+
68+
def add_ddl(self, ddl: str, **kwargs) -> str:
69+
_id = str(uuid.uuid4()) + "-ddl"
70+
doc = Document(
71+
page_content=ddl,
72+
metadata={"id": _id},
73+
)
74+
self.ddl_collection.add_documents([doc], ids=[doc.metadata["id"]])
75+
return _id
76+
77+
def add_documentation(self, documentation: str, **kwargs) -> str:
78+
_id = str(uuid.uuid4()) + "-doc"
79+
doc = Document(
80+
page_content=documentation,
81+
metadata={"id": _id},
82+
)
83+
self.documentation_collection.add_documents([doc], ids=[doc.metadata["id"]])
84+
return _id
85+
86+
def get_collection(self, collection_name):
87+
match collection_name:
88+
case "sql":
89+
return self.sql_collection
90+
case "ddl":
91+
return self.ddl_collection
92+
case "documentation":
93+
return self.documentation_collection
94+
case _:
95+
raise ValueError("Specified collection does not exist.")
96+
97+
async def get_similar_question_sql(self, question: str) -> list:
98+
documents = self.sql_collection.similarity_search(query=question, k=self.n_results)
99+
return [ast.literal_eval(document.page_content) for document in documents]
100+
101+
async def get_related_ddl(self, question: str, **kwargs) -> list:
102+
documents = await self.ddl_collection.similarity_search(query=question, k=self.n_results)
103+
return [document.page_content for document in documents]
104+
105+
async def get_related_documentation(self, question: str, **kwargs) -> list:
106+
documents = await self.documentation_collection.similarity_search(query=question, k=self.n_results)
107+
return [document.page_content for document in documents]
108+
109+
def train(
110+
self,
111+
question: str | None = None,
112+
sql: str | None = None,
113+
ddl: str | None = None,
114+
documentation: str | None = None,
115+
plan: TrainingPlan | None = None,
116+
createdat: str | None = None,
117+
):
118+
if question and not sql:
119+
raise ValidationError("Please provide a SQL query.")
120+
121+
if documentation:
122+
logging.info(f"Adding documentation: {documentation}")
123+
return self.add_documentation(documentation)
124+
125+
if sql and question:
126+
return self.add_question_sql(question=question, sql=sql, createdat=createdat)
127+
128+
if ddl:
129+
logging.info(f"Adding ddl: {ddl}")
130+
return self.add_ddl(ddl)
131+
132+
if plan:
133+
for item in plan._plan:
134+
if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL:
135+
self.add_ddl(item.item_value)
136+
elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS:
137+
self.add_documentation(item.item_value)
138+
elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL and item.item_name:
139+
self.add_question_sql(question=item.item_name, sql=item.item_value)
140+
141+
def get_training_data(self, **kwargs) -> pd.DataFrame:
142+
# Establishing the connection
143+
engine = create_engine(self.connection_string)
144+
145+
# Querying the 'langchain_pg_embedding' table
146+
query_embedding = "SELECT cmetadata, document FROM langchain_pg_embedding"
147+
df_embedding = pd.read_sql(query_embedding, engine)
148+
149+
# List to accumulate the processed rows
150+
processed_rows = []
151+
152+
# Process each row in the DataFrame
153+
for _, row in df_embedding.iterrows():
154+
custom_id = row["cmetadata"]["id"]
155+
document = row["document"]
156+
training_data_type = "documentation" if custom_id[-3:] == "doc" else custom_id[-3:]
157+
158+
if training_data_type == "sql":
159+
# Convert the document string to a dictionary
160+
try:
161+
doc_dict = ast.literal_eval(document)
162+
question = doc_dict.get("question")
163+
content = doc_dict.get("sql")
164+
except (ValueError, SyntaxError):
165+
logging.info(f"Skipping row with custom_id {custom_id} due to parsing error.")
166+
continue
167+
elif training_data_type in ["documentation", "ddl"]:
168+
question = None # Default value for question
169+
content = document
170+
else:
171+
# If the suffix is not recognized, skip this row
172+
logging.info(f"Skipping row with custom_id {custom_id} due to unrecognized training data type.")
173+
continue
174+
175+
# Append the processed data to the list
176+
processed_rows.append(
177+
{"id": custom_id, "question": question, "content": content, "training_data_type": training_data_type}
178+
)
179+
180+
# Create a DataFrame from the list of processed rows
181+
df_processed = pd.DataFrame(processed_rows)
182+
183+
return df_processed
184+
185+
def remove_training_data(self, id: str, **kwargs) -> bool:
186+
# Create the database engine
187+
engine = create_engine(self.connection_string)
188+
189+
# SQL DELETE statement
190+
delete_statement = text(
191+
"""
192+
DELETE FROM langchain_pg_embedding
193+
WHERE cmetadata ->> 'id' = :id
194+
"""
195+
)
196+
197+
# Connect to the database and execute the delete statement
198+
with engine.connect() as connection:
199+
# Start a transaction
200+
with connection.begin() as transaction:
201+
try:
202+
result = connection.execute(delete_statement, {"id": id})
203+
# Commit the transaction if the delete was successful
204+
transaction.commit()
205+
# Check if any row was deleted and return True or False accordingly
206+
return result.rowcount > 0
207+
except Exception as e:
208+
# Rollback the transaction in case of error
209+
logging.error(f"An error occurred: {e}")
210+
transaction.rollback()
211+
return False
212+
213+
def remove_collection(self, collection_name: str) -> bool:
214+
engine = create_engine(self.connection_string)
215+
216+
# Determine the suffix to look for based on the collection name
217+
suffix_map = {"ddl": "ddl", "sql": "sql", "documentation": "doc"}
218+
suffix = suffix_map.get(collection_name)
219+
220+
if not suffix:
221+
logging.info("Invalid collection name. Choose from 'ddl', 'sql', or 'documentation'.")
222+
return False
223+
224+
# SQL query to delete rows based on the condition
225+
query = text(
226+
f"""
227+
DELETE FROM langchain_pg_embedding
228+
WHERE cmetadata->>'id' LIKE '%{suffix}'
229+
"""
230+
)
231+
232+
# Execute the deletion within a transaction block
233+
with engine.connect() as connection:
234+
with connection.begin() as transaction:
235+
try:
236+
result = connection.execute(query)
237+
transaction.commit() # Explicitly commit the transaction
238+
if result.rowcount > 0:
239+
logging.info(
240+
f"Deleted {result.rowcount} rows from "
241+
f"langchain_pg_embedding where collection is {collection_name}."
242+
)
243+
return True
244+
else:
245+
logging.info(f"No rows deleted for collection {collection_name}.")
246+
return False
247+
except Exception as e:
248+
logging.error(f"An error occurred: {e}")
249+
transaction.rollback() # Rollback in case of error
250+
return False
251+
252+
def generate_embedding(self, *args, **kwargs):
253+
pass
254+
255+
def submit_prompt(self, *args, **kwargs):
256+
pass
257+
258+
def system_message(self, message: str) -> any:
259+
return {"role": "system", "content": message}
260+
261+
def user_message(self, message: str) -> any:
262+
return {"role": "user", "content": message}
263+
264+
def assistant_message(self, message: str) -> any:
265+
return {"role": "assistant", "content": message}

tests/test_imports.py

+2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def test_regular_imports():
1818
from vanna.openai.openai_chat import OpenAI_Chat
1919
from vanna.openai.openai_embeddings import OpenAI_Embeddings
2020
from vanna.opensearch.opensearch_vector import OpenSearch_VectorStore
21+
from vanna.pgvector.pgvector import PG_VectorStore
2122
from vanna.pinecone.pinecone_vector import PineconeDB_VectorStore
2223
from vanna.qdrant.qdrant import Qdrant_VectorStore
2324
from vanna.qianfan.Qianfan_Chat import Qianfan_Chat
@@ -43,6 +44,7 @@ def test_shortcut_imports():
4344
from vanna.ollama import Ollama
4445
from vanna.openai import OpenAI_Chat, OpenAI_Embeddings
4546
from vanna.opensearch import OpenSearch_VectorStore
47+
from vanna.pgvector import PG_VectorStore
4648
from vanna.pinecone import PineconeDB_VectorStore
4749
from vanna.qdrant import Qdrant_VectorStore
4850
from vanna.qianfan import Qianfan_Chat, Qianfan_Embeddings

tests/test_pgvector.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import os
2+
3+
from dotenv import load_dotenv
4+
5+
from vanna.pgvector import PG_VectorStore
6+
7+
load_dotenv()
8+
9+
10+
def get_vanna_connection_string():
11+
server = os.environ.get("PG_SERVER")
12+
driver = "psycopg"
13+
port = 5434
14+
database = os.environ.get("PG_DATABASE")
15+
username = os.environ.get("PG_USERNAME")
16+
password = os.environ.get("PG_PASSWORD")
17+
18+
return f"postgresql+psycopg://{username}:{password}@{server}:{port}/{database}"
19+
20+
21+
def test_pgvector():
22+
connection_string = get_vanna_connection_string()
23+
pgclient = PG_VectorStore(config={"connection_string": connection_string})
24+
assert pgclient is not None

0 commit comments

Comments
 (0)