Skip to content

Commit

Permalink
feat: list available embedding/LLM models for ollama (#1840)
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwooders authored Oct 8, 2024
1 parent 7200bc2 commit 2345021
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 148 deletions.
7 changes: 7 additions & 0 deletions .github/workflows/test_ollama.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
name: Endpoint (Ollama)

env:
OLLAMA_BASE_URL: "http://localhost:11434"

on:
push:
branches: [ main ]
Expand Down Expand Up @@ -36,3 +39,7 @@ jobs:
- name: Test embedding endpoint
run: |
poetry run pytest -s -vv tests/test_endpoints.py::test_embedding_endpoint_ollama
- name: Test provider
run: |
poetry run pytest -s -vv tests/test_providers.py::test_ollama
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,4 @@ jobs:
LETTA_SERVER_PASS: test_server_token
PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }}
run: |
poetry run pytest -s -vv -k "not test_tools.py and not test_concurrent_connections.py and not test_quickstart and not test_endpoints and not test_storage and not test_server and not test_openai_client" tests
poetry run pytest -s -vv -k "not test_tools.py and not test_concurrent_connections.py and not test_quickstart and not test_endpoints and not test_storage and not test_server and not test_openai_client and not test_providers" tests
72 changes: 51 additions & 21 deletions letta/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,34 +122,64 @@ def get_model_context_window(self, model_name: str):
response = requests.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True})
response_json = response.json()

# thank you vLLM: https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L1675
possible_keys = [
# OPT
"max_position_embeddings",
# GPT-2
"n_positions",
# MPT
"max_seq_len",
# ChatGLM2
"seq_length",
# Command-R
"model_max_length",
# Others
"max_sequence_length",
"max_seq_length",
"seq_len",
]

## thank you vLLM: https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L1675
# possible_keys = [
# # OPT
# "max_position_embeddings",
# # GPT-2
# "n_positions",
# # MPT
# "max_seq_len",
# # ChatGLM2
# "seq_length",
# # Command-R
# "model_max_length",
# # Others
# "max_sequence_length",
# "max_seq_length",
# "seq_len",
# ]
# max_position_embeddings
# parse model cards: nous, dolphon, llama
for key, value in response_json["model_info"].items():
if "context_window" in key:
if "context_length" in key:
return value
return None

def get_model_embedding_dim(self, model_name: str):
import requests

response = requests.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True})
response_json = response.json()
for key, value in response_json["model_info"].items():
if "embedding_length" in key:
return value
return None

def list_embedding_models(self) -> List[EmbeddingConfig]:
# TODO: filter embedding models
return []
# https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
import requests

response = requests.get(f"{self.base_url}/api/tags")
if response.status_code != 200:
raise Exception(f"Failed to list Ollama models: {response.text}")
response_json = response.json()

configs = []
for model in response_json["models"]:
embedding_dim = self.get_model_embedding_dim(model["name"])
if not embedding_dim:
continue
configs.append(
EmbeddingConfig(
embedding_model=model["name"],
embedding_endpoint_type="ollama",
embedding_endpoint=self.base_url,
embedding_dim=embedding_dim,
embedding_chunk_size=300,
)
)
return configs


class GroqProvider(OpenAIProvider):
Expand Down
110 changes: 0 additions & 110 deletions letta/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,116 +55,6 @@ class Settings(BaseSettings):
pg_port: Optional[int] = None
pg_uri: Optional[str] = None # option to specifiy full uri

## llm configuration
# llm_endpoint: Optional[str] = None
# llm_endpoint_type: Optional[str] = None
# llm_model: Optional[str] = None
# llm_context_window: Optional[int] = None

## embedding configuration
# embedding_endpoint: Optional[str] = None
# embedding_endpoint_type: Optional[str] = None
# embedding_dim: Optional[int] = None
# embedding_model: Optional[str] = None
# embedding_chunk_size: int = 300

# @property
# def llm_config(self):

# # try to get LLM config from settings
# if self.llm_endpoint and self.llm_endpoint_type and self.llm_model and self.llm_context_window:
# return LLMConfig(
# model=self.llm_model,
# model_endpoint_type=self.llm_endpoint_type,
# model_endpoint=self.llm_endpoint,
# model_wrapper=None,
# context_window=self.llm_context_window,
# )
# else:
# if not self.llm_endpoint:
# printd(f"No LETTA_LLM_ENDPOINT provided")
# if not self.llm_endpoint_type:
# printd(f"No LETTA_LLM_ENDPOINT_TYPE provided")
# if not self.llm_model:
# printd(f"No LETTA_LLM_MODEL provided")
# if not self.llm_context_window:
# printd(f"No LETTA_LLM_CONTEX_WINDOW provided")

# # quickstart options
# if self.llm_model:
# try:
# return LLMConfig.default_config(self.llm_model)
# except ValueError:
# pass

# # try to read from config file (last resort)
# from letta.config import LettaConfig

# if LettaConfig.exists():
# config = LettaConfig.load()
# llm_config = LLMConfig(
# model=config.default_llm_config.model,
# model_endpoint_type=config.default_llm_config.model_endpoint_type,
# model_endpoint=config.default_llm_config.model_endpoint,
# model_wrapper=config.default_llm_config.model_wrapper,
# context_window=config.default_llm_config.context_window,
# )
# return llm_config

# # check OpenAI API key
# if os.getenv("OPENAI_API_KEY"):
# return LLMConfig.default_config(self.llm_model if self.llm_model else "gpt-4")

# return LLMConfig.default_config("letta")

# @property
# def embedding_config(self):

# # try to get LLM config from settings
# if self.embedding_endpoint and self.embedding_endpoint_type and self.embedding_model and self.embedding_dim:
# return EmbeddingConfig(
# embedding_model=self.embedding_model,
# embedding_endpoint_type=self.embedding_endpoint_type,
# embedding_endpoint=self.embedding_endpoint,
# embedding_dim=self.embedding_dim,
# embedding_chunk_size=self.embedding_chunk_size,
# )
# else:
# if not self.embedding_endpoint:
# printd(f"No LETTA_EMBEDDING_ENDPOINT provided")
# if not self.embedding_endpoint_type:
# printd(f"No LETTA_EMBEDDING_ENDPOINT_TYPE provided")
# if not self.embedding_model:
# printd(f"No LETTA_EMBEDDING_MODEL provided")
# if not self.embedding_dim:
# printd(f"No LETTA_EMBEDDING_DIM provided")

# # TODO
# ## quickstart options
# # if self.embedding_model:
# # try:
# # return EmbeddingConfig.default_config(self.embedding_model)
# # except ValueError as e:
# # pass

# # try to read from config file (last resort)
# from letta.config import LettaConfig

# if LettaConfig.exists():
# config = LettaConfig.load()
# return EmbeddingConfig(
# embedding_model=config.default_embedding_config.embedding_model,
# embedding_endpoint_type=config.default_embedding_config.embedding_endpoint_type,
# embedding_endpoint=config.default_embedding_config.embedding_endpoint,
# embedding_dim=config.default_embedding_config.embedding_dim,
# embedding_chunk_size=config.default_embedding_config.embedding_chunk_size,
# )

# if os.getenv("OPENAI_API_KEY"):
# return EmbeddingConfig.default_config(self.embedding_model if self.embedding_model else "text-embedding-ada-002")

# return EmbeddingConfig.default_config("letta")

@property
def letta_pg_uri(self) -> str:
if self.pg_uri:
Expand Down
31 changes: 15 additions & 16 deletions tests/test_providers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import os

from letta.providers import AnthropicProvider, GoogleAIProvider, OpenAIProvider
from letta.providers import (
AnthropicProvider,
GoogleAIProvider,
OllamaProvider,
OpenAIProvider,
)


def test_openai():
Expand All @@ -24,24 +29,18 @@ def test_anthropic():
# print(models)
#
#
# def test_ollama():
# provider = OllamaProvider()
# models = provider.list_llm_models()
# print(models)
#
#
def test_ollama():
provider = OllamaProvider(base_url=os.getenv("OLLAMA_BASE_URL"))
models = provider.list_llm_models()
print(models)

embedding_models = provider.list_embedding_models()
print(embedding_models)


def test_googleai():
provider = GoogleAIProvider(api_key=os.getenv("GEMINI_API_KEY"))
models = provider.list_llm_models()
print(models)

provider.list_embedding_models()


#
#
test_googleai()
# test_ollama()
# test_groq()
# test_openai()
# test_anthropic()

0 comments on commit 2345021

Please sign in to comment.