From df01658f02ee90e5c59d0ffe990cd7e4d6f8049a Mon Sep 17 00:00:00 2001 From: Paul Swingle Date: Wed, 3 Jan 2024 15:28:58 -0800 Subject: [PATCH 1/4] add litellm --- mentat/code_context.py | 4 +- mentat/config.py | 12 +- mentat/conversation.py | 7 - mentat/errors.py | 3 +- mentat/feature_filters/default_filter.py | 4 +- mentat/llm_api_handler.py | 176 +++++++++++------------ mentat/session.py | 12 +- requirements.txt | 1 + tests/conftest.py | 4 +- 9 files changed, 100 insertions(+), 123 deletions(-) diff --git a/mentat/code_context.py b/mentat/code_context.py index 362682f47..876f414b2 100644 --- a/mentat/code_context.py +++ b/mentat/code_context.py @@ -11,7 +11,7 @@ split_file_into_intervals, ) from mentat.diff_context import DiffContext -from mentat.errors import ContextSizeInsufficient, PathValidationError +from mentat.errors import PathValidationError, ReturnToUser from mentat.feature_filters.default_filter import DefaultFilter from mentat.feature_filters.embedding_similarity_filter import EmbeddingSimilarityFilter from mentat.git_handler import get_paths_with_git_diffs @@ -152,7 +152,7 @@ async def get_code_message( prompt_tokens + meta_tokens + include_files_tokens + config.token_buffer ) if not is_context_sufficient(tokens_used): - raise ContextSizeInsufficient() + raise ReturnToUser() auto_tokens = min(get_max_tokens() - tokens_used, config.auto_context_tokens) # Get auto included features diff --git a/mentat/config.py b/mentat/config.py index 4f423932e..73d1004e7 100644 --- a/mentat/config.py +++ b/mentat/config.py @@ -9,7 +9,7 @@ from attr import converters, validators from mentat.git_handler import get_git_root_for_path -from mentat.llm_api_handler import known_models +from mentat.llm_api_handler import available_embedding_models, available_models from mentat.parsers.parser import Parser from mentat.parsers.parser_map import parser_map from mentat.session_context import SESSION_CONTEXT @@ -35,19 +35,15 @@ class Config: # Model specific settings model: str = attr.field( default="gpt-4-1106-preview", - metadata={"auto_completions": list(known_models.keys())}, + metadata={"auto_completions": available_models()}, ) feature_selection_model: str = attr.field( default="gpt-4-1106-preview", - metadata={"auto_completions": list(known_models.keys())}, + metadata={"auto_completions": available_models()}, ) embedding_model: str = attr.field( default="text-embedding-ada-002", - metadata={ - "auto_completions": [ - model.name for model in known_models.values() if model.embedding_model - ] - }, + metadata={"auto_completions": available_embedding_models()}, ) temperature: float = attr.field( default=0.2, converter=float, validator=[validators.le(1), validators.ge(0)] diff --git a/mentat/conversation.py b/mentat/conversation.py index 413731c0e..58248bdc6 100644 --- a/mentat/conversation.py +++ b/mentat/conversation.py @@ -15,7 +15,6 @@ ChatCompletionUserMessageParam, ) -from mentat.errors import MentatError from mentat.llm_api_handler import ( TOKEN_COUNT_WARNING, count_tokens, @@ -41,13 +40,7 @@ async def display_token_count(self): stream = session_context.stream config = session_context.config code_context = session_context.code_context - llm_api_handler = session_context.llm_api_handler - if not await llm_api_handler.is_model_available(config.model): - raise MentatError( - f"Model {config.model} is not available. Please try again with a" - " different model." - ) if "gpt-4" not in config.model: stream.send( "Warning: Mentat has only been tested on GPT-4. You may experience" diff --git a/mentat/errors.py b/mentat/errors.py index c328fed10..0e341e852 100644 --- a/mentat/errors.py +++ b/mentat/errors.py @@ -44,8 +44,7 @@ class PathValidationError(Exception): pass -class ContextSizeInsufficient(Exception): +class ReturnToUser(Exception): """ - Raised when trying to call the API with too many tokens for that model. Will give control back to the user after being raised. """ diff --git a/mentat/feature_filters/default_filter.py b/mentat/feature_filters/default_filter.py index 76354ac58..dde60bdc6 100644 --- a/mentat/feature_filters/default_filter.py +++ b/mentat/feature_filters/default_filter.py @@ -1,7 +1,7 @@ from typing import Optional from mentat.code_feature import CodeFeature -from mentat.errors import ContextSizeInsufficient, ModelError +from mentat.errors import ModelError, ReturnToUser from mentat.feature_filters.embedding_similarity_filter import EmbeddingSimilarityFilter from mentat.feature_filters.feature_filter import FeatureFilter from mentat.feature_filters.llm_feature_filter import LLMFeatureFilter @@ -39,7 +39,7 @@ async def filter(self, features: list[CodeFeature]) -> list[CodeFeature]: self.expected_edits, (0.5 if self.user_prompt != "" else 1) * self.loading_multiplier, ).filter(features) - except (ModelError, ContextSizeInsufficient): + except (ModelError, ReturnToUser): ctx.stream.send( "Feature-selection LLM response invalid. Using TruncateFilter" " instead." diff --git a/mentat/llm_api_handler.py b/mentat/llm_api_handler.py index 50a9efda9..e67057048 100644 --- a/mentat/llm_api_handler.py +++ b/mentat/llm_api_handler.py @@ -1,7 +1,5 @@ from __future__ import annotations -import base64 -import io import os import sys from pathlib import Path @@ -10,17 +8,16 @@ Any, AsyncIterator, Callable, - Dict, List, Literal, Optional, + TypedDict, cast, overload, ) -import attr +import litellm import sentry_sdk -import tiktoken from dotenv import load_dotenv from openai import ( APIConnectionError, @@ -29,16 +26,16 @@ AsyncStream, AuthenticationError, ) +from openai.types import CreateEmbeddingResponse +from openai.types.audio import Transcription from openai.types.chat import ( ChatCompletion, ChatCompletionChunk, - ChatCompletionContentPartParam, ChatCompletionMessageParam, ) from openai.types.chat.completion_create_params import ResponseFormat -from PIL import Image -from mentat.errors import ContextSizeInsufficient, MentatError, UserError +from mentat.errors import MentatError, ReturnToUser from mentat.session_context import SESSION_CONTEXT from mentat.utils import mentat_dir_path @@ -56,7 +53,7 @@ def is_test_environment(): def api_guard(func: Callable[..., Any]) -> Callable[..., Any]: - """Decorator that should be used on any function that calls the OpenAI API + """Decorator that should be used on any function that calls an LLM API It does two things: 1. Raises if the function is called in tests (that aren't benchmarks) @@ -92,6 +89,9 @@ def count_tokens(message: str, model: str, full_message: bool) -> int: Use prompt_tokens to get the exact amount of tokens for a prompt. If full_message is true, will include the extra 4 tokens used in a chat completion by this message if this message is part of a prompt. You do NOT want full_message to be true for a response. + """ + return litellm.token_counter(model, text=message) # pyright: ignore + """ try: encoding = tiktoken.encoding_for_model(model) @@ -100,13 +100,17 @@ def count_tokens(message: str, model: str, full_message: bool) -> int: return len(encoding.encode(message, disallowed_special=())) + ( 4 if full_message else 0 ) + """ -def prompt_tokens(messages: list[ChatCompletionMessageParam], model: str): +def prompt_tokens(messages: list[ChatCompletionMessageParam], model: str) -> int: """ Returns the number of tokens used by a prompt if it was sent to OpenAI for a chat completion. Adapted from https://platform.openai.com/docs/guides/text-generation/managing-tokens """ + + return litellm.token_counter(model, messages=messages) # pyright: ignore + """ try: encoding = tiktoken.encoding_for_model(model) except KeyError: @@ -142,52 +146,44 @@ def prompt_tokens(messages: list[ChatCompletionMessageParam], model: str): num_tokens -= 1 # role is always required and always 1 token num_tokens += 2 # every reply is primed with <|start|>assistant return num_tokens + """ -@attr.define -class Model: - name: str = attr.field() - context_size: int = attr.field() - input_cost: float = attr.field() - output_cost: float = attr.field() - embedding_model: bool = attr.field(default=False) - - -known_models: Dict[str, Model] = { - "gpt-4-1106-preview": Model("gpt-4-1106-preview", 128000, 0.01, 0.03), - # model name on Azure - "gpt-4-1106-Preview": Model("gpt-4-1106-Preview", 128000, 0.01, 0.03), - "gpt-4-vision-preview": Model("gpt-4-vision-preview", 128000, 0.01, 0.03), - "gpt-4": Model("gpt-4", 8192, 0.03, 0.06), - "gpt-4-32k": Model("gpt-4-32k", 32768, 0.06, 0.12), - "gpt-4-0613": Model("gpt-4-0613", 8192, 0.03, 0.06), - "gpt-4-32k-0613": Model("gpt-4-32k-0613", 32768, 0.06, 0.12), - "gpt-4-0314": Model("gpt-4-0314", 8192, 0.03, 0.06), - "gpt-4-32k-0314": Model("gpt-4-32k-0314", 32768, 0.06, 0.12), - "gpt-3.5-turbo-1106": Model("gpt-3.5-turbo-1106", 16385, 0.001, 0.002), - "gpt-3.5-turbo": Model("gpt-3.5-turbo", 16385, 0.001, 0.002), - "gpt-3.5-turbo-0613": Model("gpt-3.5-turbo-0613", 4096, 0.0015, 0.002), - "gpt-3.5-turbo-16k-0613": Model("gpt-3.5-turbo-16k-0613", 16385, 0.003, 0.004), - "gpt-3.5-turbo-0301": Model("gpt-3.5-turbo-0301", 4096, 0.0015, 0.002), - "text-embedding-ada-002": Model( - "text-embedding-ada-002", 8191, 0.0001, 0, embedding_model=True - ), -} +class Model(TypedDict): + max_tokens: int + input_cost_per_token: float + output_cost_per_token: float + litellm_provider: str + mode: str -def model_context_size(model: str) -> Optional[int]: - if model not in known_models: +def _get_model_info(model: str) -> Optional[Model]: + try: + return litellm.get_model_info(model) # pyright: ignore + except Exception: return None - else: - return known_models[model].context_size + + +def available_models() -> List[str]: + return litellm.model_list # pyright: ignore + + +def available_embedding_models() -> List[str]: + return litellm.all_embedding_models # pyright: ignore + + +def model_context_size(model: str) -> Optional[int]: + model_info = _get_model_info(model) + return model_info["max_tokens"] if model_info is not None else None def model_price_per_1000_tokens(model: str) -> Optional[tuple[float, float]]: - """Returns (input, output) cost per 1000 tokens in USD""" - if model not in known_models: - return None - else: - return known_models[model].input_cost, known_models[model].output_cost + model_info = _get_model_info(model) + return ( + (model_info["input_cost_per_token"], model_info["output_cost_per_token"]) + if model_info is not None + else None + ) def get_max_tokens() -> int: @@ -205,12 +201,16 @@ def get_max_tokens() -> int: elif maximum_context is not None: return maximum_context else: + maximum_context = 4096 + # This attr has a converter from str to int + config.maximum_context = str(maximum_context) stream.send( - f"Context size for {config.model} is not known. Please set" - " maximum-context with `/config maximum_context `.", - color="light_red", + f"Context size for {config.model} is not known. Set maximum-context" + " with `/config maximum_context `. Using a default value of" + f" {maximum_context}.", + color="yellow", ) - raise ContextSizeInsufficient() + return maximum_context def is_context_sufficient(tokens: int) -> bool: @@ -232,33 +232,35 @@ def is_context_sufficient(tokens: int) -> bool: class LlmApiHandler: """Used for any functions that require calling the external LLM API""" - def initialize_client(self): + def load_env(self): + ctx = SESSION_CONTEXT.get() + if not load_dotenv(mentat_dir_path / ".env"): load_dotenv() + key = os.getenv("OPENAI_API_KEY") base_url = os.getenv("OPENAI_API_BASE") azure_key = os.getenv("AZURE_OPENAI_KEY") azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") - if not key and not azure_key: - raise UserError( - "No OpenAI api key detected.\nEither place your key into a .env" - " file or export it as an environment variable." - ) - # We don't have any use for a synchronous client, but if we ever do we can easily make it here - if azure_endpoint: + if azure_endpoint and azure_key: self.async_client = AsyncAzureOpenAI( api_key=azure_key, api_version="2023-12-01-preview", azure_endpoint=azure_endpoint, ) - else: + elif key: self.async_client = AsyncOpenAI(api_key=key, base_url=base_url) - try: - self.async_client.models.list() # Test the key - except AuthenticationError as e: - raise UserError(f"API gave an Authentication Error:\n{e}") + else: + self.async_client = None + + if self.async_client is not None: + try: + self.async_client.models.list() # Test the key + except AuthenticationError as e: + ctx.stream.send(f"OpenAI API gave Authentication Error:\n{e}") + self.async_client = None @overload async def call_llm_api( @@ -293,30 +295,21 @@ async def call_llm_api( # Confirm that model has enough tokens remaining. tokens = prompt_tokens(messages, model) if not is_context_sufficient(tokens): - raise ContextSizeInsufficient() + raise ReturnToUser() start_time = default_timer() with sentry_sdk.start_span(description="LLM Call") as span: span.set_tag("model", model) - # OpenAI's API is bugged; when gpt-4-vision-preview is used, including the response format - # at all returns a 400 error. Additionally, gpt-4-vision-preview has a max response of 30 tokens by default. - # Until this is fixed, we have to use this workaround. - if model == "gpt-4-vision-preview": - response = await self.async_client.chat.completions.create( + response = cast( + ChatCompletion | AsyncIterator[ChatCompletionChunk], + await litellm.acompletion( # pyright: ignore model=model, messages=messages, temperature=config.temperature, stream=stream, - max_tokens=4096, - ) - else: - response = await self.async_client.chat.completions.create( - model=model, - messages=messages, - temperature=config.temperature, - stream=stream, - response_format=response_format, - ) + response_format=response_format, # pyright: ignore + ), + ) # We have to cast response since pyright isn't smart enough to connect # the dots between stream and the overloaded create function @@ -342,22 +335,23 @@ async def call_llm_api( async def call_embedding_api( self, input_texts: list[str], model: str = "text-embedding-ada-002" ) -> list[list[float]]: - response = await self.async_client.embeddings.create( - input=input_texts, model=model + response = cast( + CreateEmbeddingResponse, + await litellm.aembedding(input=input_texts, model=model), # pyright: ignore ) return [embedding.embedding for embedding in response.data] - @api_guard - async def is_model_available(self, model: str) -> bool: - available_models: list[str] = [ - model.id async for model in self.async_client.models.list() - ] - return model in available_models - @api_guard async def call_whisper_api(self, audio_path: Path) -> str: + ctx = SESSION_CONTEXT.get() + + if self.async_client is None: + ctx.stream.send( + "You must provide a valid OpenAI API key to use the Whisper API." + ) + raise ReturnToUser() audio_file = open(audio_path, "rb") - transcript = await self.async_client.audio.transcriptions.create( + transcript: Transcription = await self.async_client.audio.transcriptions.create( model="whisper-1", file=audio_file, ) diff --git a/mentat/session.py b/mentat/session.py index 84789d20c..ea6e6eea0 100644 --- a/mentat/session.py +++ b/mentat/session.py @@ -20,13 +20,7 @@ from mentat.conversation import Conversation from mentat.cost_tracker import CostTracker from mentat.ctags import ensure_ctags_installed -from mentat.errors import ( - ContextSizeInsufficient, - MentatError, - SampleError, - SessionExit, - UserError, -) +from mentat.errors import MentatError, ReturnToUser, SampleError, SessionExit, UserError from mentat.git_handler import get_git_root_for_path from mentat.llm_api_handler import LlmApiHandler, is_test_environment from mentat.logging_config import setup_logging @@ -147,7 +141,7 @@ async def _main(self): if session_context.config.auto_context_tokens > 0: ensure_ctags_installed() - session_context.llm_api_handler.initialize_client() + session_context.llm_api_handler.load_env() code_context.display_context() await conversation.display_token_count() @@ -214,7 +208,7 @@ async def _main(self): stream.send(bool(file_edits), channel="edits_complete") except SessionExit: break - except ContextSizeInsufficient: + except ReturnToUser: need_user_request = True continue except (APITimeoutError, RateLimitError, BadRequestError) as e: diff --git a/requirements.txt b/requirements.txt index 1110a6015..20408439a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ fire==0.5.0 gitpython==3.1.37 jinja2==3.1.2 jsonschema>=4.17.0 +litellm==1.16.11 numpy==1.26.0 openai==1.3.0 pillow==10.1.0 diff --git a/tests/conftest.py b/tests/conftest.py index 1d4b2bc9b..8c83ad487 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -242,9 +242,9 @@ def mock_model_available(mocker): @pytest.fixture(autouse=True, scope="function") -def mock_initialize_client(mocker, request): +def mock_load_env(mocker, request): if not request.config.getoption("--benchmark"): - mocker.patch.object(LlmApiHandler, "initialize_client") + mocker.patch.object(LlmApiHandler, "load_env") # ContextVars need to be set in a synchronous fixture due to pytest not propagating From 861e8ba400efd2c625c2d02a9fa065214f2d9cb6 Mon Sep 17 00:00:00 2001 From: Paul Swingle Date: Wed, 3 Jan 2024 15:53:48 -0800 Subject: [PATCH 2/4] fix images --- mentat/llm_api_handler.py | 184 ++++++++++++++++++++++---------------- 1 file changed, 108 insertions(+), 76 deletions(-) diff --git a/mentat/llm_api_handler.py b/mentat/llm_api_handler.py index e67057048..480025128 100644 --- a/mentat/llm_api_handler.py +++ b/mentat/llm_api_handler.py @@ -1,5 +1,7 @@ from __future__ import annotations +import base64 +import io import os import sys from pathlib import Path @@ -18,6 +20,7 @@ import litellm import sentry_sdk +import tiktoken from dotenv import load_dotenv from openai import ( APIConnectionError, @@ -31,9 +34,11 @@ from openai.types.chat import ( ChatCompletion, ChatCompletionChunk, + ChatCompletionContentPartParam, ChatCompletionMessageParam, ) from openai.types.chat.completion_create_params import ResponseFormat +from PIL import Image from mentat.errors import MentatError, ReturnToUser from mentat.session_context import SESSION_CONTEXT @@ -83,72 +88,6 @@ def chunk_to_lines(chunk: ChatCompletionChunk) -> list[str]: return ("" if content is None else content).splitlines(keepends=True) -def count_tokens(message: str, model: str, full_message: bool) -> int: - """ - Calculates the tokens in this message. Will NOT be accurate for a full prompt! - Use prompt_tokens to get the exact amount of tokens for a prompt. - If full_message is true, will include the extra 4 tokens used in a chat completion by this message - if this message is part of a prompt. You do NOT want full_message to be true for a response. - """ - return litellm.token_counter(model, text=message) # pyright: ignore - - """ - try: - encoding = tiktoken.encoding_for_model(model) - except KeyError: - encoding = tiktoken.get_encoding("cl100k_base") - return len(encoding.encode(message, disallowed_special=())) + ( - 4 if full_message else 0 - ) - """ - - -def prompt_tokens(messages: list[ChatCompletionMessageParam], model: str) -> int: - """ - Returns the number of tokens used by a prompt if it was sent to OpenAI for a chat completion. - Adapted from https://platform.openai.com/docs/guides/text-generation/managing-tokens - """ - - return litellm.token_counter(model, messages=messages) # pyright: ignore - """ - try: - encoding = tiktoken.encoding_for_model(model) - except KeyError: - encoding = tiktoken.get_encoding("cl100k_base") - - num_tokens = 0 - for message in messages: - # every message follows <|start|>{role/name}\n{content}<|end|>\n - # this has 5 tokens (start token, role, \n, end token, \n), but we count the role token later - num_tokens += 4 - for key, value in message.items(): - if isinstance(value, list) and key == "content": - value = cast(List[ChatCompletionContentPartParam], value) - for entry in value: - if entry["type"] == "text": - num_tokens += len(encoding.encode(entry["text"])) - if entry["type"] == "image_url": - image_base64: str = entry["image_url"]["url"].split(",")[1] - image_bytes: bytes = base64.b64decode(image_base64) - image = Image.open(io.BytesIO(image_bytes)) - size = image.size - # As described here: https://platform.openai.com/docs/guides/vision/calculating-costs - scale = min(1, 2048 / max(size)) - size = (int(size[0] * scale), int(size[1] * scale)) - scale = min(1, 768 / min(size)) - size = (int(size[0] * scale), int(size[1] * scale)) - num_tokens += 85 + 170 * ((size[0] + 511) // 512) * ( - (size[1] + 511) // 512 - ) - elif isinstance(value, str): - num_tokens += len(encoding.encode(value)) - if key == "name": # if there's a name, the role is omitted - num_tokens -= 1 # role is always required and always 1 token - num_tokens += 2 # every reply is primed with <|start|>assistant - return num_tokens - """ - - class Model(TypedDict): max_tokens: int input_cost_per_token: float @@ -229,6 +168,84 @@ def is_context_sufficient(tokens: int) -> bool: return True +# litellm's token counting functions are inaccurate and don't count picture tokens; +# if we are using OpenAI, use our functions instead. +def _open_ai_count_tokens(message: str, model: str, full_message: bool) -> int: + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + encoding = tiktoken.get_encoding("cl100k_base") + return len(encoding.encode(message, disallowed_special=())) + ( + 4 if full_message else 0 + ) + + +def _open_ai_prompt_tokens( + messages: List[ChatCompletionMessageParam], model: str +) -> int: + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + encoding = tiktoken.get_encoding("cl100k_base") + + num_tokens = 0 + for message in messages: + # every message follows <|start|>{role/name}\n{content}<|end|>\n + # this has 5 tokens (start token, role, \n, end token, \n), but we count the role token later + num_tokens += 4 + for key, value in message.items(): + if isinstance(value, list) and key == "content": + value = cast(List[ChatCompletionContentPartParam], value) + for entry in value: + if entry["type"] == "text": + num_tokens += len(encoding.encode(entry["text"])) + if entry["type"] == "image_url": + image_base64: str = entry["image_url"]["url"].split(",")[1] + image_bytes: bytes = base64.b64decode(image_base64) + image = Image.open(io.BytesIO(image_bytes)) + size = image.size + # As described here: https://platform.openai.com/docs/guides/vision/calculating-costs + scale = min(1, 2048 / max(size)) + size = (int(size[0] * scale), int(size[1] * scale)) + scale = min(1, 768 / min(size)) + size = (int(size[0] * scale), int(size[1] * scale)) + num_tokens += 85 + 170 * ((size[0] + 511) // 512) * ( + (size[1] + 511) // 512 + ) + elif isinstance(value, str): + num_tokens += len(encoding.encode(value)) + if key == "name": # if there's a name, the role is omitted + num_tokens -= 1 # role is always required and always 1 token + num_tokens += 2 # every reply is primed with <|start|>assistant + return num_tokens + + +def count_tokens(message: str, model: str, full_message: bool) -> int: + """ + Calculates the tokens in this message. Will NOT be accurate for a full prompt! + Use prompt_tokens to get the exact amount of tokens for a prompt. + If full_message is true, will include the extra 4 tokens used in a chat completion by this message + if this message is part of a prompt. You do NOT want full_message to be true for a response. + """ + model_info = _get_model_info(model) + if model_info is not None and model_info["litellm_provider"] == "openai": + return _open_ai_count_tokens(message, model, full_message) + else: + return litellm.token_counter(model, text=message) # pyright: ignore + + +def prompt_tokens(messages: List[ChatCompletionMessageParam], model: str) -> int: + """ + Returns the number of tokens used by a prompt if it was sent to OpenAI for a chat completion. + Adapted from https://platform.openai.com/docs/guides/text-generation/managing-tokens + """ + model_info = _get_model_info(model) + if model_info is not None and model_info["litellm_provider"] == "openai": + return _open_ai_prompt_tokens(messages, model) + else: + return litellm.token_counter(model, messages=messages) # pyright: ignore + + class LlmApiHandler: """Used for any functions that require calling the external LLM API""" @@ -300,16 +317,31 @@ async def call_llm_api( start_time = default_timer() with sentry_sdk.start_span(description="LLM Call") as span: span.set_tag("model", model) - response = cast( - ChatCompletion | AsyncIterator[ChatCompletionChunk], - await litellm.acompletion( # pyright: ignore - model=model, - messages=messages, - temperature=config.temperature, - stream=stream, - response_format=response_format, # pyright: ignore - ), - ) + # OpenAI's API is bugged; when gpt-4-vision-preview is used, including the response format + # at all returns a 400 error. Additionally, gpt-4-vision-preview has a max response of 30 tokens by default. + # Until this is fixed, we have to use this workaround. + if model == "gpt-4-vision-preview": + response = cast( + ChatCompletion | AsyncIterator[ChatCompletionChunk], + await litellm.acompletion( # pyright: ignore + model=model, + messages=messages, + temperature=config.temperature, + stream=stream, + max_tokens=4096, + ), + ) + else: + response = cast( + ChatCompletion | AsyncIterator[ChatCompletionChunk], + await litellm.acompletion( # pyright: ignore + model=model, + messages=messages, + temperature=config.temperature, + stream=stream, + response_format=response_format, # pyright: ignore + ), + ) # We have to cast response since pyright isn't smart enough to connect # the dots between stream and the overloaded create function From 946e2daa593b4d6369a1ef894a3113af18c134b2 Mon Sep 17 00:00:00 2001 From: Paul Swingle Date: Wed, 3 Jan 2024 16:00:59 -0800 Subject: [PATCH 3/4] fix tests --- tests/code_context_test.py | 4 ++-- tests/conftest.py | 7 ------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/tests/code_context_test.py b/tests/code_context_test.py index 63808e168..2db02107e 100644 --- a/tests/code_context_test.py +++ b/tests/code_context_test.py @@ -8,7 +8,7 @@ from mentat.code_context import CodeContext from mentat.config import Config -from mentat.errors import ContextSizeInsufficient +from mentat.errors import ReturnToUser from mentat.feature_filters.default_filter import DefaultFilter from mentat.git_handler import get_non_gitignored_files from mentat.include_files import is_file_text_encoded @@ -222,7 +222,7 @@ async def _count_max_tokens_where(tokens_used: int) -> int: return count_tokens(code_message, "gpt-4", full_message=True) assert await _count_max_tokens_where(0) == 89 # Code - with pytest.raises(ContextSizeInsufficient): + with pytest.raises(ReturnToUser): await _count_max_tokens_where(1e6) diff --git a/tests/conftest.py b/tests/conftest.py index 8c83ad487..7530d455c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -234,13 +234,6 @@ def set_embedding_values(value): ### Auto-used fixtures -@pytest.fixture(autouse=True, scope="function") -def mock_model_available(mocker): - model_available_mock = mocker.patch.object(LlmApiHandler, "is_model_available") - model_available_mock.return_value = True - return model_available_mock - - @pytest.fixture(autouse=True, scope="function") def mock_load_env(mocker, request): if not request.config.getoption("--benchmark"): From 5357530092a0d0f6f239925cb98512fc7196abb2 Mon Sep 17 00:00:00 2001 From: Paul Swingle Date: Wed, 3 Jan 2024 16:58:56 -0800 Subject: [PATCH 4/4] add custom llm provider to config and update docs --- README.md | 14 ++++++--- docs/configuration.md | 21 +++++-------- mentat/config.py | 11 +++++++ mentat/llm_api_handler.py | 65 ++++++++++++++++++++++++--------------- 4 files changed, 69 insertions(+), 42 deletions(-) diff --git a/README.md b/README.md index 7714b854f..6abc090aa 100644 --- a/README.md +++ b/README.md @@ -58,9 +58,13 @@ cd mentat pip install -e . ``` -## Add your OpenAI API Key +## Selecting which LLM Model to use -You'll need to have API access to GPT-4 to run Mentat. There are a few options to provide Mentat with your OpenAI API key: +We highly recommend using the default model, `gpt-4-1106-preview`, as it performs vastly better than any other model benchmarked so far. However, if you wish to use a different model, jump [here](#alternative-models). + +### Add your OpenAI API Key + +There are a few options to provide Mentat with your OpenAI API key: 1. Create a `.env` file with the line `OPENAI_API_KEY=` in the directory you plan to run mentat in or in `~/.mentat/.env` 2. Run `export OPENAI_API_KEY=` prior to running Mentat @@ -68,11 +72,11 @@ You'll need to have API access to GPT-4 to run Mentat. There are a few options t ### Azure OpenAI -Mentat also works with the Azure OpenAI API. To use the Azure API, provide the `AZURE_OPENAI_ENDPOINT` (`https://.openai.azure.com/`) and `AZURE_OPENAI_KEY` environment variables instead of `OPENAI_API_KEY`. +Mentat also works with the Azure OpenAI API. To use the Azure API, provide the `AZURE_API_BASE` (`https://.openai.azure.com/`), `AZURE_API_KEY`, and `AZURE_API_VERSION` environment variables instead of `OPENAI_API_KEY`. Then, set the model as described in [configuration.md](docs/configuration.md) to your Azure model. -In addition, Mentat uses the `gpt-4-1106-preview` by default. On Azure, this model is available under a different name: `gpt-4-1106-Preview` (with a capital P). To use it, override the default model as described in [configuration.md](docs/configuration.md). +### Alternative Models -> **_Important:_** Due to changes in the OpenAI Python SDK, you can no longer use `OPENAI_API_BASE` to access the Azure API with Mentat. +Mentat uses [litellm](https://github.com/BerriAI/litellm) to retrieve chat completions from models. To use a model other than openai, simply set the model (and possibly the llm_provider) as described in [configuration.md](docs/configuration.md). Additionally, check litellm documentation for the provider that your model is under and supply any needed environment variables. ## Configuration diff --git a/docs/configuration.md b/docs/configuration.md index 3bcbd720a..37f4144a5 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -41,30 +41,25 @@ A list of key-value pairs defining a custom [Pygment Style](https://pygments.org } ``` -### Maximum Context +## 🦙 Alternative Models -If you're using a model other than gpt-3.5 or gpt-4 we won't be able to infer the model's context size so you need to manually set the maximum context like so. +Mentat uses [litellm](https://github.com/BerriAI/litellm), so you can direct it to use any local or hosted model. See their documentation to assist setting up any required environment variables, and set the model (and possibly llm_provider, if litellm doesn't automatically recognize the model) in `~/.mentat/.mentat_config.json`: ```json { - "maximum-context": 16000 + "model": "", + "llm-provider": "" } ``` -This can also be used to save costs for instance if you want to use a maximum of 16k tokens when using gpt-4-32k. -## 🦙 Alternative Models +### Maximum Context -Mentat is powered with openai's sdk so you can direct it to use a local model, or any hosted model which conforms to OpenAi's API spec. For example if you host a Llama instance following the directions [here](https://github.com/abetlen/llama-cpp-python#web-server) then you use that model with Mentat by exporting its path e.g. -```bash -export OPENAI_API_BASE="http://localhost:8000/v1 -``` -and then setting your model in `~/.mentat/.mentat_config.json`: +If you use a model unknown to litellm, you can manually set the maximum context of the model like so: ```json { - "model": "/absolute/path/to/7B/llama-model.gguf" - "maximum-context": 2048 + "maximum-context": 16000 } ``` -For models other than gpt-3.5 and gpt-4 we may not be able to infer a maximum context size so you'll also have to set the maximum-context. +This can also be used to save costs by setting a more conservative limit on models with larger context sizes. ### Alternative Formats diff --git a/mentat/config.py b/mentat/config.py index 73d1004e7..94f656cbf 100644 --- a/mentat/config.py +++ b/mentat/config.py @@ -6,6 +6,7 @@ from pathlib import Path import attr +import litellm from attr import converters, validators from mentat.git_handler import get_git_root_for_path @@ -37,6 +38,16 @@ class Config: default="gpt-4-1106-preview", metadata={"auto_completions": available_models()}, ) + llm_provider: str | None = attr.field( + default=None, + metadata={ + "description": ( + "The llm provider to use. See https://github.com/BerriAI/litellm for a" + " list of all providers and supported models." + ), + "auto_completions": litellm.provider_list, # pyright: ignore + }, + ) feature_selection_model: str = attr.field( default="gpt-4-1106-preview", metadata={"auto_completions": available_models()}, diff --git a/mentat/llm_api_handler.py b/mentat/llm_api_handler.py index 480025128..e55e9b1ee 100644 --- a/mentat/llm_api_handler.py +++ b/mentat/llm_api_handler.py @@ -317,31 +317,48 @@ async def call_llm_api( start_time = default_timer() with sentry_sdk.start_span(description="LLM Call") as span: span.set_tag("model", model) - # OpenAI's API is bugged; when gpt-4-vision-preview is used, including the response format - # at all returns a 400 error. Additionally, gpt-4-vision-preview has a max response of 30 tokens by default. - # Until this is fixed, we have to use this workaround. - if model == "gpt-4-vision-preview": - response = cast( - ChatCompletion | AsyncIterator[ChatCompletionChunk], - await litellm.acompletion( # pyright: ignore - model=model, - messages=messages, - temperature=config.temperature, - stream=stream, - max_tokens=4096, - ), - ) - else: - response = cast( - ChatCompletion | AsyncIterator[ChatCompletionChunk], - await litellm.acompletion( # pyright: ignore - model=model, - messages=messages, - temperature=config.temperature, - stream=stream, - response_format=response_format, # pyright: ignore - ), + + try: + # OpenAI's API is bugged; when gpt-4-vision-preview is used, including the response format + # at all returns a 400 error. + # Additionally, gpt-4-vision-preview has a max response of 30 tokens by default. + # Until this is fixed, we have to use this workaround. + if model == "gpt-4-vision-preview": + response = cast( + ChatCompletion | AsyncIterator[ChatCompletionChunk], + await litellm.acompletion( # pyright: ignore + model=model, + messages=messages, + temperature=config.temperature, + stream=stream, + custom_llm_provider=config.llm_provider, + max_tokens=4096, + ), + ) + else: + response = cast( + ChatCompletion | AsyncIterator[ChatCompletionChunk], + await litellm.acompletion( # pyright: ignore + model=model, + messages=messages, + temperature=config.temperature, + stream=stream, + custom_llm_provider=config.llm_provider, + response_format=response_format, # pyright: ignore + ), + ) + except litellm.APIError as e: + session_context.stream.send(f"Error accessing LLM: {e}", color="red") + raise ReturnToUser() + except litellm.NotFoundError: + llm_provider_error_message = f" for llm_provider {config.llm_provider}" + session_context.stream.send( + "Unknown model" + f" {model}{llm_provider_error_message if config.llm_provider is not None else ''}." + " Please use `/context model ` to switch models.", + color="red", ) + raise ReturnToUser() # We have to cast response since pyright isn't smart enough to connect # the dots between stream and the overloaded create function