Skip to content

Commit ac1a841

Browse files
authored
Merge pull request #660 from edlouth/pgvector_fixes
Pgvector fixes
2 parents ed26e2a + 4915144 commit ac1a841

File tree

2 files changed

+42
-29
lines changed

2 files changed

+42
-29
lines changed

src/vanna/pgvector/pgvector.py

+10-22
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,20 @@ def __init__(self, config=None):
2828
if config and "embedding_function" in config:
2929
self.embedding_function = config.get("embedding_function")
3030
else:
31-
from sentence_transformers import SentenceTransformer
32-
self.embedding_function = SentenceTransformer("sentence-transformers/all-MiniLM-l6-v2")
31+
from langchain_huggingface import HuggingFaceEmbeddings
32+
self.embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
3333

34-
self.sql_vectorstore = PGVector(
34+
self.sql_collection = PGVector(
3535
embeddings=self.embedding_function,
3636
collection_name="sql",
3737
connection=self.connection_string,
3838
)
39-
self.ddl_vectorstore = PGVector(
39+
self.ddl_collection = PGVector(
4040
embeddings=self.embedding_function,
4141
collection_name="ddl",
4242
connection=self.connection_string,
4343
)
44-
self.documentation_vectorstore = PGVector(
44+
self.documentation_collection = PGVector(
4545
embeddings=self.embedding_function,
4646
collection_name="documentation",
4747
connection=self.connection_string,
@@ -94,16 +94,16 @@ def get_collection(self, collection_name):
9494
case _:
9595
raise ValueError("Specified collection does not exist.")
9696

97-
async def get_similar_question_sql(self, question: str) -> list:
97+
def get_similar_question_sql(self, question: str) -> list:
9898
documents = self.sql_collection.similarity_search(query=question, k=self.n_results)
9999
return [ast.literal_eval(document.page_content) for document in documents]
100100

101-
async def get_related_ddl(self, question: str, **kwargs) -> list:
102-
documents = await self.ddl_collection.similarity_search(query=question, k=self.n_results)
101+
def get_related_ddl(self, question: str, **kwargs) -> list:
102+
documents = self.ddl_collection.similarity_search(query=question, k=self.n_results)
103103
return [document.page_content for document in documents]
104104

105-
async def get_related_documentation(self, question: str, **kwargs) -> list:
106-
documents = await self.documentation_collection.similarity_search(query=question, k=self.n_results)
105+
def get_related_documentation(self, question: str, **kwargs) -> list:
106+
documents = self.documentation_collection.similarity_search(query=question, k=self.n_results)
107107
return [document.page_content for document in documents]
108108

109109
def train(
@@ -251,15 +251,3 @@ def remove_collection(self, collection_name: str) -> bool:
251251

252252
def generate_embedding(self, *args, **kwargs):
253253
pass
254-
255-
def submit_prompt(self, *args, **kwargs):
256-
pass
257-
258-
def system_message(self, message: str) -> any:
259-
return {"role": "system", "content": message}
260-
261-
def user_message(self, message: str) -> any:
262-
return {"role": "user", "content": message}
263-
264-
def assistant_message(self, message: str) -> any:
265-
return {"role": "assistant", "content": message}

tests/test_pgvector.py

+32-7
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,47 @@
33
from dotenv import load_dotenv
44

55
# from vanna.pgvector import PG_VectorStore
6+
# from vanna.openai import OpenAI_Chat
67

8+
# assume .env file placed next to file with provided env vars
79
load_dotenv()
810

9-
# Removing thiese tests for now until the dependencies are sorted out
1011
# def get_vanna_connection_string():
1112
# server = os.environ.get("PG_SERVER")
1213
# driver = "psycopg"
13-
# port = 5434
14+
# port = os.environ.get("PG_PORT", 5432)
1415
# database = os.environ.get("PG_DATABASE")
1516
# username = os.environ.get("PG_USERNAME")
1617
# password = os.environ.get("PG_PASSWORD")
1718

18-
# return f"postgresql+psycopg://{username}:{password}@{server}:{port}/{database}"
19+
# def test_pgvector_e2e():
20+
# # configure Vanna to use OpenAI and PGVector
21+
# class VannaCustom(PG_VectorStore, OpenAI_Chat):
22+
# def __init__(self, config=None):
23+
# PG_VectorStore.__init__(self, config=config)
24+
# OpenAI_Chat.__init__(self, config=config)
25+
26+
# vn = VannaCustom(config={
27+
# 'api_key': os.environ['OPENAI_API_KEY'],
28+
# 'model': 'gpt-3.5-turbo',
29+
# "connection_string": get_vanna_connection_string(),
30+
# })
1931

32+
# # connect to SQLite database
33+
# vn.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')
34+
35+
# # train Vanna on DDLs
36+
# df_ddl = vn.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null")
37+
# for ddl in df_ddl['sql'].to_list():
38+
# vn.train(ddl=ddl)
39+
# assert len(vn.get_related_ddl("dummy question")) == 10 # assume 10 DDL chunks are retrieved by default
40+
41+
# question = "What are the top 7 customers by sales?"
42+
# sql = vn.generate_sql(question)
43+
# df = vn.run_sql(sql)
44+
# assert len(df) == 7
45+
46+
# # test if Vanna can generate an answer
47+
# answer = vn.ask(question)
48+
# assert answer is not None
2049

21-
# def test_pgvector():
22-
# connection_string = get_vanna_connection_string()
23-
# pgclient = PG_VectorStore(config={"connection_string": connection_string})
24-
# assert pgclient is not None

0 commit comments

Comments
 (0)