Skip to content

Commit

Permalink
Revert "feat: embedding progress bar (#71)"
Browse files Browse the repository at this point in the history
This reverts commit 2c7fee3.
  • Loading branch information
NirantK authored Dec 13, 2023
1 parent 5537953 commit 9c546e2
Showing 1 changed file with 37 additions and 48 deletions.
85 changes: 37 additions & 48 deletions fastembed/embedding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
import math
import os
import shutil
import tarfile
Expand Down Expand Up @@ -462,8 +461,8 @@ def __init__(
Args:
model_name (str): The name of the model to use.
max_length (int, optional): The maximum number of tokens. Defaults to 512. Unknown behavior for values > 512.
cache_dir (str, optional): The path to the cache directory. \
Can be set using the `FASTEMBED_CACHE_PATH` env variable. \
cache_dir (str, optional): The path to the cache directory.
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
Defaults to `fastembed_cache` in the system's temp directory.
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
Expand All @@ -485,7 +484,7 @@ def __init__(
max_threads=threads)

def embed(
self, documents: Union[str, Iterable[str]], batch_size: int = 256, parallel: int = None, show_progress: bool = True
self, documents: Union[str, Iterable[str]], batch_size: int = 256, parallel: int = None
) -> Iterable[np.ndarray]:
"""
Encode a list of documents into list of embeddings.
Expand All @@ -498,7 +497,6 @@ def embed(
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
If 0, use all available cores.
If None, don't use data-parallel processing, use default onnxruntime threading instead.
show_progress (bool, optional): Whether to show a progress bar. Defaults to True.
Returns:
List of embeddings, one per document
Expand All @@ -515,26 +513,22 @@ def embed(

if parallel == 0:
parallel = os.cpu_count()

with tqdm(total=len(documents), disable=not show_progress) as progress_bar:
batch_iterable = iter_batch(documents, batch_size)
if parallel is None or is_small:
for batch in batch_iterable:
embeddings, _ = self.model.onnx_embed(batch)
yield from normalize(embeddings[:, 0]).astype(np.float32)
progress_bar.update(len(embeddings))
else:
start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
params = {
"path": self._model_dir,
"model_name": self.model_name,
"max_length": self._max_length,
}
pool = ParallelWorkerPool(parallel, EmbeddingWorker, start_method=start_method)
for batch in pool.ordered_map(batch_iterable, **params):
embeddings, _ = batch
yield from normalize(embeddings[:, 0]).astype(np.float32)
progress_bar.update(len(embeddings))

if parallel is None or is_small:
for batch in iter_batch(documents, batch_size):
embeddings, _ = self.model.onnx_embed(batch)
yield from normalize(embeddings[:, 0]).astype(np.float32)
else:
start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
params = {
"path": self._model_dir,
"model_name": self.model_name,
"max_length": self._max_length,
}
pool = ParallelWorkerPool(parallel, EmbeddingWorker, start_method=start_method)
for batch in pool.ordered_map(iter_batch(documents, batch_size), **params):
embeddings, _ = batch
yield from normalize(embeddings[:, 0]).astype(np.float32)

@classmethod
def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]:
Expand Down Expand Up @@ -587,8 +581,8 @@ def __init__(
Args:
model_name (str): The name of the model to use.
max_length (int, optional): The maximum number of tokens. Defaults to 512. Unknown behavior for values > 512.
cache_dir (str, optional): The path to the cache directory. \
Can be set using the `FASTEMBED_CACHE_PATH` env variable. \
cache_dir (str, optional): The path to the cache directory.
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
Defaults to `fastembed_cache` in the system's temp directory.
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
Raises:
Expand All @@ -609,7 +603,7 @@ def __init__(
max_threads=threads)

def embed(
self, documents: Union[str, Iterable[str]], batch_size: int = 256, parallel: int = None, show_progress: bool = True
self, documents: Union[str, Iterable[str]], batch_size: int = 256, parallel: int = None
) -> Iterable[np.ndarray]:
"""
Encode a list of documents into list of embeddings.
Expand All @@ -621,7 +615,6 @@ def embed(
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
If 0, use all available cores.
If None, don't use data-parallel processing, use default onnxruntime threading instead.
show_progress (bool, optional): Whether to show a progress bar. Defaults to True.
Returns:
List of embeddings, one per document
"""
Expand All @@ -638,25 +631,21 @@ def embed(
if parallel == 0:
parallel = os.cpu_count()

with tqdm(total=len(documents), disable=not show_progress) as progress_bar:
batch_iterable = iter_batch(documents, batch_size)
if parallel is None or is_small:
for batch in batch_iterable:
embeddings, attn_mask = self.model.onnx_embed(batch)
yield from normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32)
progress_bar.update(len(embeddings))
else:
start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
params = {
"path": self._model_dir,
"model_name": self.model_name,
"max_length": self._max_length,
}
pool = ParallelWorkerPool(parallel, EmbeddingWorker, start_method=start_method)
for batch in pool.ordered_map(batch_iterable, **params):
embeddings, attn_mask = batch
yield from normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32)
progress_bar.update(len(embeddings))
if parallel is None or is_small:
for batch in iter_batch(documents, batch_size):
embeddings, attn_mask = self.model.onnx_embed(batch)
yield from normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32)
else:
start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
params = {
"path": self._model_dir,
"model_name": self.model_name,
"max_length": self._max_length,
}
pool = ParallelWorkerPool(parallel, EmbeddingWorker, start_method=start_method)
for batch in pool.ordered_map(iter_batch(documents, batch_size), **params):
embeddings, attn_mask = batch
yield from normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32)

@classmethod
def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]:
Expand Down

0 comments on commit 9c546e2

Please sign in to comment.