From 2345021aafeced7bdc313b39a712669236a88f03 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Tue, 8 Oct 2024 14:57:11 -0700 Subject: [PATCH] feat: list available embedding/LLM models for ollama (#1840) --- .github/workflows/test_ollama.yml | 7 ++ .github/workflows/tests.yml | 2 +- letta/providers.py | 72 +++++++++++++------ letta/settings.py | 110 ------------------------------ tests/test_providers.py | 31 ++++----- 5 files changed, 74 insertions(+), 148 deletions(-) diff --git a/.github/workflows/test_ollama.yml b/.github/workflows/test_ollama.yml index baccde40b0..370c97b52f 100644 --- a/.github/workflows/test_ollama.yml +++ b/.github/workflows/test_ollama.yml @@ -1,5 +1,8 @@ name: Endpoint (Ollama) +env: + OLLAMA_BASE_URL: "http://localhost:11434" + on: push: branches: [ main ] @@ -36,3 +39,7 @@ jobs: - name: Test embedding endpoint run: | poetry run pytest -s -vv tests/test_endpoints.py::test_embedding_endpoint_ollama + + - name: Test provider + run: | + poetry run pytest -s -vv tests/test_providers.py::test_ollama diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 24a9b7db85..1c3580414d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -69,4 +69,4 @@ jobs: LETTA_SERVER_PASS: test_server_token PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }} run: | - poetry run pytest -s -vv -k "not test_tools.py and not test_concurrent_connections.py and not test_quickstart and not test_endpoints and not test_storage and not test_server and not test_openai_client" tests + poetry run pytest -s -vv -k "not test_tools.py and not test_concurrent_connections.py and not test_quickstart and not test_endpoints and not test_storage and not test_server and not test_openai_client and not test_providers" tests diff --git a/letta/providers.py b/letta/providers.py index ac6170629c..bfa3883a76 100644 --- a/letta/providers.py +++ b/letta/providers.py @@ -122,34 +122,64 @@ def get_model_context_window(self, model_name: str): response = requests.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True}) response_json = response.json() - # thank you vLLM: https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L1675 - possible_keys = [ - # OPT - "max_position_embeddings", - # GPT-2 - "n_positions", - # MPT - "max_seq_len", - # ChatGLM2 - "seq_length", - # Command-R - "model_max_length", - # Others - "max_sequence_length", - "max_seq_length", - "seq_len", - ] - + ## thank you vLLM: https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L1675 + # possible_keys = [ + # # OPT + # "max_position_embeddings", + # # GPT-2 + # "n_positions", + # # MPT + # "max_seq_len", + # # ChatGLM2 + # "seq_length", + # # Command-R + # "model_max_length", + # # Others + # "max_sequence_length", + # "max_seq_length", + # "seq_len", + # ] # max_position_embeddings # parse model cards: nous, dolphon, llama for key, value in response_json["model_info"].items(): - if "context_window" in key: + if "context_length" in key: + return value + return None + + def get_model_embedding_dim(self, model_name: str): + import requests + + response = requests.post(f"{self.base_url}/api/show", json={"name": model_name, "verbose": True}) + response_json = response.json() + for key, value in response_json["model_info"].items(): + if "embedding_length" in key: return value return None def list_embedding_models(self) -> List[EmbeddingConfig]: - # TODO: filter embedding models - return [] + # https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models + import requests + + response = requests.get(f"{self.base_url}/api/tags") + if response.status_code != 200: + raise Exception(f"Failed to list Ollama models: {response.text}") + response_json = response.json() + + configs = [] + for model in response_json["models"]: + embedding_dim = self.get_model_embedding_dim(model["name"]) + if not embedding_dim: + continue + configs.append( + EmbeddingConfig( + embedding_model=model["name"], + embedding_endpoint_type="ollama", + embedding_endpoint=self.base_url, + embedding_dim=embedding_dim, + embedding_chunk_size=300, + ) + ) + return configs class GroqProvider(OpenAIProvider): diff --git a/letta/settings.py b/letta/settings.py index a41dcf6974..1df531e298 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -55,116 +55,6 @@ class Settings(BaseSettings): pg_port: Optional[int] = None pg_uri: Optional[str] = None # option to specifiy full uri - ## llm configuration - # llm_endpoint: Optional[str] = None - # llm_endpoint_type: Optional[str] = None - # llm_model: Optional[str] = None - # llm_context_window: Optional[int] = None - - ## embedding configuration - # embedding_endpoint: Optional[str] = None - # embedding_endpoint_type: Optional[str] = None - # embedding_dim: Optional[int] = None - # embedding_model: Optional[str] = None - # embedding_chunk_size: int = 300 - - # @property - # def llm_config(self): - - # # try to get LLM config from settings - # if self.llm_endpoint and self.llm_endpoint_type and self.llm_model and self.llm_context_window: - # return LLMConfig( - # model=self.llm_model, - # model_endpoint_type=self.llm_endpoint_type, - # model_endpoint=self.llm_endpoint, - # model_wrapper=None, - # context_window=self.llm_context_window, - # ) - # else: - # if not self.llm_endpoint: - # printd(f"No LETTA_LLM_ENDPOINT provided") - # if not self.llm_endpoint_type: - # printd(f"No LETTA_LLM_ENDPOINT_TYPE provided") - # if not self.llm_model: - # printd(f"No LETTA_LLM_MODEL provided") - # if not self.llm_context_window: - # printd(f"No LETTA_LLM_CONTEX_WINDOW provided") - - # # quickstart options - # if self.llm_model: - # try: - # return LLMConfig.default_config(self.llm_model) - # except ValueError: - # pass - - # # try to read from config file (last resort) - # from letta.config import LettaConfig - - # if LettaConfig.exists(): - # config = LettaConfig.load() - # llm_config = LLMConfig( - # model=config.default_llm_config.model, - # model_endpoint_type=config.default_llm_config.model_endpoint_type, - # model_endpoint=config.default_llm_config.model_endpoint, - # model_wrapper=config.default_llm_config.model_wrapper, - # context_window=config.default_llm_config.context_window, - # ) - # return llm_config - - # # check OpenAI API key - # if os.getenv("OPENAI_API_KEY"): - # return LLMConfig.default_config(self.llm_model if self.llm_model else "gpt-4") - - # return LLMConfig.default_config("letta") - - # @property - # def embedding_config(self): - - # # try to get LLM config from settings - # if self.embedding_endpoint and self.embedding_endpoint_type and self.embedding_model and self.embedding_dim: - # return EmbeddingConfig( - # embedding_model=self.embedding_model, - # embedding_endpoint_type=self.embedding_endpoint_type, - # embedding_endpoint=self.embedding_endpoint, - # embedding_dim=self.embedding_dim, - # embedding_chunk_size=self.embedding_chunk_size, - # ) - # else: - # if not self.embedding_endpoint: - # printd(f"No LETTA_EMBEDDING_ENDPOINT provided") - # if not self.embedding_endpoint_type: - # printd(f"No LETTA_EMBEDDING_ENDPOINT_TYPE provided") - # if not self.embedding_model: - # printd(f"No LETTA_EMBEDDING_MODEL provided") - # if not self.embedding_dim: - # printd(f"No LETTA_EMBEDDING_DIM provided") - - # # TODO - # ## quickstart options - # # if self.embedding_model: - # # try: - # # return EmbeddingConfig.default_config(self.embedding_model) - # # except ValueError as e: - # # pass - - # # try to read from config file (last resort) - # from letta.config import LettaConfig - - # if LettaConfig.exists(): - # config = LettaConfig.load() - # return EmbeddingConfig( - # embedding_model=config.default_embedding_config.embedding_model, - # embedding_endpoint_type=config.default_embedding_config.embedding_endpoint_type, - # embedding_endpoint=config.default_embedding_config.embedding_endpoint, - # embedding_dim=config.default_embedding_config.embedding_dim, - # embedding_chunk_size=config.default_embedding_config.embedding_chunk_size, - # ) - - # if os.getenv("OPENAI_API_KEY"): - # return EmbeddingConfig.default_config(self.embedding_model if self.embedding_model else "text-embedding-ada-002") - - # return EmbeddingConfig.default_config("letta") - @property def letta_pg_uri(self) -> str: if self.pg_uri: diff --git a/tests/test_providers.py b/tests/test_providers.py index fecacd79c0..f7143f6ab9 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -1,6 +1,11 @@ import os -from letta.providers import AnthropicProvider, GoogleAIProvider, OpenAIProvider +from letta.providers import ( + AnthropicProvider, + GoogleAIProvider, + OllamaProvider, + OpenAIProvider, +) def test_openai(): @@ -24,24 +29,18 @@ def test_anthropic(): # print(models) # # -# def test_ollama(): -# provider = OllamaProvider() -# models = provider.list_llm_models() -# print(models) -# -# +def test_ollama(): + provider = OllamaProvider(base_url=os.getenv("OLLAMA_BASE_URL")) + models = provider.list_llm_models() + print(models) + + embedding_models = provider.list_embedding_models() + print(embedding_models) + + def test_googleai(): provider = GoogleAIProvider(api_key=os.getenv("GEMINI_API_KEY")) models = provider.list_llm_models() print(models) provider.list_embedding_models() - - -# -# -test_googleai() -# test_ollama() -# test_groq() -# test_openai() -# test_anthropic()