Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new: allow users to override providers #214

Merged
merged 10 commits into from
May 3, 2024
39 changes: 32 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,18 @@ The default text embedding (`TextEmbedding`) model is Flag Embedding, presented

## 🚀 Installation

To install the FastEmbed library, pip works:
To install the FastEmbed library, pip works best. You can install it with or without GPU support:

```bash
pip install fastembed
```

### ⚡️ With GPU

```bash
pip install fastembed-gpu
```

## 📖 Quickstart

```python
Expand All @@ -42,6 +48,23 @@ embeddings_list = list(embedding_model.embed(documents))
len(embeddings_list[0]) # Vector of 384 dimensions
```

### ⚡️ FastEmbed on a GPU

FastEmbed supports running on GPU devices. It requires installation of the `fastembed-gpu` package.
Make sure not to have the `fastembed` package installed, as it might interfere with the `fastembed-gpu` package.

```bash
pip install fastembed-gpu
```

```python
from fastembed import TextEmbedding

embedding_model = TextEmbedding(model_name="BAAI/bge-small-en-v1.5", providers=["CUDAExecutionProvider"])
print("The model BAAI/bge-small-en-v1.5 is ready to use on a GPU.")

```

## Usage with Qdrant

Installation with Qdrant Client in Python:
Expand All @@ -50,7 +73,13 @@ Installation with Qdrant Client in Python:
pip install qdrant-client[fastembed]
```

You might have to use ```pip install 'qdrant-client[fastembed]'``` on zsh.
or

```bash
pip install qdrant-client[fastembed-gpu]
```

You might have to use quotes ```pip install 'qdrant-client[fastembed]'``` on zsh.

```python
from qdrant_client import QdrantClient
Expand Down Expand Up @@ -85,8 +114,4 @@ search_result = client.query(
query_text="This is a query document"
)
print(search_result)
```

#### Similar Work

Ilyas M. wrote about using [FlagEmbeddings with Optimum](https://twitter.com/IlysMoutawwakil/status/1705215192425288017) over CUDA.
```
3 changes: 3 additions & 0 deletions fastembed/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from fastembed.common.onnx_model import OnnxProvider

__all__ = ["OnnxProvider"]
28 changes: 26 additions & 2 deletions fastembed/common/onnx_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
import os
from multiprocessing import get_all_start_methods
from pathlib import Path
from typing import Any, Dict, Generic, Iterable, List, Optional, Tuple, Type, TypeVar, Union
from typing import (
Any,
Dict,
Generic,
Iterable,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
Sequence,
)

import numpy as np
import onnxruntime as ort
Expand All @@ -14,6 +26,8 @@
# Holds type of the embedding result
T = TypeVar("T")

OnnxProvider = Union[str, Tuple[str, Dict[Any, Any]]]


class OnnxModel(Generic[T]):
@classmethod
Expand All @@ -39,11 +53,21 @@ def load_onnx_model(
model_dir: Path,
model_file: str,
threads: Optional[int],
providers: Optional[Sequence[OnnxProvider]] = None,
) -> None:
model_path = model_dir / model_file

# List of Execution Providers: https://onnxruntime.ai/docs/execution-providers
onnx_providers = ["CPUExecutionProvider"]

onnx_providers = ["CPUExecutionProvider"] if providers is None else list(providers)
available_providers = ort.get_available_providers()
for provider in onnx_providers:
# check providers available
provider_name = provider if isinstance(provider, str) else provider[0]
if provider_name not in available_providers:
raise ValueError(
f"Provider {provider_name} is not available. Available providers: {available_providers}"
)
Comment on lines +63 to +70
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the available_providers check!


so = ort.SessionOptions()
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
Expand Down
8 changes: 6 additions & 2 deletions fastembed/sparse/sparse_text_embedding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Type, Dict, Any, Union, Iterable, Optional
from typing import List, Type, Dict, Any, Union, Iterable, Optional, Sequence

from fastembed.common import OnnxProvider
from fastembed.sparse.sparse_embedding_base import SparseTextEmbeddingBase, SparseEmbedding
from fastembed.sparse.splade_pp import SpladePP

Expand Down Expand Up @@ -42,14 +43,17 @@ def __init__(
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
providers: Optional[Sequence[OnnxProvider]] = None,
**kwargs,
):
super().__init__(model_name, cache_dir, threads, **kwargs)

for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY:
supported_models = EMBEDDING_MODEL_TYPE.list_supported_models()
if any(model_name.lower() == model["model"].lower() for model in supported_models):
self.model = EMBEDDING_MODEL_TYPE(model_name, cache_dir, threads, **kwargs)
self.model = EMBEDDING_MODEL_TYPE(
model_name, cache_dir, threads, providers=providers, **kwargs
)
return

raise ValueError(
Expand Down
6 changes: 4 additions & 2 deletions fastembed/sparse/splade_pp.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Type
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Type, Sequence

import numpy as np

from fastembed.common.onnx_model import EmbeddingWorker, OnnxModel
from fastembed.common.onnx_model import EmbeddingWorker, OnnxModel, OnnxProvider
from fastembed.common.utils import define_cache_dir
from fastembed.sparse.sparse_embedding_base import SparseEmbedding, SparseTextEmbeddingBase

Expand Down Expand Up @@ -63,6 +63,7 @@ def __init__(
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
providers: Optional[Sequence[OnnxProvider]] = None,
**kwargs,
):
"""
Expand All @@ -88,6 +89,7 @@ def __init__(
model_dir=model_dir,
model_file=model_description["model_file"],
threads=threads,
providers=providers,
)

def embed(
Expand Down
6 changes: 4 additions & 2 deletions fastembed/text/onnx_embedding.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Dict, Optional, Tuple, Union, Iterable, Type, List, Any
from typing import Dict, Optional, Tuple, Union, Iterable, Type, List, Any, Sequence

import numpy as np

from fastembed.common.onnx_model import OnnxModel, EmbeddingWorker
from fastembed.common.onnx_model import OnnxModel, EmbeddingWorker, OnnxProvider
from fastembed.common.models import normalize
from fastembed.common.utils import define_cache_dir
from fastembed.text.text_embedding_base import TextEmbeddingBase
Expand Down Expand Up @@ -211,6 +211,7 @@ def __init__(
model_name: str = "BAAI/bge-small-en-v1.5",
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
providers: Optional[Sequence[OnnxProvider]] = None,
**kwargs,
):
"""
Expand All @@ -235,6 +236,7 @@ def __init__(
model_dir=model_dir,
model_file=model_description["model_file"],
threads=threads,
providers=providers,
)

def embed(
Expand Down
8 changes: 6 additions & 2 deletions fastembed/text/text_embedding.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Any, Dict, Iterable, List, Optional, Type, Union
from typing import Any, Dict, Iterable, List, Optional, Type, Union, Sequence

import numpy as np

from fastembed.common import OnnxProvider
from fastembed.text.e5_onnx_embedding import E5OnnxEmbedding
from fastembed.text.jina_onnx_embedding import JinaOnnxEmbedding
from fastembed.text.onnx_embedding import OnnxTextEmbedding
Expand Down Expand Up @@ -49,14 +50,17 @@ def __init__(
model_name: str = "BAAI/bge-small-en-v1.5",
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
providers: Optional[Sequence[OnnxProvider]] = None,
**kwargs,
):
super().__init__(model_name, cache_dir, threads, **kwargs)

for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY:
supported_models = EMBEDDING_MODEL_TYPE.list_supported_models()
if any(model_name.lower() == model["model"].lower() for model in supported_models):
self.model = EMBEDDING_MODEL_TYPE(model_name, cache_dir, threads, **kwargs)
self.model = EMBEDDING_MODEL_TYPE(
model_name, cache_dir, threads, providers=providers, **kwargs
)
return

raise ValueError(
Expand Down
Loading