Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add sentence trimming to OpenAIWrapper #1526

Merged
merged 18 commits into from
Dec 4, 2024
Merged
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 42 additions & 6 deletions mteb/models/openai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,48 @@


class OpenAIWrapper(Wrapper):
def __init__(self, model_name: str, embed_dim: int | None = None, **kwargs) -> None:
def __init__(
self,
model_name: str,
max_tokens: int,
tokenizer_name: str = "cl100k_base", # since all models use this tokenizer now
embed_dim: int | None = None,
**kwargs,
) -> None:
requires_package(self, "openai", "Openai text embedding")
from openai import OpenAI

self._client = OpenAI()
self._model_name = model_name
self._embed_dim = embed_dim
self._max_tokens = max_tokens
self._tokenizer_name = tokenizer_name

def encode(self, sentences: list[str], **kwargs: Any) -> np.ndarray:
requires_package(self, "openai", "Openai text embedding")
requires_package(self, "tiktoken", "Tiktoken package")
KennethEnevoldsen marked this conversation as resolved.
Show resolved Hide resolved
import tiktoken
Samoed marked this conversation as resolved.
Show resolved Hide resolved
from openai import NotGiven

if self._model_name == "text-embedding-ada-002" and self._embed_dim is not None:
logger.warning(
"Reducing embedding size available only for text-embedding-3-* models"
)

trimmed_sentences = []
KennethEnevoldsen marked this conversation as resolved.
Show resolved Hide resolved
for sentence in sentences:
encoding = tiktoken.get_encoding(self._tokenizer_name)
encoded_sentence = encoding.encode(sentence)
if len(encoded_sentence) > self._max_tokens:
trimmed_sentence = encoding.decode(encoded_sentence[: self._max_tokens])
trimmed_sentences.append(trimmed_sentence)
else:
trimmed_sentences.append(sentence)

max_batch_size = 2048
sublists = [
sentences[i : i + max_batch_size]
for i in range(0, len(sentences), max_batch_size)
trimmed_sentences[i : i + max_batch_size]
for i in range(0, len(trimmed_sentences), max_batch_size)
]

all_embeddings = []
Expand All @@ -60,7 +81,12 @@ def _to_numpy(self, embedding_response) -> np.ndarray:
revision="1",
yjoonjang marked this conversation as resolved.
Show resolved Hide resolved
release_date="2024-01-25",
languages=None, # supported languages not specified
loader=partial(OpenAIWrapper, model_name="text-embedding-3-small"),
loader=partial(
OpenAIWrapper,
model_name="text-embedding-3-small",
tokenizer_name="cl100k_base",
max_tokens=8192,
),
max_tokens=8191,
embed_dim=1536,
open_weights=False,
Expand All @@ -77,7 +103,12 @@ def _to_numpy(self, embedding_response) -> np.ndarray:
revision="1",
release_date="2024-01-25",
languages=None, # supported languages not specified
loader=partial(OpenAIWrapper, model_name="text-embedding-3-large"),
loader=partial(
OpenAIWrapper,
model_name="text-embedding-3-large",
tokenizer_name="cl100k_base",
max_tokens=8192,
),
max_tokens=8191,
embed_dim=3072,
open_weights=False,
Expand All @@ -91,7 +122,12 @@ def _to_numpy(self, embedding_response) -> np.ndarray:
revision="1",
release_date="2022-12-15",
languages=None, # supported languages not specified
loader=partial(OpenAIWrapper, model_name="text-embedding-ada-002"),
loader=partial(
OpenAIWrapper,
model_name="text-embedding-ada-002",
tokenizer_name="cl100k_base",
max_tokens=8192,
),
max_tokens=8191,
embed_dim=1536,
open_weights=False,
Expand Down
Loading