Skip to content

Commit

Permalink
add ollama support (#314)
Browse files Browse the repository at this point in the history
* untested

* patch

* updated

* clarified using tags in docs

* tested ollama, working

* fixed template issue by creating dummy template, also added missing context length indicator

* moved count_tokens to utils.py

* clean
  • Loading branch information
cpacker authored Nov 6, 2023
1 parent c752242 commit 74b2d81
Show file tree
Hide file tree
Showing 9 changed files with 143 additions and 21 deletions.
39 changes: 39 additions & 0 deletions docs/ollama.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
### MemGPT + Ollama

!!! warning "Be careful when downloading Ollama models!"

Make sure to use tags when downloading Ollama models! Don't do `ollama run dolphin2.2-mistral`, do `ollama run dolphin2.2-mistral:7b-q6_K`.

If you don't specify a tag, Ollama may default to using a highly compressed model variant (e.g. Q4). We highly recommend **NOT** using a compression level below Q4 (stick to Q6, Q8, or fp16 if possible). In our testing, models below Q6 start to become extremely unstable when used with MemGPT.

1. Download + install [Ollama](https://github.com/jmorganca/ollama) and the model you want to test with
2. Download a model to test with by running `ollama run <MODEL_NAME>` in the terminal (check the [Ollama model library](https://ollama.ai/library) for available models)
3. In addition to setting `OPENAI_API_BASE` and `BACKEND_TYPE`, we additionally need to set `OLLAMA_MODEL` (to the Ollama model name)

For example, if we want to use Dolphin 2.2.1 Mistral, we can download it by running:
```sh
# Let's use the q6_K variant
ollama run dolphin2.2-mistral:7b-q6_K
```
```text
pulling manifest
pulling d8a5ee4aba09... 100% |█████████████████████████████████████████████████████████████████████████| (4.1/4.1 GB, 20 MB/s)
pulling a47b02e00552... 100% |██████████████████████████████████████████████████████████████████████████████| (106/106 B, 77 B/s)
pulling 9640c2212a51... 100% |████████████████████████████████████████████████████████████████████████████████| (41/41 B, 22 B/s)
pulling de6bcd73f9b4... 100% |████████████████████████████████████████████████████████████████████████████████| (58/58 B, 28 B/s)
pulling 95c3d8d4429f... 100% |█████████████████████████████████████████████████████████████████████████████| (455/455 B, 330 B/s)
verifying sha256 digest
writing manifest
removing any unused layers
success
```

In your terminal where you're running MemGPT, run:
```sh
# By default, Ollama runs an API server on port 11434
export OPENAI_API_BASE=http://localhost:11434
export BACKEND_TYPE=ollama

# Make sure to add the tag!
export OLLAMA_MODEL=dolphin2.2-mistral:7b-q6_K
```
3 changes: 3 additions & 0 deletions memgpt/local_llm/chat_completion_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .lmstudio.api import get_lmstudio_completion
from .llamacpp.api import get_llamacpp_completion
from .koboldcpp.api import get_koboldcpp_completion
from .ollama.api import get_ollama_completion
from .llm_chat_completion_wrappers import airoboros, dolphin, zephyr, simple_summary_wrapper
from .utils import DotDict
from ..prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE
Expand Down Expand Up @@ -96,6 +97,8 @@ def get_chat_completion(
result = get_llamacpp_completion(prompt, grammar=grammar_name)
elif HOST_TYPE == "koboldcpp":
result = get_koboldcpp_completion(prompt, grammar=grammar_name)
elif HOST_TYPE == "ollama":
result = get_ollama_completion(prompt)
else:
raise LocalLLMError(
f"BACKEND_TYPE is not set, please set variable depending on your backend (webui, lmstudio, llamacpp, koboldcpp)"
Expand Down
8 changes: 1 addition & 7 deletions memgpt/local_llm/koboldcpp/api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import os
from urllib.parse import urljoin
import requests
import tiktoken

from .settings import SIMPLE
from ..utils import load_grammar_file
from ..utils import load_grammar_file, count_tokens
from ...constants import LLM_MAX_TOKENS

HOST = os.getenv("OPENAI_API_BASE")
Expand All @@ -14,11 +13,6 @@
DEBUG = True


def count_tokens(s: str, model: str = "gpt-4") -> int:
encoding = tiktoken.encoding_for_model(model)
return len(encoding.encode(s))


def get_koboldcpp_completion(prompt, grammar=None, settings=SIMPLE):
"""See https://lite.koboldai.net/koboldcpp_api for API spec"""
prompt_tokens = count_tokens(prompt)
Expand Down
8 changes: 1 addition & 7 deletions memgpt/local_llm/llamacpp/api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import os
from urllib.parse import urljoin
import requests
import tiktoken

from .settings import SIMPLE
from ..utils import load_grammar_file
from ..utils import load_grammar_file, count_tokens
from ...constants import LLM_MAX_TOKENS

HOST = os.getenv("OPENAI_API_BASE")
Expand All @@ -14,11 +13,6 @@
DEBUG = True


def count_tokens(s: str, model: str = "gpt-4") -> int:
encoding = tiktoken.encoding_for_model(model)
return len(encoding.encode(s))


def get_llamacpp_completion(prompt, grammar=None, settings=SIMPLE):
"""See https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md for instructions on how to run the LLM web server"""
prompt_tokens = count_tokens(prompt)
Expand Down
57 changes: 57 additions & 0 deletions memgpt/local_llm/ollama/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import os
from urllib.parse import urljoin
import requests

from .settings import SIMPLE
from ..utils import count_tokens
from ...constants import LLM_MAX_TOKENS
from ...errors import LocalLLMError

HOST = os.getenv("OPENAI_API_BASE")
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
MODEL_NAME = os.getenv("OLLAMA_MODEL") # ollama API requires this in the request
OLLAMA_API_SUFFIX = "/api/generate"
DEBUG = False


def get_ollama_completion(prompt, settings=SIMPLE, grammar=None):
"""See https://github.com/jmorganca/ollama/blob/main/docs/api.md for instructions on how to run the LLM web server"""
prompt_tokens = count_tokens(prompt)
if prompt_tokens > LLM_MAX_TOKENS:
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {LLM_MAX_TOKENS} tokens)")

if MODEL_NAME is None:
raise LocalLLMError(f"Error: OLLAMA_MODEL not specified. Set OLLAMA_MODEL to the model you want to run (e.g. 'dolphin2.2-mistral')")

# Settings for the generation, includes the prompt + stop tokens, max length, etc
request = settings
request["prompt"] = prompt
request["model"] = MODEL_NAME

# Set grammar
if grammar is not None:
# request["grammar_string"] = load_grammar_file(grammar)
raise NotImplementedError(f"Ollama does not support grammars")

if not HOST.startswith(("http://", "https://")):
raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://")

try:
URI = urljoin(HOST.strip("/") + "/", OLLAMA_API_SUFFIX.strip("/"))
response = requests.post(URI, json=request)
if response.status_code == 200:
result = response.json()
result = result["response"]
if DEBUG:
print(f"json API response.text: {result}")
else:
raise Exception(
f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}."
+ f" Make sure that the ollama API server is running and reachable at {URI}."
)

except:
# TODO handle gracefully
raise

return result
34 changes: 34 additions & 0 deletions memgpt/local_llm/ollama/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from ...constants import LLM_MAX_TOKENS

# see https://github.com/jmorganca/ollama/blob/main/docs/api.md
# and https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
SIMPLE = {
"options": {
"stop": [
"\nUSER:",
"\nASSISTANT:",
"\nFUNCTION RETURN:",
"\nUSER",
"\nASSISTANT",
"\nFUNCTION RETURN",
"\nFUNCTION",
"\nFUNC",
"<|im_start|>",
"<|im_end|>",
"<|im_sep|>",
# '\n' +
# '</s>',
# '<|',
# '\n#',
# '\n\n\n',
],
"num_ctx": LLM_MAX_TOKENS,
},
"stream": False,
# turn off Ollama's own prompt formatting
"system": "",
"template": "{{ .Prompt }}",
# "system": None,
# "template": None,
"context": None,
}
6 changes: 6 additions & 0 deletions memgpt/local_llm/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import tiktoken


class DotDict(dict):
Expand Down Expand Up @@ -31,3 +32,8 @@ def load_grammar_file(grammar):
grammar_str = file.read()

return grammar_str


def count_tokens(s: str, model: str = "gpt-4") -> int:
encoding = tiktoken.encoding_for_model(model)
return len(encoding.encode(s))
8 changes: 1 addition & 7 deletions memgpt/local_llm/webui/api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import os
from urllib.parse import urljoin
import requests
import tiktoken

from .settings import SIMPLE
from ..utils import load_grammar_file
from ..utils import load_grammar_file, count_tokens
from ...constants import LLM_MAX_TOKENS

HOST = os.getenv("OPENAI_API_BASE")
Expand All @@ -13,11 +12,6 @@
DEBUG = False


def count_tokens(s: str, model: str = "gpt-4") -> int:
encoding = tiktoken.encoding_for_model(model)
return len(encoding.encode(s))


def get_webui_completion(prompt, settings=SIMPLE, grammar=None):
"""See https://github.com/oobabooga/text-generation-webui for instructions on how to run the LLM web server"""
prompt_tokens = count_tokens(prompt)
Expand Down
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ nav:
- 'LM Studio': lmstudio.md
- 'llama.cpp': llamacpp.md
- 'koboldcpp': koboldcpp.md
- 'ollama': ollama.md
- 'Troubleshooting': local_llm_faq.md
- 'Integrations':
- 'Autogen': autogen.md
Expand Down

0 comments on commit 74b2d81

Please sign in to comment.