Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Add provider config and update spice with new models #562

Merged
merged 2 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ There are a few options to provide Mentat with your OpenAI API key:
2. Run `export OPENAI_API_KEY=<your key here>` prior to running Mentat
3. Place the previous command in your `.bashrc` or `.zshrc` to export your key on every terminal startup

If you want to use a models through Azure, Ollama or other service see [this doc](https://docs.mentat.ai/en/latest/user/alternative_models.html) for details.
If you want to use a models through Azure, Ollama or other services see [this doc](https://docs.mentat.ai/en/latest/user/alternative_models.html) for details.

# 🚀 Usage

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async def grade(to_grade, prompt, model="gpt-4-1106-preview"):
messages[1]["content"] = messages[1]["content"][:-chars_to_remove]

llm_api_handler = SESSION_CONTEXT.get().llm_api_handler
llm_grade = await llm_api_handler.call_llm_api(messages, model, False, ResponseFormat(type="json_object"))
llm_grade = await llm_api_handler.call_llm_api(messages, model, None, False, ResponseFormat(type="json_object"))
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved
content = llm_grade.text
return json.loads(content)
except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/exercism_practice.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ async def failure_analysis(exercise_runner, language):
response = ""
try:
llm_api_handler = SESSION_CONTEXT.get().llm_api_handler
llm_grade = await llm_api_handler.call_llm_api(messages, model, False)
llm_grade = await llm_api_handler.call_llm_api(messages, model, None, False)
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved
response = llm_grade.text
except BadRequestError:
response = "Unable to analyze test case\nreason: too many tokens to analyze"
Expand Down
4 changes: 2 additions & 2 deletions mentat/agent_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async def enable_agent_mode(self):
),
]
model = ctx.config.model
response = await ctx.llm_api_handler.call_llm_api(messages, model, False)
response = await ctx.llm_api_handler.call_llm_api(messages, model, ctx.config.provider, False)
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved
content = response.text

paths = [Path(path) for path in content.strip().split("\n") if Path(path).exists()]
Expand Down Expand Up @@ -81,7 +81,7 @@ async def _determine_commands(self) -> List[str]:

try:
# TODO: Should this even be a separate call or should we collect commands in the edit call?
response = await ctx.llm_api_handler.call_llm_api(messages, model, False)
response = await ctx.llm_api_handler.call_llm_api(messages, model, ctx.config.provider, False)
ctx.cost_tracker.display_last_api_call()
except BadRequestError as e:
ctx.stream.send(f"Error accessing OpenAI API: {e.message}", style="error")
Expand Down
8 changes: 8 additions & 0 deletions mentat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from argparse import ArgumentParser, Namespace
from json import JSONDecodeError
from pathlib import Path
from typing import Optional

import attr
from attr import converters, validators
Expand Down Expand Up @@ -38,14 +39,21 @@ class Config:
default="gpt-4-0125-preview",
metadata={"auto_completions": list(known_models.keys())},
)
provider: Optional[str] = attr.field(default=None, metadata={"auto_completions": ["openai", "anthropic", "azure"]})
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved
feature_selection_model: str = attr.field(
default="gpt-4-1106-preview",
metadata={"auto_completions": list(known_models.keys())},
)
feature_selection_provider: Optional[str] = attr.field(
default=None, metadata={"auto_completions": ["openai", "anthropic", "azure"]}
)
embedding_model: str = attr.field(
default="text-embedding-ada-002",
metadata={"auto_completions": [model.name for model in known_models.values() if model.embedding_model]},
)
embedding_provider: Optional[str] = attr.field(
default=None, metadata={"auto_completions": ["openai", "anthropic", "azure"]}
)
temperature: float = attr.field(default=0.2, converter=float, validator=[validators.le(1), validators.ge(0)])

maximum_context: int | None = attr.field(
Expand Down
1 change: 1 addition & 0 deletions mentat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ async def _stream_model_response(
response = await llm_api_handler.call_llm_api(
messages,
config.model,
config.provider,
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved
stream=True,
response_format=parser.response_format(),
)
Expand Down
1 change: 1 addition & 0 deletions mentat/feature_filters/llm_feature_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ async def filter(
llm_response = await llm_api_handler.call_llm_api(
messages=messages,
model=model,
provider=config.feature_selection_provider,
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved
stream=False,
response_format=ResponseFormat(type="json_object"),
)
Expand Down
29 changes: 21 additions & 8 deletions mentat/llm_api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@
)
from openai.types.chat.completion_create_params import ResponseFormat
from PIL import Image
from spice import APIConnectionError, Spice, SpiceError, SpiceMessage, SpiceResponse, StreamingSpiceResponse
from spice import APIConnectionError, Spice, SpiceMessage, SpiceResponse, StreamingSpiceResponse
from spice.errors import NoAPIKeyError
from spice.models import WHISPER_1
from spice.providers import OPEN_AI
from spice.spice import InvalidModelError

from mentat.errors import MentatError, ReturnToUser
from mentat.session_context import SESSION_CONTEXT
Expand Down Expand Up @@ -72,9 +73,12 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
try:
return await func(*args, **kwargs)
except APIConnectionError:
raise MentatError("API connection error: please check your internet connection and" " try again.")
except SpiceError as e:
raise MentatError(f"API error: {e}")
raise MentatError("API connection error: please check your internet connection and try again.")
except InvalidModelError:
SESSION_CONTEXT.get().stream.send(
"Unknown model. Use /config provider <provider> and try again.", style="error"
)
raise ReturnToUser()

return async_wrapper
else:
Expand All @@ -84,9 +88,12 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
try:
return func(*args, **kwargs)
except APIConnectionError:
raise MentatError("API connection error: please check your internet connection and" " try again.")
except SpiceError as e:
raise MentatError(f"API error: {e}")
raise MentatError("API connection error: please check your internet connection and try again.")
except InvalidModelError:
SESSION_CONTEXT.get().stream.send(
"Unknown model. Use /config provider <provider> and try again.", style="error"
)
raise ReturnToUser()

return sync_wrapper

Expand Down Expand Up @@ -349,6 +356,7 @@ async def call_llm_api(
self,
messages: List[SpiceMessage],
model: str,
provider: Optional[str],
stream: Literal[False],
response_format: ResponseFormat = ResponseFormat(type="text"),
) -> SpiceResponse:
Expand All @@ -359,6 +367,7 @@ async def call_llm_api(
self,
messages: List[SpiceMessage],
model: str,
provider: Optional[str],
stream: Literal[True],
response_format: ResponseFormat = ResponseFormat(type="text"),
) -> StreamingSpiceResponse:
Expand All @@ -369,6 +378,7 @@ async def call_llm_api(
self,
messages: List[SpiceMessage],
model: str,
provider: Optional[str],
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved
stream: bool,
response_format: ResponseFormat = ResponseFormat(type="text"),
) -> SpiceResponse | StreamingSpiceResponse:
Expand All @@ -386,6 +396,7 @@ async def call_llm_api(
if not stream:
response = await self.spice.get_response(
model=model,
provider=provider,
messages=messages,
temperature=config.temperature,
response_format=response_format, # pyright: ignore
Expand All @@ -394,6 +405,7 @@ async def call_llm_api(
else:
response = await self.spice.stream_response(
model=model,
provider=provider,
messages=messages,
temperature=config.temperature,
response_format=response_format, # pyright: ignore
Expand All @@ -403,7 +415,8 @@ async def call_llm_api(

@api_guard
def call_embedding_api(self, input_texts: list[str], model: str = "text-embedding-ada-002") -> Embeddings:
return self.spice.get_embeddings_sync(input_texts, model) # pyright: ignore
ctx = SESSION_CONTEXT.get()
return self.spice.get_embeddings_sync(input_texts, model, provider=ctx.config.embedding_provider) # pyright: ignore

@api_guard
async def call_whisper_api(self, audio_path: Path) -> str:
Expand Down
4 changes: 3 additions & 1 deletion mentat/revisor/revisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ async def revise_edit(file_edit: FileEdit):
"\nRevising edits for file" f" {get_relative_path(file_edit.file_path, ctx.cwd)}...",
style="info",
)
response = await ctx.llm_api_handler.call_llm_api(messages, model=ctx.config.model, stream=False)
response = await ctx.llm_api_handler.call_llm_api(
messages, model=ctx.config.model, provider=ctx.config.provider, stream=False
)
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved
message = response.text
messages.append(ChatCompletionAssistantMessageParam(content=message, role="assistant"))
ctx.conversation.add_transcript_message(
Expand Down
1 change: 1 addition & 0 deletions mentat/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ async def _main(self):
stream.send(None, channel="client_exit")
break
except ReturnToUser:
stream.send(None, channel="loading", terminate=True)
need_user_request = True
continue
except (
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ selenium==4.15.2
sentry-sdk==1.34.0
sounddevice==0.4.6
soundfile==0.12.1
spiceai==0.1.8
spiceai==0.1.9
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved
termcolor==2.3.0
textual==0.47.1
textual-autocomplete==2.1.0b0
Expand Down
6 changes: 3 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def mock_call_llm_api(mocker):
completion_mock = mocker.patch.object(LlmApiHandler, "call_llm_api")

def wrap_unstreamed_string(value):
return SpiceResponse(SpiceCallArgs("gpt-3.5-turbo", [], False), value, 1, 0, 0, True)
return SpiceResponse(SpiceCallArgs("gpt-3.5-turbo", [], False), value, 1, 0, 0, True, 1)
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved

def wrap_streamed_strings(values):
class MockStreamingSpiceResponse:
Expand All @@ -115,7 +115,7 @@ async def __anext__(self):
return values[self.cur_value - 1]

def current_response(self):
return SpiceResponse(SpiceCallArgs("gpt-3.5-turbo", [], True), "".join(values), 1, 0, 0, True)
return SpiceResponse(SpiceCallArgs("gpt-3.5-turbo", [], True), "".join(values), 1, 0, 0, True, 1)

mock_spice_response = MockStreamingSpiceResponse()
return mock_spice_response
Expand All @@ -131,7 +131,7 @@ def set_unstreamed_values(value):
completion_mock.set_unstreamed_values = set_unstreamed_values

def set_return_values(values):
async def call_llm_api_mock(messages, model, stream, response_format="unused"):
async def call_llm_api_mock(messages, model, provider, stream, response_format="unused"):
value = call_llm_api_mock.values.pop()
if stream:
return wrap_streamed_strings([value])
Expand Down
Loading