Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add MistralProvider #1883

Merged
merged 2 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions letta/llm_api/mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import requests

from letta.utils import printd, smart_urljoin


def mistral_get_model_list(url: str, api_key: str) -> dict:
url = smart_urljoin(url, "models")

headers = {"Content-Type": "application/json"}
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"

printd(f"Sending request to {url}")
response = None
try:
# TODO add query param "tool" to be true
response = requests.get(url, headers=headers)
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response_json = response.json() # convert to dict from string
return response_json
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
try:
if response:
response = response.json()
except:
pass
printd(f"Got HTTPError, exception={http_err}, response={response}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
try:
if response:
response = response.json()
except:
pass
printd(f"Got RequestException, exception={req_err}, response={response}")
raise req_err
except Exception as e:
# Handle other potential errors
try:
if response:
response = response.json()
except:
pass
printd(f"Got unknown Exception, exception={e}, response={response}")
raise e
44 changes: 44 additions & 0 deletions letta/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,50 @@ def list_embedding_models(self) -> List[EmbeddingConfig]:
return []


class MistralProvider(Provider):
name: str = "mistral"
api_key: str = Field(..., description="API key for the Mistral API.")
base_url: str = "https://api.mistral.ai/v1"

def list_llm_models(self) -> List[LLMConfig]:
from letta.llm_api.mistral import mistral_get_model_list

# Some hardcoded support for OpenRouter (so that we only get models with tool calling support)...
# See: https://openrouter.ai/docs/requests
response = mistral_get_model_list(self.base_url, api_key=self.api_key)

assert "data" in response, f"Mistral model query response missing 'data' field: {response}"

configs = []
for model in response["data"]:
# If model has chat completions and function calling enabled
if model["capabilities"]["completion_chat"] and model["capabilities"]["function_calling"]:
configs.append(
LLMConfig(
model=model["id"],
model_endpoint_type="openai",
model_endpoint=self.base_url,
context_window=model["max_context_length"],
)
)

return configs

def list_embedding_models(self) -> List[EmbeddingConfig]:
# Not supported for mistral
return []

def get_model_context_window(self, model_name: str) -> Optional[int]:
# Redoing this is fine because it's a pretty lightweight call
models = self.list_llm_models()

for m in models:
if model_name in m["id"]:
return int(m["max_context_length"])

return None


class OllamaProvider(OpenAIProvider):
name: str = "ollama"
base_url: str = Field(..., description="Base URL for the Ollama API.")
Expand Down
17 changes: 14 additions & 3 deletions tests/test_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from letta.providers import (
AnthropicProvider,
AzureProvider,
GoogleAIProvider,
MistralProvider,
OllamaProvider,
OpenAIProvider,
)
Expand Down Expand Up @@ -31,10 +33,13 @@ def test_anthropic():
#


# TODO: Add this test
# https://linear.app/letta/issue/LET-159/add-tests-for-azure-openai-in-test-providerspy-and-test-endpointspy
def test_azure():
pass
provider = AzureProvider(api_key=os.getenv("AZURE_API_KEY"), base_url=os.getenv("AZURE_BASE_URL"))
models = provider.list_llm_models()
print([m.model for m in models])

embed_models = provider.list_embedding_models()
print([m.embedding_model for m in embed_models])


def test_ollama():
Expand All @@ -54,6 +59,12 @@ def test_googleai():
provider.list_embedding_models()


def test_mistral():
provider = MistralProvider(api_key=os.getenv("MISTRAL_API_KEY"))
models = provider.list_llm_models()
print([m.model for m in models])


# def test_vllm():
# provider = VLLMProvider(base_url=os.getenv("VLLM_API_BASE"))
# models = provider.list_llm_models()
Expand Down
Loading