Skip to content

Commit 0ac5536

Browse files
authored
community: add support for using GPUs with FastEmbedEmbeddings (#29627)
- **Description:** add a `gpu: bool = False` field to the `FastEmbedEmbeddings` class which enables to use GPU (through ONNX CUDA provider) when generating embeddings with any fastembed model. It just requires the user to install a different dependency and we use a different provider when instantiating `fastembed.TextEmbedding` - **Issue:** when generating embeddings for a really large amount of documents this drastically increase performance (honestly that is a must have in some situations, you can't just use CPU it is way too slow) - **Dependencies:** no direct change to dependencies, but internally the users will need to install `fastembed-gpu` instead of `fastembed`, I made all the changes to the init function to properly let the user know which dependency they should install depending on if they enabled `gpu` or not cf. fastembed docs about GPU for more details: https://qdrant.github.io/fastembed/examples/FastEmbed_GPU/ I did not added test because it would require access to a GPU in the testing environment
1 parent 0ceda55 commit 0ac5536

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

libs/community/langchain_community/embeddings/fastembed.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,13 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
6565
Defaults to `None`.
6666
"""
6767

68+
gpu: bool = False
69+
"""Enable the use of GPU through CUDA. This requires to install `fastembed-gpu`
70+
instead of `fastembed`. See https://qdrant.github.io/fastembed/examples/FastEmbed_GPU
71+
for more details.
72+
Defaults to False.
73+
"""
74+
6875
model: Any = None # : :meta private:
6976

7077
model_config = ConfigDict(extra="allow", protected_namespaces=())
@@ -76,26 +83,30 @@ def validate_environment(cls, values: Dict) -> Dict:
7683
max_length = values.get("max_length")
7784
cache_dir = values.get("cache_dir")
7885
threads = values.get("threads")
86+
gpu = values.get("gpu")
87+
pkg_to_import = "fastembed-gpu" if gpu else "fastembed"
7988

8089
try:
81-
fastembed = importlib.import_module("fastembed")
90+
fastembed = importlib.import_module(pkg_to_import)
8291

8392
except ModuleNotFoundError:
8493
raise ImportError(
85-
"Could not import 'fastembed' Python package. "
86-
"Please install it with `pip install fastembed`."
94+
f"Could not import '{pkg_to_import}' Python package. "
95+
f"Please install it with `pip install {pkg_to_import}`."
8796
)
8897

89-
if importlib.metadata.version("fastembed") < MIN_VERSION:
98+
if importlib.metadata.version(pkg_to_import) < MIN_VERSION:
9099
raise ImportError(
91-
'FastEmbedEmbeddings requires `pip install -U "fastembed>=0.2.0"`.'
100+
f"FastEmbedEmbeddings requires "
101+
f'`pip install -U "{pkg_to_import}>={MIN_VERSION}"`.'
92102
)
93103

94104
values["model"] = fastembed.TextEmbedding(
95105
model_name=model_name,
96106
max_length=max_length,
97107
cache_dir=cache_dir,
98108
threads=threads,
109+
providers=["CUDAExecutionProvider"] if gpu else None,
99110
)
100111
return values
101112

0 commit comments

Comments
 (0)