Skip to content

Commit 4915144

Browse files
authored
Merge branch 'main' into pgvector_fixes
2 parents 85586ac + ed26e2a commit 4915144

File tree

5 files changed

+77
-62
lines changed

5 files changed

+77
-62
lines changed

papers/ai-sql-accuracy-2023-08-17.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ By providing just those 3 example queries, we see substantial improvements to th
229229

230230
Enterprise data warehouses often contain 100s (or even 1000s) of tables, and an order of magnitude more queries that cover all the use cases within their organizations. Given the limited size of the context windows of modern LLMs, we can’t just shove all the prior queries and schema definitions into the prompt.
231231

232-
Our final approach to context is a more sophisticated ML approach - load embeddings of prior queries and the table schemas into a vector database, and only choose the most relevant queries / tables to the question asked. Here's a diagram of what we are doing - note the contextual relevance search in the red box -
232+
Our final approach to context is a more sophisticated ML approach - load embeddings of prior queries and the table schemas into a vector database, and only choose the most relevant queries / tables to the question asked. Here's a diagram of what we are doing - note the contextual relevance search in the green box -
233233

234234
![](https://raw.githubusercontent.com/vanna-ai/vanna/main/papers/img/using-contextually-relevant-examples.png)
235235

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ bigquery = ["google-cloud-bigquery"]
3333
snowflake = ["snowflake-connector-python"]
3434
duckdb = ["duckdb"]
3535
google = ["google-generativeai", "google-cloud-aiplatform"]
36-
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai>=1.0.0", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client", "azure-search-documents", "azure-identity", "azure-common", "faiss-cpu", "boto", "botocore"]
36+
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai>=1.0.0", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client", "azure-search-documents", "azure-identity", "azure-common", "faiss-cpu", "boto", "boto3", "botocore", "langchain_core", "langchain_postgres"]
3737
test = ["tox"]
3838
chromadb = ["chromadb"]
3939
openai = ["openai"]

src/vanna/google/bigquery_vector.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
import os
33
import uuid
44
from typing import List, Optional
5+
from vertexai.language_models import (
6+
TextEmbeddingInput,
7+
TextEmbeddingModel
8+
)
59

610
import pandas as pd
711
from google.cloud import bigquery
@@ -23,17 +27,15 @@ def __init__(self, config: dict, **kwargs):
2327
or set as an environment variable, assign it.
2428
"""
2529
print("Configuring genai")
30+
self.type = "GEMINI"
2631
import google.generativeai as genai
2732

2833
genai.configure(api_key=config["api_key"])
2934

3035
self.genai = genai
3136
else:
37+
self.type = "VERTEX_AI"
3238
# Authenticate using VertexAI
33-
from vertexai.language_models import (
34-
TextEmbeddingInput,
35-
TextEmbeddingModel,
36-
)
3739

3840
if self.config.get("project_id"):
3941
self.project_id = self.config.get("project_id")
@@ -139,25 +141,42 @@ def fetch_similar_training_data(self, training_data_type: str, question: str, n_
139141
results = self.conn.query(query).result().to_dataframe()
140142
return results
141143

142-
def generate_question_embedding(self, data: str, **kwargs) -> List[float]:
143-
result = self.genai.embed_content(
144+
def get_embeddings(self, data: str, task: str) -> List[float]:
145+
embeddings = None
146+
147+
if self.type == "VERTEX_AI":
148+
input = [TextEmbeddingInput(data, task)]
149+
model = TextEmbeddingModel.from_pretrained("text-embedding-004")
150+
151+
result = model.get_embeddings(input)
152+
153+
if len(result) > 0:
154+
embeddings = result[0].values
155+
else:
156+
# Use Gemini Consumer API
157+
result = self.genai.embed_content(
144158
model="models/text-embedding-004",
145159
content=data,
146-
task_type="retrieval_query")
160+
task_type=task)
147161

148-
if 'embedding' in result:
149-
return result['embedding']
162+
if 'embedding' in result:
163+
embeddings = result['embedding']
164+
165+
return embeddings
166+
167+
def generate_question_embedding(self, data: str, **kwargs) -> List[float]:
168+
result = self.get_embeddings(data, "RETRIEVAL_QUERY")
169+
170+
if result != None:
171+
return result
150172
else:
151173
raise ValueError("No embeddings returned")
152174

153175
def generate_storage_embedding(self, data: str, **kwargs) -> List[float]:
154-
result = self.genai.embed_content(
155-
model="models/text-embedding-004",
156-
content=data,
157-
task_type="retrieval_document")
176+
result = self.get_embeddings(data, "RETRIEVAL_DOCUMENT")
158177

159-
if 'embedding' in result:
160-
return result['embedding']
178+
if result != None:
179+
return result
161180
else:
162181
raise ValueError("No embeddings returned")
163182

src/vanna/google/gemini_chat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def __init__(self, config=None):
1515
if "model_name" in config:
1616
model_name = config["model_name"]
1717
else:
18-
model_name = "gemini-1.0-pro"
18+
model_name = "gemini-1.5-pro"
1919

2020
self.google_api_key = None
2121

@@ -30,7 +30,7 @@ def __init__(self, config=None):
3030
self.chat_model = genai.GenerativeModel(model_name)
3131
else:
3232
# Authenticate using VertexAI
33-
from vertexai.preview.generative_models import GenerativeModel
33+
from vertexai.generative_models import GenerativeModel
3434
self.chat_model = GenerativeModel(model_name)
3535

3636
def system_message(self, message: str) -> any:

tests/test_pgvector.py

Lines changed: 39 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,52 +2,48 @@
22

33
from dotenv import load_dotenv
44

5-
from vanna.pgvector import PG_VectorStore
6-
from vanna.openai import OpenAI_Chat
7-
5+
# from vanna.pgvector import PG_VectorStore
6+
# from vanna.openai import OpenAI_Chat
87

98
# assume .env file placed next to file with provided env vars
109
load_dotenv()
1110

12-
13-
def get_vanna_connection_string():
14-
server = os.environ.get("PG_SERVER")
15-
driver = "psycopg"
16-
port = os.environ.get("PG_PORT", 5432)
17-
database = os.environ.get("PG_DATABASE")
18-
username = os.environ.get("PG_USERNAME")
19-
password = os.environ.get("PG_PASSWORD")
20-
21-
return f"postgresql+psycopg://{username}:{password}@{server}:{port}/{database}"
22-
23-
24-
def test_pgvector_e2e():
25-
# configure Vanna to use OpenAI and PGVector
26-
class VannaCustom(PG_VectorStore, OpenAI_Chat):
27-
def __init__(self, config=None):
28-
PG_VectorStore.__init__(self, config=config)
29-
OpenAI_Chat.__init__(self, config=config)
11+
# def get_vanna_connection_string():
12+
# server = os.environ.get("PG_SERVER")
13+
# driver = "psycopg"
14+
# port = os.environ.get("PG_PORT", 5432)
15+
# database = os.environ.get("PG_DATABASE")
16+
# username = os.environ.get("PG_USERNAME")
17+
# password = os.environ.get("PG_PASSWORD")
18+
19+
# def test_pgvector_e2e():
20+
# # configure Vanna to use OpenAI and PGVector
21+
# class VannaCustom(PG_VectorStore, OpenAI_Chat):
22+
# def __init__(self, config=None):
23+
# PG_VectorStore.__init__(self, config=config)
24+
# OpenAI_Chat.__init__(self, config=config)
3025

31-
vn = VannaCustom(config={
32-
'api_key': os.environ['OPENAI_API_KEY'],
33-
'model': 'gpt-3.5-turbo',
34-
"connection_string": get_vanna_connection_string(),
35-
})
36-
37-
# connect to SQLite database
38-
vn.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')
39-
40-
# train Vanna on DDLs
41-
df_ddl = vn.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null")
42-
for ddl in df_ddl['sql'].to_list():
43-
vn.train(ddl=ddl)
44-
assert len(vn.get_related_ddl("dummy question")) == 10 # assume 10 DDL chunks are retrieved by default
26+
# vn = VannaCustom(config={
27+
# 'api_key': os.environ['OPENAI_API_KEY'],
28+
# 'model': 'gpt-3.5-turbo',
29+
# "connection_string": get_vanna_connection_string(),
30+
# })
31+
32+
# # connect to SQLite database
33+
# vn.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')
34+
35+
# # train Vanna on DDLs
36+
# df_ddl = vn.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null")
37+
# for ddl in df_ddl['sql'].to_list():
38+
# vn.train(ddl=ddl)
39+
# assert len(vn.get_related_ddl("dummy question")) == 10 # assume 10 DDL chunks are retrieved by default
4540

46-
question = "What are the top 7 customers by sales?"
47-
sql = vn.generate_sql(question)
48-
df = vn.run_sql(sql)
49-
assert len(df) == 7
50-
51-
# test if Vanna can generate an answer
52-
answer = vn.ask(question)
53-
assert answer is not None
41+
# question = "What are the top 7 customers by sales?"
42+
# sql = vn.generate_sql(question)
43+
# df = vn.run_sql(sql)
44+
# assert len(df) == 7
45+
46+
# # test if Vanna can generate an answer
47+
# answer = vn.ask(question)
48+
# assert answer is not None
49+

0 commit comments

Comments
 (0)