diff --git a/llm_mistral.py b/llm_mistral.py index 4f260c6..cb19a82 100644 --- a/llm_mistral.py +++ b/llm_mistral.py @@ -3,7 +3,6 @@ import llm - @llm.hookimpl def register_models(register): register(Mistral("mistral-tiny")) @@ -17,8 +16,33 @@ class Mistral(llm.Model): def __init__(self, model_id): self.model_id = model_id + def build_messages(self, prompt, conversation): + messages = [] + if not conversation: + if prompt.system: + messages.append({"role": "system", "content": prompt.system}) + messages.append({"role": "user", "content": prompt.prompt}) + return messages + current_system = None + for prev_response in conversation.responses: + if ( + prev_response.prompt.system + and prev_response.prompt.system != current_system + ): + messages.append( + {"role": "system", "content": prev_response.prompt.system} + ) + current_system = prev_response.prompt.system + messages.append({"role": "user", "content": prev_response.prompt.prompt}) + messages.append({"role": "assistant", "content": prev_response.text()}) + if prompt.system and prompt.system != current_system: + messages.append({"role": "system", "content": prompt.system}) + messages.append({"role": "user", "content": prompt.prompt}) + return messages + def execute(self, prompt, stream, response, conversation): key = llm.get_key("", "mistral", "LLM_MISTRAL_KEY") + messages = self.build_messages(prompt, conversation) with httpx.Client() as client: with connect_sse( client, @@ -31,17 +55,16 @@ def execute(self, prompt, stream, response, conversation): }, json={ "model": self.model_id, - "messages": [ - {"role": "user", "content": prompt.prompt} - ], + "messages": messages, "stream": True, }, ) as event_source: # In case of unauthorized: event_source.response.raise_for_status() for sse in event_source.iter_sse(): - if sse.data != '[DONE]': + if sse.data != "[DONE]": try: yield sse.json()["choices"][0]["delta"]["content"] except KeyError: pass + response._prompt_json = {"messages": messages}