-
Notifications
You must be signed in to change notification settings - Fork 128
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* update inits to expose ollamachatgenerator * add ollama chat generator * add tests for ollama chat generator * add tests for init method * Change order of chat history to chronological * add test for chat history * add return type to _build_message * refactor message_to_dict to one liner * add return types to fixtures * add test for unavailable model * drop streaming references for now * drop streaming callback from tests * Update integrations/ollama/src/ollama_haystack/chat/chat_generator.py Co-authored-by: Stefano Fiorucci <[email protected]> * Update integrations/ollama/src/ollama_haystack/chat/chat_generator.py Co-authored-by: Stefano Fiorucci <[email protected]> * Update integrations/ollama/src/ollama_haystack/chat/chat_generator.py Co-authored-by: Stefano Fiorucci <[email protected]> * drop _chat_history_to_dict * drop intermediate ollama to haystack response methods * change metadata to meta * lint with black * refactor chat message fixture into one list * add chat generator example * rename example -> generator example * add new chat generator example * Update integrations/ollama/src/ollama_haystack/chat/chat_generator.py Co-authored-by: Stefano Fiorucci <[email protected]> * update test for new timeout * Update test_chat_generator.py * increase generator timeout * add docstrings * fix --------- Co-authored-by: Stefano Fiorucci <[email protected]>
- Loading branch information
1 parent
9411c99
commit 7435282
Showing
8 changed files
with
278 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# In order to run this example, you will need to have an instance of Ollama running with the | ||
# orca-mini model downloaded. We suggest you use the following commands to serve an orca-mini | ||
# model from Ollama | ||
# | ||
# docker run -d -p 11434:11434 --name ollama ollama/ollama:latest | ||
# docker exec ollama ollama pull orca-mini | ||
|
||
from haystack.dataclasses import ChatMessage | ||
|
||
from ollama_haystack import OllamaChatGenerator | ||
|
||
messages = [ | ||
ChatMessage.from_user("What's Natural Language Processing?"), | ||
ChatMessage.from_system( | ||
"Natural Language Processing (NLP) is a field of computer science and artificial " | ||
"intelligence concerned with the interaction between computers and human language" | ||
), | ||
ChatMessage.from_user("How do I get started?"), | ||
] | ||
client = OllamaChatGenerator(model="orca-mini", timeout=45, url="http://localhost:11434/api/chat") | ||
|
||
response = client.run(messages, generation_kwargs={"temperature": 0.2}) | ||
|
||
print(response["replies"]) | ||
# | ||
# [ | ||
# ChatMessage( | ||
# content="Natural Language Processing (NLP) is a broad field of computer science and artificial intelligence " | ||
# "that involves the interaction between computers and human language. To get started in NLP, " | ||
# "you can start by learning about the different techniques and tools used in NLP such as machine " | ||
# "learning algorithms, deep learning frameworks, and natural language processing libraries. You can " | ||
# "also learn about the applications of NLP in various fields such as chatbots, sentiment analysis, " | ||
# "speech recognition, and text classification. Additionally, you can explore the available resources " | ||
# "such as online courses, tutorials, and books on NLP to gain a deeper understanding of the field.", | ||
# role=<ChatRole.ASSISTANT: 'assistant'>, | ||
# name=None, | ||
# meta={ | ||
# "model": "orca-mini", | ||
# "created_at": "2024-01-08T15:35:23.378609793Z", | ||
# "done": True, | ||
# "total_duration": 20026330217, | ||
# "load_duration": 1540167, | ||
# "prompt_eval_count": 99, | ||
# "prompt_eval_duration": 8486609000, | ||
# "eval_count": 124, | ||
# "eval_duration": 11532988000, | ||
# }, | ||
# ) | ||
# ] |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
96 changes: 96 additions & 0 deletions
96
integrations/ollama/src/ollama_haystack/chat/chat_generator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
from typing import Any, Dict, List, Optional | ||
|
||
import requests | ||
from haystack import component | ||
from haystack.dataclasses import ChatMessage | ||
from requests import Response | ||
|
||
|
||
@component | ||
class OllamaChatGenerator: | ||
""" | ||
Chat Generator based on Ollama. Ollama is a library for easily running LLMs locally. | ||
This component provides an interface to generate text using a LLM running in Ollama. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model: str = "orca-mini", | ||
url: str = "http://localhost:11434/api/chat", | ||
generation_kwargs: Optional[Dict[str, Any]] = None, | ||
template: Optional[str] = None, | ||
timeout: int = 120, | ||
): | ||
""" | ||
:param model: The name of the model to use. The model should be available in the running Ollama instance. | ||
Default is "orca-mini". | ||
:param url: The URL of the chat endpoint of a running Ollama instance. | ||
Default is "http://localhost:11434/api/chat". | ||
:param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature, | ||
top_p, and others. See the available arguments in | ||
[Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). | ||
:param template: The full prompt template (overrides what is defined in the Ollama Modelfile). | ||
:param timeout: The number of seconds before throwing a timeout error from the Ollama API. | ||
Default is 120 seconds. | ||
""" | ||
|
||
self.timeout = timeout | ||
self.template = template | ||
self.generation_kwargs = generation_kwargs or {} | ||
self.url = url | ||
self.model = model | ||
|
||
def _message_to_dict(self, message: ChatMessage) -> Dict[str, str]: | ||
return {"role": message.role.value, "content": message.content} | ||
|
||
def _create_json_payload(self, messages: List[ChatMessage], generation_kwargs=None) -> Dict[str, Any]: | ||
""" | ||
Returns A dictionary of JSON arguments for a POST request to an Ollama service | ||
:param messages: A history/list of chat messages | ||
:param generation_kwargs: | ||
:return: A dictionary of arguments for a POST request to an Ollama service | ||
""" | ||
generation_kwargs = generation_kwargs or {} | ||
return { | ||
"messages": [self._message_to_dict(message) for message in messages], | ||
"model": self.model, | ||
"stream": False, | ||
"template": self.template, | ||
"options": generation_kwargs, | ||
} | ||
|
||
def _build_message_from_ollama_response(self, ollama_response: Response) -> ChatMessage: | ||
""" | ||
Converts the non-streaming response from the Ollama API to a ChatMessage. | ||
:param ollama_response: The completion returned by the Ollama API. | ||
:return: The ChatMessage. | ||
""" | ||
json_content = ollama_response.json() | ||
message = ChatMessage.from_assistant(content=json_content["message"]["content"]) | ||
message.meta.update({key: value for key, value in json_content.items() if key != "message"}) | ||
return message | ||
|
||
@component.output_types(replies=List[ChatMessage]) | ||
def run( | ||
self, | ||
messages: List[ChatMessage], | ||
generation_kwargs: Optional[Dict[str, Any]] = None, | ||
): | ||
""" | ||
Run an Ollama Model on a given chat history. | ||
:param messages: A list of ChatMessage instances representing the input messages. | ||
:param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature, | ||
top_p, etc. See the | ||
[Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). | ||
:return: A dictionary of the replies containing their metadata | ||
""" | ||
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} | ||
|
||
json_payload = self._create_json_payload(messages, generation_kwargs) | ||
|
||
response = requests.post(url=self.url, json=json_payload, timeout=self.timeout) | ||
|
||
# throw error on unsuccessful response | ||
response.raise_for_status() | ||
|
||
return {"replies": [self._build_message_from_ollama_response(response)]} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
from typing import List | ||
from unittest.mock import Mock | ||
|
||
import pytest | ||
from haystack.dataclasses import ChatMessage, ChatRole | ||
from requests import HTTPError, Response | ||
|
||
from ollama_haystack import OllamaChatGenerator | ||
|
||
|
||
@pytest.fixture | ||
def chat_messages() -> List[ChatMessage]: | ||
return [ | ||
ChatMessage.from_user("Tell me about why Super Mario is the greatest superhero"), | ||
ChatMessage.from_assistant( | ||
"Super Mario has prevented Bowser from destroying the world", {"something": "something"} | ||
), | ||
] | ||
|
||
|
||
class TestOllamaChatGenerator: | ||
def test_init_default(self): | ||
component = OllamaChatGenerator() | ||
assert component.model == "orca-mini" | ||
assert component.url == "http://localhost:11434/api/chat" | ||
assert component.generation_kwargs == {} | ||
assert component.template is None | ||
assert component.timeout == 120 | ||
|
||
def test_init(self): | ||
component = OllamaChatGenerator( | ||
model="llama2", | ||
url="http://my-custom-endpoint:11434/api/chat", | ||
generation_kwargs={"temperature": 0.5}, | ||
timeout=5, | ||
) | ||
|
||
assert component.model == "llama2" | ||
assert component.url == "http://my-custom-endpoint:11434/api/chat" | ||
assert component.generation_kwargs == {"temperature": 0.5} | ||
assert component.template is None | ||
assert component.timeout == 5 | ||
|
||
def test_create_json_payload(self, chat_messages): | ||
observed = OllamaChatGenerator(model="some_model")._create_json_payload(chat_messages, {"temperature": 0.1}) | ||
expected = { | ||
"messages": [ | ||
{"role": "user", "content": "Tell me about why Super Mario is the greatest superhero"}, | ||
{"role": "assistant", "content": "Super Mario has prevented Bowser from destroying the world"}, | ||
], | ||
"model": "some_model", | ||
"stream": False, | ||
"template": None, | ||
"options": {"temperature": 0.1}, | ||
} | ||
|
||
assert observed == expected | ||
|
||
def test_build_message_from_ollama_response(self): | ||
model = "some_model" | ||
|
||
mock_ollama_response = Mock(Response) | ||
mock_ollama_response.json.return_value = { | ||
"model": model, | ||
"created_at": "2023-12-12T14:13:43.416799Z", | ||
"message": {"role": "assistant", "content": "Hello! How are you today?"}, | ||
"done": True, | ||
"total_duration": 5191566416, | ||
"load_duration": 2154458, | ||
"prompt_eval_count": 26, | ||
"prompt_eval_duration": 383809000, | ||
"eval_count": 298, | ||
"eval_duration": 4799921000, | ||
} | ||
|
||
observed = OllamaChatGenerator(model=model)._build_message_from_ollama_response(mock_ollama_response) | ||
|
||
assert observed.role == "assistant" | ||
assert observed.content == "Hello! How are you today?" | ||
|
||
@pytest.mark.integration | ||
def test_run(self): | ||
chat_generator = OllamaChatGenerator() | ||
|
||
user_questions_and_assistant_answers = [ | ||
("What's the capital of France?", "Paris"), | ||
("What is the capital of Canada?", "Ottawa"), | ||
("What is the capital of Ghana?", "Accra"), | ||
] | ||
|
||
for question, answer in user_questions_and_assistant_answers: | ||
message = ChatMessage.from_user(question) | ||
|
||
response = chat_generator.run([message]) | ||
|
||
assert isinstance(response, dict) | ||
assert isinstance(response["replies"], list) | ||
assert answer in response["replies"][0].content | ||
|
||
@pytest.mark.integration | ||
def test_run_with_chat_history(self): | ||
chat_generator = OllamaChatGenerator() | ||
|
||
chat_history = [ | ||
{"role": "user", "content": "What is the largest city in the United Kingdom by population?"}, | ||
{"role": "assistant", "content": "London is the largest city in the United Kingdom by population"}, | ||
{"role": "user", "content": "And what is the second largest?"}, | ||
] | ||
|
||
chat_messages = [ | ||
ChatMessage(role=ChatRole(message["role"]), content=message["content"], name=None) | ||
for message in chat_history | ||
] | ||
response = chat_generator.run(chat_messages) | ||
|
||
assert isinstance(response, dict) | ||
assert isinstance(response["replies"], list) | ||
assert "Manchester" in response["replies"][-1].content | ||
|
||
@pytest.mark.integration | ||
def test_run_model_unavailable(self): | ||
component = OllamaChatGenerator(model="Alistair_and_Stefano_are_great") | ||
|
||
with pytest.raises(HTTPError): | ||
message = ChatMessage.from_user( | ||
"Based on your infinite wisdom, can you tell me why Alistair and Stefano are so great?" | ||
) | ||
component.run([message]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters