From 74b2d81f8809ea707afd3e782f9504d32544af4a Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Mon, 6 Nov 2023 15:11:22 -0800 Subject: [PATCH] add ollama support (#314) * untested * patch * updated * clarified using tags in docs * tested ollama, working * fixed template issue by creating dummy template, also added missing context length indicator * moved count_tokens to utils.py * clean --- docs/ollama.md | 39 ++++++++++++++++ memgpt/local_llm/chat_completion_proxy.py | 3 ++ memgpt/local_llm/koboldcpp/api.py | 8 +--- memgpt/local_llm/llamacpp/api.py | 8 +--- memgpt/local_llm/ollama/api.py | 57 +++++++++++++++++++++++ memgpt/local_llm/ollama/settings.py | 34 ++++++++++++++ memgpt/local_llm/utils.py | 6 +++ memgpt/local_llm/webui/api.py | 8 +--- mkdocs.yml | 1 + 9 files changed, 143 insertions(+), 21 deletions(-) create mode 100644 docs/ollama.md create mode 100644 memgpt/local_llm/ollama/api.py create mode 100644 memgpt/local_llm/ollama/settings.py diff --git a/docs/ollama.md b/docs/ollama.md new file mode 100644 index 0000000000..95515fd6e5 --- /dev/null +++ b/docs/ollama.md @@ -0,0 +1,39 @@ +### MemGPT + Ollama + +!!! warning "Be careful when downloading Ollama models!" + + Make sure to use tags when downloading Ollama models! Don't do `ollama run dolphin2.2-mistral`, do `ollama run dolphin2.2-mistral:7b-q6_K`. + + If you don't specify a tag, Ollama may default to using a highly compressed model variant (e.g. Q4). We highly recommend **NOT** using a compression level below Q4 (stick to Q6, Q8, or fp16 if possible). In our testing, models below Q6 start to become extremely unstable when used with MemGPT. + +1. Download + install [Ollama](https://github.com/jmorganca/ollama) and the model you want to test with +2. Download a model to test with by running `ollama run ` in the terminal (check the [Ollama model library](https://ollama.ai/library) for available models) +3. In addition to setting `OPENAI_API_BASE` and `BACKEND_TYPE`, we additionally need to set `OLLAMA_MODEL` (to the Ollama model name) + +For example, if we want to use Dolphin 2.2.1 Mistral, we can download it by running: +```sh +# Let's use the q6_K variant +ollama run dolphin2.2-mistral:7b-q6_K +``` +```text +pulling manifest +pulling d8a5ee4aba09... 100% |█████████████████████████████████████████████████████████████████████████| (4.1/4.1 GB, 20 MB/s) +pulling a47b02e00552... 100% |██████████████████████████████████████████████████████████████████████████████| (106/106 B, 77 B/s) +pulling 9640c2212a51... 100% |████████████████████████████████████████████████████████████████████████████████| (41/41 B, 22 B/s) +pulling de6bcd73f9b4... 100% |████████████████████████████████████████████████████████████████████████████████| (58/58 B, 28 B/s) +pulling 95c3d8d4429f... 100% |█████████████████████████████████████████████████████████████████████████████| (455/455 B, 330 B/s) +verifying sha256 digest +writing manifest +removing any unused layers +success +``` + +In your terminal where you're running MemGPT, run: +```sh +# By default, Ollama runs an API server on port 11434 +export OPENAI_API_BASE=http://localhost:11434 +export BACKEND_TYPE=ollama + +# Make sure to add the tag! +export OLLAMA_MODEL=dolphin2.2-mistral:7b-q6_K +``` diff --git a/memgpt/local_llm/chat_completion_proxy.py b/memgpt/local_llm/chat_completion_proxy.py index 8f5af63f02..44b3c81fa5 100644 --- a/memgpt/local_llm/chat_completion_proxy.py +++ b/memgpt/local_llm/chat_completion_proxy.py @@ -8,6 +8,7 @@ from .lmstudio.api import get_lmstudio_completion from .llamacpp.api import get_llamacpp_completion from .koboldcpp.api import get_koboldcpp_completion +from .ollama.api import get_ollama_completion from .llm_chat_completion_wrappers import airoboros, dolphin, zephyr, simple_summary_wrapper from .utils import DotDict from ..prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE @@ -96,6 +97,8 @@ def get_chat_completion( result = get_llamacpp_completion(prompt, grammar=grammar_name) elif HOST_TYPE == "koboldcpp": result = get_koboldcpp_completion(prompt, grammar=grammar_name) + elif HOST_TYPE == "ollama": + result = get_ollama_completion(prompt) else: raise LocalLLMError( f"BACKEND_TYPE is not set, please set variable depending on your backend (webui, lmstudio, llamacpp, koboldcpp)" diff --git a/memgpt/local_llm/koboldcpp/api.py b/memgpt/local_llm/koboldcpp/api.py index 1e81c59388..d345217928 100644 --- a/memgpt/local_llm/koboldcpp/api.py +++ b/memgpt/local_llm/koboldcpp/api.py @@ -1,10 +1,9 @@ import os from urllib.parse import urljoin import requests -import tiktoken from .settings import SIMPLE -from ..utils import load_grammar_file +from ..utils import load_grammar_file, count_tokens from ...constants import LLM_MAX_TOKENS HOST = os.getenv("OPENAI_API_BASE") @@ -14,11 +13,6 @@ DEBUG = True -def count_tokens(s: str, model: str = "gpt-4") -> int: - encoding = tiktoken.encoding_for_model(model) - return len(encoding.encode(s)) - - def get_koboldcpp_completion(prompt, grammar=None, settings=SIMPLE): """See https://lite.koboldai.net/koboldcpp_api for API spec""" prompt_tokens = count_tokens(prompt) diff --git a/memgpt/local_llm/llamacpp/api.py b/memgpt/local_llm/llamacpp/api.py index ea51f71759..5e2218e55e 100644 --- a/memgpt/local_llm/llamacpp/api.py +++ b/memgpt/local_llm/llamacpp/api.py @@ -1,10 +1,9 @@ import os from urllib.parse import urljoin import requests -import tiktoken from .settings import SIMPLE -from ..utils import load_grammar_file +from ..utils import load_grammar_file, count_tokens from ...constants import LLM_MAX_TOKENS HOST = os.getenv("OPENAI_API_BASE") @@ -14,11 +13,6 @@ DEBUG = True -def count_tokens(s: str, model: str = "gpt-4") -> int: - encoding = tiktoken.encoding_for_model(model) - return len(encoding.encode(s)) - - def get_llamacpp_completion(prompt, grammar=None, settings=SIMPLE): """See https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md for instructions on how to run the LLM web server""" prompt_tokens = count_tokens(prompt) diff --git a/memgpt/local_llm/ollama/api.py b/memgpt/local_llm/ollama/api.py new file mode 100644 index 0000000000..f13af5f383 --- /dev/null +++ b/memgpt/local_llm/ollama/api.py @@ -0,0 +1,57 @@ +import os +from urllib.parse import urljoin +import requests + +from .settings import SIMPLE +from ..utils import count_tokens +from ...constants import LLM_MAX_TOKENS +from ...errors import LocalLLMError + +HOST = os.getenv("OPENAI_API_BASE") +HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion +MODEL_NAME = os.getenv("OLLAMA_MODEL") # ollama API requires this in the request +OLLAMA_API_SUFFIX = "/api/generate" +DEBUG = False + + +def get_ollama_completion(prompt, settings=SIMPLE, grammar=None): + """See https://github.com/jmorganca/ollama/blob/main/docs/api.md for instructions on how to run the LLM web server""" + prompt_tokens = count_tokens(prompt) + if prompt_tokens > LLM_MAX_TOKENS: + raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {LLM_MAX_TOKENS} tokens)") + + if MODEL_NAME is None: + raise LocalLLMError(f"Error: OLLAMA_MODEL not specified. Set OLLAMA_MODEL to the model you want to run (e.g. 'dolphin2.2-mistral')") + + # Settings for the generation, includes the prompt + stop tokens, max length, etc + request = settings + request["prompt"] = prompt + request["model"] = MODEL_NAME + + # Set grammar + if grammar is not None: + # request["grammar_string"] = load_grammar_file(grammar) + raise NotImplementedError(f"Ollama does not support grammars") + + if not HOST.startswith(("http://", "https://")): + raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://") + + try: + URI = urljoin(HOST.strip("/") + "/", OLLAMA_API_SUFFIX.strip("/")) + response = requests.post(URI, json=request) + if response.status_code == 200: + result = response.json() + result = result["response"] + if DEBUG: + print(f"json API response.text: {result}") + else: + raise Exception( + f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}." + + f" Make sure that the ollama API server is running and reachable at {URI}." + ) + + except: + # TODO handle gracefully + raise + + return result diff --git a/memgpt/local_llm/ollama/settings.py b/memgpt/local_llm/ollama/settings.py new file mode 100644 index 0000000000..f412361ca3 --- /dev/null +++ b/memgpt/local_llm/ollama/settings.py @@ -0,0 +1,34 @@ +from ...constants import LLM_MAX_TOKENS + +# see https://github.com/jmorganca/ollama/blob/main/docs/api.md +# and https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values +SIMPLE = { + "options": { + "stop": [ + "\nUSER:", + "\nASSISTANT:", + "\nFUNCTION RETURN:", + "\nUSER", + "\nASSISTANT", + "\nFUNCTION RETURN", + "\nFUNCTION", + "\nFUNC", + "<|im_start|>", + "<|im_end|>", + "<|im_sep|>", + # '\n' + + # '', + # '<|', + # '\n#', + # '\n\n\n', + ], + "num_ctx": LLM_MAX_TOKENS, + }, + "stream": False, + # turn off Ollama's own prompt formatting + "system": "", + "template": "{{ .Prompt }}", + # "system": None, + # "template": None, + "context": None, +} diff --git a/memgpt/local_llm/utils.py b/memgpt/local_llm/utils.py index 2456776171..f6c44eaef7 100644 --- a/memgpt/local_llm/utils.py +++ b/memgpt/local_llm/utils.py @@ -1,4 +1,5 @@ import os +import tiktoken class DotDict(dict): @@ -31,3 +32,8 @@ def load_grammar_file(grammar): grammar_str = file.read() return grammar_str + + +def count_tokens(s: str, model: str = "gpt-4") -> int: + encoding = tiktoken.encoding_for_model(model) + return len(encoding.encode(s)) diff --git a/memgpt/local_llm/webui/api.py b/memgpt/local_llm/webui/api.py index 97a5c8858d..163c403516 100644 --- a/memgpt/local_llm/webui/api.py +++ b/memgpt/local_llm/webui/api.py @@ -1,10 +1,9 @@ import os from urllib.parse import urljoin import requests -import tiktoken from .settings import SIMPLE -from ..utils import load_grammar_file +from ..utils import load_grammar_file, count_tokens from ...constants import LLM_MAX_TOKENS HOST = os.getenv("OPENAI_API_BASE") @@ -13,11 +12,6 @@ DEBUG = False -def count_tokens(s: str, model: str = "gpt-4") -> int: - encoding = tiktoken.encoding_for_model(model) - return len(encoding.encode(s)) - - def get_webui_completion(prompt, settings=SIMPLE, grammar=None): """See https://github.com/oobabooga/text-generation-webui for instructions on how to run the LLM web server""" prompt_tokens = count_tokens(prompt) diff --git a/mkdocs.yml b/mkdocs.yml index f1f1712c22..6ca2b38ffa 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -21,6 +21,7 @@ nav: - 'LM Studio': lmstudio.md - 'llama.cpp': llamacpp.md - 'koboldcpp': koboldcpp.md + - 'ollama': ollama.md - 'Troubleshooting': local_llm_faq.md - 'Integrations': - 'Autogen': autogen.md