Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Add provider config and update spice with new models #562

Merged
merged 2 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async def grade(to_grade, prompt, model="gpt-4-1106-preview"):
messages[1]["content"] = messages[1]["content"][:-chars_to_remove]

llm_api_handler = SESSION_CONTEXT.get().llm_api_handler
llm_grade = await llm_api_handler.call_llm_api(messages, model, False, ResponseFormat(type="json_object"))
llm_grade = await llm_api_handler.call_llm_api(messages, model, None, False, ResponseFormat(type="json_object"))
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved
content = llm_grade.text
return json.loads(content)
except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/exercism_practice.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ async def failure_analysis(exercise_runner, language):
response = ""
try:
llm_api_handler = SESSION_CONTEXT.get().llm_api_handler
llm_grade = await llm_api_handler.call_llm_api(messages, model, False)
llm_grade = await llm_api_handler.call_llm_api(messages, model, None, False)
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved
response = llm_grade.text
except BadRequestError:
response = "Unable to analyze test case\nreason: too many tokens to analyze"
Expand Down
4 changes: 2 additions & 2 deletions mentat/agent_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async def enable_agent_mode(self):
),
]
model = ctx.config.model
response = await ctx.llm_api_handler.call_llm_api(messages, model, False)
response = await ctx.llm_api_handler.call_llm_api(messages, model, ctx.config.provider, False)
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved
content = response.text

paths = [Path(path) for path in content.strip().split("\n") if Path(path).exists()]
Expand Down Expand Up @@ -81,7 +81,7 @@ async def _determine_commands(self) -> List[str]:

try:
# TODO: Should this even be a separate call or should we collect commands in the edit call?
response = await ctx.llm_api_handler.call_llm_api(messages, model, False)
response = await ctx.llm_api_handler.call_llm_api(messages, model, ctx.config.provider, False)
ctx.cost_tracker.display_last_api_call()
except BadRequestError as e:
ctx.stream.send(f"Error accessing OpenAI API: {e.message}", style="error")
Expand Down
8 changes: 8 additions & 0 deletions mentat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from argparse import ArgumentParser, Namespace
from json import JSONDecodeError
from pathlib import Path
from typing import Optional

import attr
from attr import converters, validators
Expand Down Expand Up @@ -38,14 +39,21 @@ class Config:
default="gpt-4-0125-preview",
metadata={"auto_completions": list(known_models.keys())},
)
provider: Optional[str] = attr.field(default=None, metadata={"auto_completions": ["openai", "anthropic", "azure"]})
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved
feature_selection_model: str = attr.field(
default="gpt-4-1106-preview",
metadata={"auto_completions": list(known_models.keys())},
)
feature_selection_provider: Optional[str] = attr.field(
default=None, metadata={"auto_completions": ["openai", "anthropic", "azure"]}
)
embedding_model: str = attr.field(
default="text-embedding-ada-002",
metadata={"auto_completions": [model.name for model in known_models.values() if model.embedding_model]},
)
embedding_provider: Optional[str] = attr.field(
default=None, metadata={"auto_completions": ["openai", "anthropic", "azure"]}
)
temperature: float = attr.field(default=0.2, converter=float, validator=[validators.le(1), validators.ge(0)])

maximum_context: int | None = attr.field(
Expand Down
1 change: 1 addition & 0 deletions mentat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ async def _stream_model_response(
response = await llm_api_handler.call_llm_api(
messages,
config.model,
config.provider,
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved
stream=True,
response_format=parser.response_format(),
)
Expand Down
1 change: 1 addition & 0 deletions mentat/feature_filters/llm_feature_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ async def filter(
llm_response = await llm_api_handler.call_llm_api(
messages=messages,
model=model,
provider=config.feature_selection_provider,
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved
stream=False,
response_format=ResponseFormat(type="json_object"),
)
Expand Down
8 changes: 7 additions & 1 deletion mentat/llm_api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ async def call_llm_api(
self,
messages: List[SpiceMessage],
model: str,
provider: Optional[str],
stream: Literal[False],
response_format: ResponseFormat = ResponseFormat(type="text"),
) -> SpiceResponse:
Expand All @@ -359,6 +360,7 @@ async def call_llm_api(
self,
messages: List[SpiceMessage],
model: str,
provider: Optional[str],
stream: Literal[True],
response_format: ResponseFormat = ResponseFormat(type="text"),
) -> StreamingSpiceResponse:
Expand All @@ -369,6 +371,7 @@ async def call_llm_api(
self,
messages: List[SpiceMessage],
model: str,
provider: Optional[str],
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved
stream: bool,
response_format: ResponseFormat = ResponseFormat(type="text"),
) -> SpiceResponse | StreamingSpiceResponse:
Expand All @@ -386,6 +389,7 @@ async def call_llm_api(
if not stream:
response = await self.spice.get_response(
model=model,
provider=provider,
messages=messages,
temperature=config.temperature,
response_format=response_format, # pyright: ignore
Expand All @@ -394,6 +398,7 @@ async def call_llm_api(
else:
response = await self.spice.stream_response(
model=model,
provider=provider,
messages=messages,
temperature=config.temperature,
response_format=response_format, # pyright: ignore
Expand All @@ -403,7 +408,8 @@ async def call_llm_api(

@api_guard
def call_embedding_api(self, input_texts: list[str], model: str = "text-embedding-ada-002") -> Embeddings:
return self.spice.get_embeddings_sync(input_texts, model) # pyright: ignore
ctx = SESSION_CONTEXT.get()
return self.spice.get_embeddings_sync(input_texts, model, provider=ctx.config.embedding_provider) # pyright: ignore

@api_guard
async def call_whisper_api(self, audio_path: Path) -> str:
Expand Down
4 changes: 3 additions & 1 deletion mentat/revisor/revisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ async def revise_edit(file_edit: FileEdit):
"\nRevising edits for file" f" {get_relative_path(file_edit.file_path, ctx.cwd)}...",
style="info",
)
response = await ctx.llm_api_handler.call_llm_api(messages, model=ctx.config.model, stream=False)
response = await ctx.llm_api_handler.call_llm_api(
messages, model=ctx.config.model, provider=ctx.config.provider, stream=False
)
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved
message = response.text
messages.append(ChatCompletionAssistantMessageParam(content=message, role="assistant"))
ctx.conversation.add_transcript_message(
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ selenium==4.15.2
sentry-sdk==1.34.0
sounddevice==0.4.6
soundfile==0.12.1
spiceai==0.1.8
spiceai==0.1.9
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved
termcolor==2.3.0
textual==0.47.1
textual-autocomplete==2.1.0b0
Expand Down
6 changes: 3 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def mock_call_llm_api(mocker):
completion_mock = mocker.patch.object(LlmApiHandler, "call_llm_api")

def wrap_unstreamed_string(value):
return SpiceResponse(SpiceCallArgs("gpt-3.5-turbo", [], False), value, 1, 0, 0, True)
return SpiceResponse(SpiceCallArgs("gpt-3.5-turbo", [], False), value, 1, 0, 0, True, 1)
PCSwingle marked this conversation as resolved.
Show resolved Hide resolved

def wrap_streamed_strings(values):
class MockStreamingSpiceResponse:
Expand All @@ -115,7 +115,7 @@ async def __anext__(self):
return values[self.cur_value - 1]

def current_response(self):
return SpiceResponse(SpiceCallArgs("gpt-3.5-turbo", [], True), "".join(values), 1, 0, 0, True)
return SpiceResponse(SpiceCallArgs("gpt-3.5-turbo", [], True), "".join(values), 1, 0, 0, True, 1)

mock_spice_response = MockStreamingSpiceResponse()
return mock_spice_response
Expand All @@ -131,7 +131,7 @@ def set_unstreamed_values(value):
completion_mock.set_unstreamed_values = set_unstreamed_values

def set_return_values(values):
async def call_llm_api_mock(messages, model, stream, response_format="unused"):
async def call_llm_api_mock(messages, model, provider, stream, response_format="unused"):
value = call_llm_api_mock.values.pop()
if stream:
return wrap_streamed_strings([value])
Expand Down
Loading