Skip to content

Commit e317d45

Browse files
authored
Bug-Fix[Community] Fix FastEmbedEmbeddings (langchain-ai#26764)
langchain-ai#26759 - Fix langchain-ai#26759 - Change `model` param from private to public, which may not be initiated. - Add test case
1 parent a8e1577 commit e317d45

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

libs/community/langchain_community/embeddings/fastembed.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
6565
Defaults to `None`.
6666
"""
6767

68-
_model: Any = None # : :meta private:
68+
model: Any = None # : :meta private:
6969

7070
model_config = ConfigDict(extra="allow", protected_namespaces=())
7171

@@ -91,7 +91,7 @@ def validate_environment(cls, values: Dict) -> Dict:
9191
'FastEmbedEmbeddings requires `pip install -U "fastembed>=0.2.0"`.'
9292
)
9393

94-
values["_model"] = fastembed.TextEmbedding(
94+
values["model"] = fastembed.TextEmbedding(
9595
model_name=model_name,
9696
max_length=max_length,
9797
cache_dir=cache_dir,
@@ -110,11 +110,11 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
110110
"""
111111
embeddings: List[np.ndarray]
112112
if self.doc_embed_type == "passage":
113-
embeddings = self._model.passage_embed(
113+
embeddings = self.model.passage_embed(
114114
texts, batch_size=self.batch_size, parallel=self.parallel
115115
)
116116
else:
117-
embeddings = self._model.embed(
117+
embeddings = self.model.embed(
118118
texts, batch_size=self.batch_size, parallel=self.parallel
119119
)
120120
return [e.tolist() for e in embeddings]
@@ -129,7 +129,7 @@ def embed_query(self, text: str) -> List[float]:
129129
Embeddings for the text.
130130
"""
131131
query_embeddings: np.ndarray = next(
132-
self._model.query_embed(
132+
self.model.query_embed(
133133
text, batch_size=self.batch_size, parallel=self.parallel
134134
)
135135
)

libs/community/tests/integration_tests/embeddings/test_fastembed.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,11 @@ async def test_fastembed_async_embedding_query(
8080
embedding = FastEmbedEmbeddings(model_name=model_name, max_length=max_length) # type: ignore[call-arg]
8181
output = await embedding.aembed_query(document)
8282
assert len(output) == 384
83+
84+
85+
def test_fastembed_embedding_query_with_default_params() -> None:
86+
"""Test fastembed embeddings for query with default model params"""
87+
document = "foo bar"
88+
embedding = FastEmbedEmbeddings() # type: ignore[call-arg]
89+
output = embedding.embed_query(document)
90+
assert len(output) == 384

0 commit comments

Comments
 (0)