Skip to content

Commit

Permalink
System prompts and conversation support, refs #1
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Dec 15, 2023
1 parent 4c8fcfb commit 534eb86
Showing 1 changed file with 28 additions and 5 deletions.
33 changes: 28 additions & 5 deletions llm_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import llm



@llm.hookimpl
def register_models(register):
register(Mistral("mistral-tiny"))
Expand All @@ -17,8 +16,33 @@ class Mistral(llm.Model):
def __init__(self, model_id):
self.model_id = model_id

def build_messages(self, prompt, conversation):
messages = []
if not conversation:
if prompt.system:
messages.append({"role": "system", "content": prompt.system})
messages.append({"role": "user", "content": prompt.prompt})
return messages
current_system = None
for prev_response in conversation.responses:
if (
prev_response.prompt.system
and prev_response.prompt.system != current_system
):
messages.append(
{"role": "system", "content": prev_response.prompt.system}
)
current_system = prev_response.prompt.system
messages.append({"role": "user", "content": prev_response.prompt.prompt})
messages.append({"role": "assistant", "content": prev_response.text()})
if prompt.system and prompt.system != current_system:
messages.append({"role": "system", "content": prompt.system})
messages.append({"role": "user", "content": prompt.prompt})
return messages

def execute(self, prompt, stream, response, conversation):
key = llm.get_key("", "mistral", "LLM_MISTRAL_KEY")
messages = self.build_messages(prompt, conversation)
with httpx.Client() as client:
with connect_sse(
client,
Expand All @@ -31,17 +55,16 @@ def execute(self, prompt, stream, response, conversation):
},
json={
"model": self.model_id,
"messages": [
{"role": "user", "content": prompt.prompt}
],
"messages": messages,
"stream": True,
},
) as event_source:
# In case of unauthorized:
event_source.response.raise_for_status()
for sse in event_source.iter_sse():
if sse.data != '[DONE]':
if sse.data != "[DONE]":
try:
yield sse.json()["choices"][0]["delta"]["content"]
except KeyError:
pass
response._prompt_json = {"messages": messages}

0 comments on commit 534eb86

Please sign in to comment.