diff --git a/README.md b/README.md index 4bda7b4f..55274446 100644 --- a/README.md +++ b/README.md @@ -14,12 +14,18 @@ The default text embedding (`TextEmbedding`) model is Flag Embedding, presented ## 🚀 Installation -To install the FastEmbed library, pip works: +To install the FastEmbed library, pip works best. You can install it with or without GPU support: ```bash pip install fastembed ``` +### ⚡️ With GPU + +```bash +pip install fastembed-gpu +``` + ## 📖 Quickstart ```python @@ -42,6 +48,23 @@ embeddings_list = list(embedding_model.embed(documents)) len(embeddings_list[0]) # Vector of 384 dimensions ``` +### ⚡️ FastEmbed on a GPU + +FastEmbed supports running on GPU devices. It requires installation of the `fastembed-gpu` package. +Make sure not to have the `fastembed` package installed, as it might interfere with the `fastembed-gpu` package. + +```bash +pip install fastembed-gpu +``` + +```python +from fastembed import TextEmbedding + +embedding_model = TextEmbedding(model_name="BAAI/bge-small-en-v1.5", providers=["CUDAExecutionProvider"]) +print("The model BAAI/bge-small-en-v1.5 is ready to use on a GPU.") + +``` + ## Usage with Qdrant Installation with Qdrant Client in Python: @@ -50,7 +73,13 @@ Installation with Qdrant Client in Python: pip install qdrant-client[fastembed] ``` -You might have to use ```pip install 'qdrant-client[fastembed]'``` on zsh. +or + +```bash +pip install qdrant-client[fastembed-gpu] +``` + +You might have to use quotes ```pip install 'qdrant-client[fastembed]'``` on zsh. ```python from qdrant_client import QdrantClient @@ -85,8 +114,4 @@ search_result = client.query( query_text="This is a query document" ) print(search_result) -``` - -#### Similar Work - -Ilyas M. wrote about using [FlagEmbeddings with Optimum](https://twitter.com/IlysMoutawwakil/status/1705215192425288017) over CUDA. +``` \ No newline at end of file diff --git a/fastembed/common/__init__.py b/fastembed/common/__init__.py index e69de29b..06e249f0 100644 --- a/fastembed/common/__init__.py +++ b/fastembed/common/__init__.py @@ -0,0 +1,3 @@ +from fastembed.common.onnx_model import OnnxProvider + +__all__ = ["OnnxProvider"] diff --git a/fastembed/common/onnx_model.py b/fastembed/common/onnx_model.py index 588b7bf5..5ce48719 100644 --- a/fastembed/common/onnx_model.py +++ b/fastembed/common/onnx_model.py @@ -1,7 +1,19 @@ import os from multiprocessing import get_all_start_methods from pathlib import Path -from typing import Any, Dict, Generic, Iterable, List, Optional, Tuple, Type, TypeVar, Union +from typing import ( + Any, + Dict, + Generic, + Iterable, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, + Sequence, +) import numpy as np import onnxruntime as ort @@ -14,6 +26,8 @@ # Holds type of the embedding result T = TypeVar("T") +OnnxProvider = Union[str, Tuple[str, Dict[Any, Any]]] + class OnnxModel(Generic[T]): @classmethod @@ -39,11 +53,21 @@ def load_onnx_model( model_dir: Path, model_file: str, threads: Optional[int], + providers: Optional[Sequence[OnnxProvider]] = None, ) -> None: model_path = model_dir / model_file # List of Execution Providers: https://onnxruntime.ai/docs/execution-providers - onnx_providers = ["CPUExecutionProvider"] + + onnx_providers = ["CPUExecutionProvider"] if providers is None else list(providers) + available_providers = ort.get_available_providers() + for provider in onnx_providers: + # check providers available + provider_name = provider if isinstance(provider, str) else provider[0] + if provider_name not in available_providers: + raise ValueError( + f"Provider {provider_name} is not available. Available providers: {available_providers}" + ) so = ort.SessionOptions() so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL diff --git a/fastembed/sparse/sparse_text_embedding.py b/fastembed/sparse/sparse_text_embedding.py index aa838ee8..f55f7c56 100644 --- a/fastembed/sparse/sparse_text_embedding.py +++ b/fastembed/sparse/sparse_text_embedding.py @@ -1,5 +1,6 @@ -from typing import List, Type, Dict, Any, Union, Iterable, Optional +from typing import List, Type, Dict, Any, Union, Iterable, Optional, Sequence +from fastembed.common import OnnxProvider from fastembed.sparse.sparse_embedding_base import SparseTextEmbeddingBase, SparseEmbedding from fastembed.sparse.splade_pp import SpladePP @@ -42,6 +43,7 @@ def __init__( model_name: str, cache_dir: Optional[str] = None, threads: Optional[int] = None, + providers: Optional[Sequence[OnnxProvider]] = None, **kwargs, ): super().__init__(model_name, cache_dir, threads, **kwargs) @@ -49,7 +51,9 @@ def __init__( for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY: supported_models = EMBEDDING_MODEL_TYPE.list_supported_models() if any(model_name.lower() == model["model"].lower() for model in supported_models): - self.model = EMBEDDING_MODEL_TYPE(model_name, cache_dir, threads, **kwargs) + self.model = EMBEDDING_MODEL_TYPE( + model_name, cache_dir, threads, providers=providers, **kwargs + ) return raise ValueError( diff --git a/fastembed/sparse/splade_pp.py b/fastembed/sparse/splade_pp.py index f59114ca..cfb130c1 100644 --- a/fastembed/sparse/splade_pp.py +++ b/fastembed/sparse/splade_pp.py @@ -1,8 +1,8 @@ -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Type +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Type, Sequence import numpy as np -from fastembed.common.onnx_model import EmbeddingWorker, OnnxModel +from fastembed.common.onnx_model import EmbeddingWorker, OnnxModel, OnnxProvider from fastembed.common.utils import define_cache_dir from fastembed.sparse.sparse_embedding_base import SparseEmbedding, SparseTextEmbeddingBase @@ -63,6 +63,7 @@ def __init__( model_name: str, cache_dir: Optional[str] = None, threads: Optional[int] = None, + providers: Optional[Sequence[OnnxProvider]] = None, **kwargs, ): """ @@ -88,6 +89,7 @@ def __init__( model_dir=model_dir, model_file=model_description["model_file"], threads=threads, + providers=providers, ) def embed( diff --git a/fastembed/text/onnx_embedding.py b/fastembed/text/onnx_embedding.py index 059cce20..f2ae46fc 100644 --- a/fastembed/text/onnx_embedding.py +++ b/fastembed/text/onnx_embedding.py @@ -1,8 +1,8 @@ -from typing import Dict, Optional, Tuple, Union, Iterable, Type, List, Any +from typing import Dict, Optional, Tuple, Union, Iterable, Type, List, Any, Sequence import numpy as np -from fastembed.common.onnx_model import OnnxModel, EmbeddingWorker +from fastembed.common.onnx_model import OnnxModel, EmbeddingWorker, OnnxProvider from fastembed.common.models import normalize from fastembed.common.utils import define_cache_dir from fastembed.text.text_embedding_base import TextEmbeddingBase @@ -211,6 +211,7 @@ def __init__( model_name: str = "BAAI/bge-small-en-v1.5", cache_dir: Optional[str] = None, threads: Optional[int] = None, + providers: Optional[Sequence[OnnxProvider]] = None, **kwargs, ): """ @@ -235,6 +236,7 @@ def __init__( model_dir=model_dir, model_file=model_description["model_file"], threads=threads, + providers=providers, ) def embed( diff --git a/fastembed/text/text_embedding.py b/fastembed/text/text_embedding.py index 10b43625..d9923b3d 100644 --- a/fastembed/text/text_embedding.py +++ b/fastembed/text/text_embedding.py @@ -1,7 +1,8 @@ -from typing import Any, Dict, Iterable, List, Optional, Type, Union +from typing import Any, Dict, Iterable, List, Optional, Type, Union, Sequence import numpy as np +from fastembed.common import OnnxProvider from fastembed.text.e5_onnx_embedding import E5OnnxEmbedding from fastembed.text.jina_onnx_embedding import JinaOnnxEmbedding from fastembed.text.onnx_embedding import OnnxTextEmbedding @@ -49,6 +50,7 @@ def __init__( model_name: str = "BAAI/bge-small-en-v1.5", cache_dir: Optional[str] = None, threads: Optional[int] = None, + providers: Optional[Sequence[OnnxProvider]] = None, **kwargs, ): super().__init__(model_name, cache_dir, threads, **kwargs) @@ -56,7 +58,9 @@ def __init__( for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY: supported_models = EMBEDDING_MODEL_TYPE.list_supported_models() if any(model_name.lower() == model["model"].lower() for model in supported_models): - self.model = EMBEDDING_MODEL_TYPE(model_name, cache_dir, threads, **kwargs) + self.model = EMBEDDING_MODEL_TYPE( + model_name, cache_dir, threads, providers=providers, **kwargs + ) return raise ValueError(