Skip to content

Chroma persistence #1028

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 55 additions & 10 deletions langchain/vectorstores/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,29 +28,39 @@ class Chroma(VectorStore):
"""

def __init__(
self, collection_name: str, embedding_function: Optional[Embeddings] = None
self,
collection_name: str,
embedding_function: Optional[Embeddings] = None,
persist_directory: Optional[str] = None,
) -> None:
"""Initialize with Chroma client."""
try:
import chromadb
import chromadb.config
except ImportError:
raise ValueError(
"Could not import chromadb python package. "
"Please it install it with `pip install chromadb`."
)

# TODO: Add support for custom client. For now this is in-memory only.
self._client = chromadb.Client()
self._client_settings = chromadb.config.Settings()
if persist_directory is not None:
self._client_settings = chromadb.config.Settings(
chroma_db_impl="duckdb+parquet", persist_directory=persist_directory
)
self._client = chromadb.Client(self._client_settings)
self._embedding_function = embedding_function
self._persist_directory = persist_directory

# Check if the collection exists, create it if not
if collection_name in [col.name for col in self._client.list_collections()]:
self._collection = self._client.get_collection(name=collection_name)
if embedding_function is not None:
logger.warning(
f"Collection {collection_name} already exists,"
" embedding function will not be updated."
)
# TODO: Persist the user's embedding function
logger.warning(
f"Collection {collection_name} already exists,"
" Do you have the right embedding function?"
)
else:
self._collection = self._client.create_collection(
name=collection_name,
Expand Down Expand Up @@ -78,7 +88,12 @@ def add_texts(
# TODO: Handle the case where the user doesn't provide ids on the Collection
if ids is None:
ids = [str(uuid.uuid1()) for _ in texts]
self._collection.add(metadatas=metadatas, documents=texts, ids=ids)
embeddings = None
if self._embedding_function is not None:
embeddings = self._embedding_function.embed_documents(list(texts))
self._collection.add(
metadatas=metadatas, embeddings=embeddings, documents=texts, ids=ids
)
return ids

def similarity_search(
Expand Down Expand Up @@ -116,6 +131,23 @@ def similarity_search(
]
return docs

def delete_collection(self) -> None:
"""Delete the collection."""
self._client.delete_collection(self._collection.name)

def persist(self) -> None:
"""Persist the collection.

This can be used to explicitly persist the data to disk.
It will also be called automatically when the object is destroyed.
"""
if self._persist_directory is None:
raise ValueError(
"You must specify a persist_directory on"
"creation to persist the collection."
)
self._client.persist()

@classmethod
def from_texts(
cls,
Expand All @@ -124,12 +156,17 @@ def from_texts(
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
collection_name: str = "langchain",
persist_directory: Optional[str] = None,
**kwargs: Any,
) -> Chroma:
"""Create a Chroma vectorstore from a raw documents.

If a persist_directory is specified, the collection will be persisted there.
Otherwise, the data will be ephemeral in-memory.

Args:
collection_name (str): Name of the collection to create.
persist_directory (Optional[str]): Directory to persist the collection.
documents (List[Document]): List of documents to add.
embedding (Optional[Embeddings]): Embedding function. Defaults to None.
metadatas (Optional[List[dict]]): List of metadatas. Defaults to None.
Expand All @@ -139,7 +176,9 @@ def from_texts(
Chroma: Chroma vectorstore.
"""
chroma_collection = cls(
collection_name=collection_name, embedding_function=embedding
collection_name=collection_name,
embedding_function=embedding,
persist_directory=persist_directory,
)
chroma_collection.add_texts(texts=texts, metadatas=metadatas, ids=ids)
return chroma_collection
Expand All @@ -151,12 +190,17 @@ def from_documents(
embedding: Optional[Embeddings] = None,
ids: Optional[List[str]] = None,
collection_name: str = "langchain",
persist_directory: Optional[str] = None,
**kwargs: Any,
) -> Chroma:
"""Create a Chroma vectorstore from a list of documents.

If a persist_directory is specified, the collection will be persisted there.
Otherwise, the data will be ephemeral in-memory.

Args:
collection_name (str): Name of the collection to create.
persist_directory (Optional[str]): Directory to persist the collection.
documents (List[Document]): List of documents to add to the vectorstore.
embedding (Optional[Embeddings]): Embedding function. Defaults to None.

Expand All @@ -166,9 +210,10 @@ def from_documents(
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
return cls.from_texts(
collection_name=collection_name,
texts=texts,
embedding=embedding,
metadatas=metadatas,
ids=ids,
collection_name=collection_name,
persist_directory=persist_directory,
)
33 changes: 33 additions & 0 deletions tests/integration_tests/vectorstores/test_chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,36 @@ def test_chroma_with_metadatas() -> None:
)
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo", metadata={"page": "0"})]


def test_chroma_with_persistence() -> None:
"""Test end to end construction and search, with persistence."""
chroma_persist_dir = "./tests/persist_dir"
collection_name = "test_collection"
texts = ["foo", "bar", "baz"]
docsearch = Chroma.from_texts(
collection_name=collection_name,
texts=texts,
embedding=FakeEmbeddings(),
persist_directory=chroma_persist_dir,
)

output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo")]

docsearch.persist()

# Get a new VectorStore from the persisted directory
docsearch = Chroma(
collection_name=collection_name,
embedding_function=FakeEmbeddings(),
persist_directory=chroma_persist_dir,
)
output = docsearch.similarity_search("foo", k=1)

# Clean up
docsearch.delete_collection()

# Persist doesn't need to be called again
# Data will be automatically persisted on object deletion
# Or on program exit