Skip to content

Commit

Permalink
ci: patch embedding issue in tests (#1096)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpacker authored Mar 5, 2024
1 parent d3d3612 commit 3fd568e
Show file tree
Hide file tree
Showing 4 changed files with 338 additions and 6 deletions.
23 changes: 21 additions & 2 deletions memgpt/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,25 @@ def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] = None
credentials = MemGPTCredentials.load()

if endpoint_type == "openai":
assert credentials.openai_key is not None
from llama_index.embeddings.openai import OpenAIEmbedding

additional_kwargs = {"user_id": user_id} if user_id else {}
model = OpenAIEmbedding(api_base=config.embedding_endpoint, api_key=credentials.openai_key, additional_kwargs=additional_kwargs)
model = OpenAIEmbedding(
api_base=config.embedding_endpoint,
api_key=credentials.openai_key,
additional_kwargs=additional_kwargs,
)
return model

elif endpoint_type == "azure":
assert all(
[
credentials.azure_key is not None,
credentials.azure_embedding_endpoint is not None,
credentials.azure_version is not None,
]
)
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding

# https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#embeddings
Expand All @@ -176,7 +189,13 @@ def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] = None
azure_endpoint=credentials.azure_endpoint,
api_version=credentials.azure_version,
)

elif endpoint_type == "hugging-face":
return EmbeddingEndpoint(model=config.embedding_model, base_url=config.embedding_endpoint, user=user_id)
return EmbeddingEndpoint(
model=config.embedding_model,
base_url=config.embedding_endpoint,
user=user_id,
)

else:
return default_embedding_model()
Loading

0 comments on commit 3fd568e

Please sign in to comment.