Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: list available embedding/LLM models for ollama #1840

Merged
merged 6 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .github/workflows/test_ollama.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
name: Endpoint (Ollama)

env:
OLLAMA_BASE_URL: "http://localhost:11434"

on:
push:
branches: [ main ]
Expand Down Expand Up @@ -36,3 +39,7 @@ jobs:
- name: Test embedding endpoint
run: |
poetry run pytest -s -vv tests/test_endpoints.py::test_embedding_endpoint_ollama

- name: Test provider
run: |
poetry run pytest -s -vv tests/test_providers.py::test_ollama
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,4 @@ jobs:
LETTA_SERVER_PASS: test_server_token
PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }}
run: |
poetry run pytest -s -vv -k "not test_tools.py and not test_concurrent_connections.py and not test_quickstart and not test_endpoints and not test_storage and not test_server and not test_openai_client" tests
poetry run pytest -s -vv -k "not test_tools.py and not test_concurrent_connections.py and not test_quickstart and not test_endpoints and not test_storage and not test_server and not test_openai_client and not test_providers" tests
72 changes: 51 additions & 21 deletions letta/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,34 +122,64 @@ def get_model_context_window(self, model_name: str):
response = requests.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True})
response_json = response.json()

# thank you vLLM: https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L1675
possible_keys = [
# OPT
"max_position_embeddings",
# GPT-2
"n_positions",
# MPT
"max_seq_len",
# ChatGLM2
"seq_length",
# Command-R
"model_max_length",
# Others
"max_sequence_length",
"max_seq_length",
"seq_len",
]

## thank you vLLM: https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L1675
# possible_keys = [
# # OPT
# "max_position_embeddings",
# # GPT-2
# "n_positions",
# # MPT
# "max_seq_len",
# # ChatGLM2
# "seq_length",
# # Command-R
# "model_max_length",
# # Others
# "max_sequence_length",
# "max_seq_length",
# "seq_len",
# ]
# max_position_embeddings
# parse model cards: nous, dolphon, llama
for key, value in response_json["model_info"].items():
if "context_window" in key:
if "context_length" in key:
return value
return None

def get_model_embedding_dim(self, model_name: str):
import requests

response = requests.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True})
response_json = response.json()
for key, value in response_json["model_info"].items():
if "embedding_length" in key:
return value
return None

def list_embedding_models(self) -> List[EmbeddingConfig]:
# TODO: filter embedding models
return []
# https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
import requests

response = requests.get(f"{self.base_url}/api/tags")
if response.status_code != 200:
raise Exception(f"Failed to list Ollama models: {response.text}")
response_json = response.json()

configs = []
for model in response_json["models"]:
embedding_dim = self.get_model_embedding_dim(model["name"])
if not embedding_dim:
continue
configs.append(
EmbeddingConfig(
embedding_model=model["name"],
embedding_endpoint_type="ollama",
embedding_endpoint=self.base_url,
embedding_dim=embedding_dim,
embedding_chunk_size=300,
)
)
return configs


class GroqProvider(OpenAIProvider):
Expand Down
110 changes: 0 additions & 110 deletions letta/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,116 +55,6 @@ class Settings(BaseSettings):
pg_port: Optional[int] = None
pg_uri: Optional[str] = None # option to specifiy full uri

## llm configuration
# llm_endpoint: Optional[str] = None
# llm_endpoint_type: Optional[str] = None
# llm_model: Optional[str] = None
# llm_context_window: Optional[int] = None

## embedding configuration
# embedding_endpoint: Optional[str] = None
# embedding_endpoint_type: Optional[str] = None
# embedding_dim: Optional[int] = None
# embedding_model: Optional[str] = None
# embedding_chunk_size: int = 300

# @property
# def llm_config(self):

# # try to get LLM config from settings
# if self.llm_endpoint and self.llm_endpoint_type and self.llm_model and self.llm_context_window:
# return LLMConfig(
# model=self.llm_model,
# model_endpoint_type=self.llm_endpoint_type,
# model_endpoint=self.llm_endpoint,
# model_wrapper=None,
# context_window=self.llm_context_window,
# )
# else:
# if not self.llm_endpoint:
# printd(f"No LETTA_LLM_ENDPOINT provided")
# if not self.llm_endpoint_type:
# printd(f"No LETTA_LLM_ENDPOINT_TYPE provided")
# if not self.llm_model:
# printd(f"No LETTA_LLM_MODEL provided")
# if not self.llm_context_window:
# printd(f"No LETTA_LLM_CONTEX_WINDOW provided")

# # quickstart options
# if self.llm_model:
# try:
# return LLMConfig.default_config(self.llm_model)
# except ValueError:
# pass

# # try to read from config file (last resort)
# from letta.config import LettaConfig

# if LettaConfig.exists():
# config = LettaConfig.load()
# llm_config = LLMConfig(
# model=config.default_llm_config.model,
# model_endpoint_type=config.default_llm_config.model_endpoint_type,
# model_endpoint=config.default_llm_config.model_endpoint,
# model_wrapper=config.default_llm_config.model_wrapper,
# context_window=config.default_llm_config.context_window,
# )
# return llm_config

# # check OpenAI API key
# if os.getenv("OPENAI_API_KEY"):
# return LLMConfig.default_config(self.llm_model if self.llm_model else "gpt-4")

# return LLMConfig.default_config("letta")

# @property
# def embedding_config(self):

# # try to get LLM config from settings
# if self.embedding_endpoint and self.embedding_endpoint_type and self.embedding_model and self.embedding_dim:
# return EmbeddingConfig(
# embedding_model=self.embedding_model,
# embedding_endpoint_type=self.embedding_endpoint_type,
# embedding_endpoint=self.embedding_endpoint,
# embedding_dim=self.embedding_dim,
# embedding_chunk_size=self.embedding_chunk_size,
# )
# else:
# if not self.embedding_endpoint:
# printd(f"No LETTA_EMBEDDING_ENDPOINT provided")
# if not self.embedding_endpoint_type:
# printd(f"No LETTA_EMBEDDING_ENDPOINT_TYPE provided")
# if not self.embedding_model:
# printd(f"No LETTA_EMBEDDING_MODEL provided")
# if not self.embedding_dim:
# printd(f"No LETTA_EMBEDDING_DIM provided")

# # TODO
# ## quickstart options
# # if self.embedding_model:
# # try:
# # return EmbeddingConfig.default_config(self.embedding_model)
# # except ValueError as e:
# # pass

# # try to read from config file (last resort)
# from letta.config import LettaConfig

# if LettaConfig.exists():
# config = LettaConfig.load()
# return EmbeddingConfig(
# embedding_model=config.default_embedding_config.embedding_model,
# embedding_endpoint_type=config.default_embedding_config.embedding_endpoint_type,
# embedding_endpoint=config.default_embedding_config.embedding_endpoint,
# embedding_dim=config.default_embedding_config.embedding_dim,
# embedding_chunk_size=config.default_embedding_config.embedding_chunk_size,
# )

# if os.getenv("OPENAI_API_KEY"):
# return EmbeddingConfig.default_config(self.embedding_model if self.embedding_model else "text-embedding-ada-002")

# return EmbeddingConfig.default_config("letta")

@property
def letta_pg_uri(self) -> str:
if self.pg_uri:
Expand Down
31 changes: 15 additions & 16 deletions tests/test_providers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import os

from letta.providers import AnthropicProvider, GoogleAIProvider, OpenAIProvider
from letta.providers import (
AnthropicProvider,
GoogleAIProvider,
OllamaProvider,
OpenAIProvider,
)


def test_openai():
Expand All @@ -24,24 +29,18 @@ def test_anthropic():
# print(models)
#
#
# def test_ollama():
# provider = OllamaProvider()
# models = provider.list_llm_models()
# print(models)
#
#
def test_ollama():
provider = OllamaProvider(base_url=os.getenv("OLLAMA_BASE_URL"))
models = provider.list_llm_models()
print(models)

embedding_models = provider.list_embedding_models()
print(embedding_models)


def test_googleai():
provider = GoogleAIProvider(api_key=os.getenv("GEMINI_API_KEY"))
models = provider.list_llm_models()
print(models)

provider.list_embedding_models()


#
#
test_googleai()
# test_ollama()
# test_groq()
# test_openai()
# test_anthropic()
Loading