Skip to content

Commit

Permalink
feat: added dimensions parameters to Azure OpenAI Embedders (#7449)
Browse files Browse the repository at this point in the history
* added dimensions parameter to AzureOpenAIEmbedders

* created releasenote

* update release note

---------

Co-authored-by: Julian Risch <[email protected]>
  • Loading branch information
nickprock and julian-risch authored Apr 2, 2024
1 parent 6e28969 commit 42c5b7a
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 2 deletions.
12 changes: 11 additions & 1 deletion haystack/components/embedders/azure_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
azure_endpoint: Optional[str] = None,
api_version: Optional[str] = "2023-05-15",
azure_deployment: str = "text-embedding-ada-002",
dimensions: Optional[int] = None,
api_key: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False),
azure_ad_token: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False),
organization: Optional[str] = None,
Expand All @@ -53,6 +54,8 @@ def __init__(
The version of the API to use.
:param azure_deployment:
The deployment of the model, usually matches the model name.
:param dimensions:
The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
:param api_key:
The API key used for authentication.
:param azure_ad_token:
Expand Down Expand Up @@ -90,6 +93,7 @@ def __init__(
self.api_version = api_version
self.azure_endpoint = azure_endpoint
self.azure_deployment = azure_deployment
self.dimensions = dimensions
self.organization = organization
self.prefix = prefix
self.suffix = suffix
Expand Down Expand Up @@ -124,6 +128,7 @@ def to_dict(self) -> Dict[str, Any]:
self,
azure_endpoint=self.azure_endpoint,
azure_deployment=self.azure_deployment,
dimensions=self.dimensions,
organization=self.organization,
api_version=self.api_version,
prefix=self.prefix,
Expand Down Expand Up @@ -175,7 +180,12 @@ def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List
meta: Dict[str, Any] = {"model": "", "usage": {"prompt_tokens": 0, "total_tokens": 0}}
for i in tqdm(range(0, len(texts_to_embed), batch_size), desc="Embedding Texts"):
batch = texts_to_embed[i : i + batch_size]
response = self._client.embeddings.create(model=self.azure_deployment, input=batch)
if self.dimensions is not None:
response = self._client.embeddings.create(
model=self.azure_deployment, dimensions=self.dimensions, input=batch
)
else:
response = self._client.embeddings.create(model=self.azure_deployment, input=batch)

# Append embeddings to the list
all_embeddings.extend(el.embedding for el in response.data)
Expand Down
12 changes: 11 additions & 1 deletion haystack/components/embedders/azure_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
azure_endpoint: Optional[str] = None,
api_version: Optional[str] = "2023-05-15",
azure_deployment: str = "text-embedding-ada-002",
dimensions: Optional[int] = None,
api_key: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False),
azure_ad_token: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False),
organization: Optional[str] = None,
Expand All @@ -48,6 +49,8 @@ def __init__(
The version of the API to use.
:param azure_deployment:
The deployment of the model, usually matches the model name.
:param dimensions:
The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
:param api_key:
The API key used for authentication.
:param azure_ad_token:
Expand Down Expand Up @@ -80,6 +83,7 @@ def __init__(
self.api_version = api_version
self.azure_endpoint = azure_endpoint
self.azure_deployment = azure_deployment
self.dimensions = dimensions
self.organization = organization
self.prefix = prefix
self.suffix = suffix
Expand Down Expand Up @@ -110,6 +114,7 @@ def to_dict(self) -> Dict[str, Any]:
self,
azure_endpoint=self.azure_endpoint,
azure_deployment=self.azure_deployment,
dimensions=self.dimensions,
organization=self.organization,
api_version=self.api_version,
prefix=self.prefix,
Expand Down Expand Up @@ -156,7 +161,12 @@ def run(self, text: str):
# finally, replace newlines as recommended by OpenAI docs
processed_text = f"{self.prefix}{text}{self.suffix}".replace("\n", " ")

response = self._client.embeddings.create(model=self.azure_deployment, input=processed_text)
if self.dimensions is not None:
response = self._client.embeddings.create(
model=self.azure_deployment, dimensions=self.dimensions, input=processed_text
)
else:
response = self._client.embeddings.create(model=self.azure_deployment, input=processed_text)

return {
"embedding": response.data[0].embedding,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
enhancements:
- |
add dimensions parameter to Azure OpenAI Embedders (AzureOpenAITextEmbedder and AzureOpenAIDocumentEmbedder) to fully support new embedding models like text-embedding-3-small, text-embedding-3-large and upcoming ones
2 changes: 2 additions & 0 deletions test/components/embedders/test_azure_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def test_init_default(self, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
embedder = AzureOpenAIDocumentEmbedder(azure_endpoint="https://example-resource.azure.openai.com/")
assert embedder.azure_deployment == "text-embedding-ada-002"
assert embedder.dimensions is None
assert embedder.organization is None
assert embedder.prefix == ""
assert embedder.suffix == ""
Expand All @@ -30,6 +31,7 @@ def test_to_dict(self, monkeypatch):
"azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"},
"api_version": "2023-05-15",
"azure_deployment": "text-embedding-ada-002",
"dimensions": None,
"azure_endpoint": "https://example-resource.azure.openai.com/",
"organization": None,
"prefix": "",
Expand Down
2 changes: 2 additions & 0 deletions test/components/embedders/test_azure_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def test_init_default(self, monkeypatch):

assert embedder._client.api_key == "fake-api-key"
assert embedder.azure_deployment == "text-embedding-ada-002"
assert embedder.dimensions is None
assert embedder.organization is None
assert embedder.prefix == ""
assert embedder.suffix == ""
Expand All @@ -26,6 +27,7 @@ def test_to_dict(self, monkeypatch):
"api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"},
"azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"},
"azure_deployment": "text-embedding-ada-002",
"dimensions": None,
"organization": None,
"azure_endpoint": "https://example-resource.azure.openai.com/",
"api_version": "2023-05-15",
Expand Down

0 comments on commit 42c5b7a

Please sign in to comment.