Skip to content

Commit

Permalink
fix: fix ChatGPT invocation layer (and add async support) (#5979)
Browse files Browse the repository at this point in the history
* ChatGPT async

* release note

* fix tests
  • Loading branch information
anakin87 authored Oct 5, 2023
1 parent 282419d commit ccc9f01
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 55 deletions.
156 changes: 115 additions & 41 deletions haystack/nodes/prompt/invocation_layer/chatgpt.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
import logging
from typing import Optional, List, Dict, Union, Any
from typing import Any, Dict, List, Optional, Union

from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamingHandler, TokenStreamingHandler
from haystack.nodes.prompt.invocation_layer.open_ai import OpenAIInvocationLayer
from haystack.nodes.prompt.invocation_layer.utils import has_azure_parameters
from haystack.utils.openai_utils import openai_request, _check_openai_finish_reason, count_openai_tokens_messages
from haystack.utils.openai_utils import (
_check_openai_finish_reason,
check_openai_async_policy_violation,
check_openai_policy_violation,
count_openai_tokens_messages,
openai_async_request,
openai_request,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -43,45 +50,6 @@ def __init__(
"""
super().__init__(api_key, model_name_or_path, max_length, api_base=api_base, **kwargs)

def _execute_openai_request(
self, prompt: Union[str, List[Dict]], base_payload: Dict, kwargs_with_defaults: Dict, stream: bool
):
"""
For more details, see [OpenAI ChatGPT API reference](https://platform.openai.com/docs/api-reference/chat).
"""
if isinstance(prompt, str):
messages = [{"role": "user", "content": prompt}]
elif isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], dict):
messages = prompt
else:
raise ValueError(
f"The prompt format is different than what the model expects. "
f"The model {self.model_name_or_path} requires either a string or messages in the ChatML format. "
f"For more details, see this [GitHub discussion](https://github.com/openai/openai-python/blob/main/chatml.md)."
)
extra_payload = {"messages": messages}
payload = {**base_payload, **extra_payload}
if not stream:
response = openai_request(url=self.url, headers=self.headers, payload=payload)
_check_openai_finish_reason(result=response, payload=payload)
assistant_response = [choice["message"]["content"].strip() for choice in response["choices"]]
else:
response = openai_request(
url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True
)
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
assistant_response = self._process_streaming_response(response=response, stream_handler=handler)

# Although ChatGPT generates text until stop words are encountered, unfortunately it includes the stop word
# We want to exclude it to be consistent with other invocation layers
if "stop" in kwargs_with_defaults and kwargs_with_defaults["stop"] is not None:
stop_words = kwargs_with_defaults["stop"]
for idx, _ in enumerate(assistant_response):
for stop_word in stop_words:
assistant_response[idx] = assistant_response[idx].replace(stop_word, "").strip()

return assistant_response

def _extract_token(self, event_data: Dict[str, Any]):
delta = event_data["choices"][0]["delta"]
if "content" in delta:
Expand Down Expand Up @@ -141,3 +109,109 @@ def supports(cls, model_name_or_path: str, **kwargs) -> bool:
and not "gpt-3.5-turbo-instruct" in model_name_or_path
)
return valid_model and not has_azure_parameters(**kwargs)

async def ainvoke(self, *args, **kwargs):
"""
Invokes a prompt on the model. Based on the model, it takes in a prompt (or either a prompt or a list of messages)
and returns a list of responses using a REST invocation.
:return: The responses are being returned.
Note: Only kwargs relevant to OpenAI are passed to OpenAI rest API. Others kwargs are ignored.
For more details, see OpenAI [documentation](https://platform.openai.com/docs/api-reference/completions/create).
"""
prompt, base_payload, kwargs_with_defaults, stream, moderation = self._prepare_invoke(*args, **kwargs)

if moderation and await check_openai_async_policy_violation(input=prompt, headers=self.headers):
logger.info("Prompt '%s' will not be sent to OpenAI due to potential policy violation.", prompt)
return []

if isinstance(prompt, str):
messages = [{"role": "user", "content": prompt}]
elif isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], dict):
messages = prompt
else:
raise ValueError(
f"The prompt format is different than what the model expects. "
f"The model {self.model_name_or_path} requires either a string or messages in the ChatML format. "
f"For more details, see this [GitHub discussion](https://github.com/openai/openai-python/blob/main/chatml.md)."
)
extra_payload = {"messages": messages}
payload = {**base_payload, **extra_payload}
if not stream:
response = await openai_async_request(url=self.url, headers=self.headers, payload=payload)
_check_openai_finish_reason(result=response, payload=payload)
assistant_response = [choice["message"]["content"].strip() for choice in response["choices"]]
else:
response = await openai_async_request(
url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True
)
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
assistant_response = self._process_streaming_response(response=response, stream_handler=handler)

# Although ChatGPT generates text until stop words are encountered, unfortunately it includes the stop word
# We want to exclude it to be consistent with other invocation layers
if "stop" in kwargs_with_defaults and kwargs_with_defaults["stop"] is not None:
stop_words = kwargs_with_defaults["stop"]
for idx, _ in enumerate(assistant_response):
for stop_word in stop_words:
assistant_response[idx] = assistant_response[idx].replace(stop_word, "").strip()

if moderation and await check_openai_async_policy_violation(input=assistant_response, headers=self.headers):
logger.info("Response '%s' will not be returned due to potential policy violation.", assistant_response)
return []

return assistant_response

def invoke(self, *args, **kwargs):
"""
Invokes a prompt on the model. Based on the model, it takes in a prompt (or either a prompt or a list of messages)
and returns a list of responses using a REST invocation.
:return: The responses are being returned.
Note: Only kwargs relevant to OpenAI are passed to OpenAI rest API. Others kwargs are ignored.
For more details, see OpenAI [documentation](https://platform.openai.com/docs/api-reference/completions/create).
"""
prompt, base_payload, kwargs_with_defaults, stream, moderation = self._prepare_invoke(*args, **kwargs)

if moderation and check_openai_policy_violation(input=prompt, headers=self.headers):
logger.info("Prompt '%s' will not be sent to OpenAI due to potential policy violation.", prompt)
return []

if isinstance(prompt, str):
messages = [{"role": "user", "content": prompt}]
elif isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], dict):
messages = prompt
else:
raise ValueError(
f"The prompt format is different than what the model expects. "
f"The model {self.model_name_or_path} requires either a string or messages in the ChatML format. "
f"For more details, see this [GitHub discussion](https://github.com/openai/openai-python/blob/main/chatml.md)."
)
extra_payload = {"messages": messages}
payload = {**base_payload, **extra_payload}
if not stream:
response = openai_request(url=self.url, headers=self.headers, payload=payload)
_check_openai_finish_reason(result=response, payload=payload)
assistant_response = [choice["message"]["content"].strip() for choice in response["choices"]]
else:
response = openai_request(
url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True
)
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
assistant_response = self._process_streaming_response(response=response, stream_handler=handler)

# Although ChatGPT generates text until stop words are encountered, unfortunately it includes the stop word
# We want to exclude it to be consistent with other invocation layers
if "stop" in kwargs_with_defaults and kwargs_with_defaults["stop"] is not None:
stop_words = kwargs_with_defaults["stop"]
for idx, _ in enumerate(assistant_response):
for stop_word in stop_words:
assistant_response[idx] = assistant_response[idx].replace(stop_word, "").strip()

if moderation and check_openai_policy_violation(input=assistant_response, headers=self.headers):
logger.info("Response '%s' will not be returned due to potential policy violation.", assistant_response)
return []

return assistant_response
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
fixes:
- |
Fixed the bug that prevented the correct usage of ChatGPT invocation layer
in 1.21.1.
Added async support for ChatGPT invocation layer.
6 changes: 3 additions & 3 deletions test/prompt/invocation_layer/test_chatgpt.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import logging
from unittest.mock import patch

import logging
import pytest

from haystack.nodes.prompt.invocation_layer import ChatGPTInvocationLayer


@pytest.mark.unit
@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request")
@patch("haystack.nodes.prompt.invocation_layer.chatgpt.openai_request")
def test_default_api_base(mock_request):
with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"):
invocation_layer = ChatGPTInvocationLayer(api_key="fake_api_key")
Expand All @@ -19,7 +19,7 @@ def test_default_api_base(mock_request):


@pytest.mark.unit
@patch("haystack.nodes.prompt.invocation_layer.open_ai.openai_request")
@patch("haystack.nodes.prompt.invocation_layer.chatgpt.openai_request")
def test_custom_api_base(mock_request):
with patch("haystack.nodes.prompt.invocation_layer.open_ai.load_openai_tokenizer"):
invocation_layer = ChatGPTInvocationLayer(api_key="fake_api_key", api_base="https://fake_api_base.com")
Expand Down
26 changes: 15 additions & 11 deletions test/prompt/test_prompt_node.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
import os
import logging
from typing import Optional, Union, List, Dict, Any, Tuple
from unittest.mock import patch, Mock, MagicMock, AsyncMock
import os
from typing import Any, Dict, List, Optional, Tuple, Union
from unittest.mock import AsyncMock, MagicMock, Mock, patch

import pytest
from prompthub import Prompt

from haystack import Document, Pipeline, BaseComponent, MultiLabel
from haystack.nodes.prompt import PromptTemplate, PromptNode, PromptModel
from haystack.nodes.prompt.prompt_template import LEGACY_DEFAULT_TEMPLATES
from haystack import BaseComponent, Document, MultiLabel, Pipeline
from haystack.nodes.prompt import PromptModel, PromptNode, PromptTemplate
from haystack.nodes.prompt.invocation_layer import (
AzureChatGPTInvocationLayer,
AzureOpenAIInvocationLayer,
OpenAIInvocationLayer,
ChatGPTInvocationLayer,
OpenAIInvocationLayer,
)
from haystack.nodes.prompt.prompt_template import LEGACY_DEFAULT_TEMPLATES


@pytest.fixture
Expand Down Expand Up @@ -1082,8 +1082,8 @@ def test_content_moderation_gpt_35():
ChatGPTInvocationLayer.
"""
prompt_node = PromptNode(model_name_or_path="gpt-3.5-turbo", api_key="key", model_kwargs={"moderate_content": True})
with patch("haystack.nodes.prompt.invocation_layer.open_ai.check_openai_policy_violation") as mock_check, patch(
"haystack.nodes.prompt.invocation_layer.open_ai.openai_request"
with patch("haystack.nodes.prompt.invocation_layer.chatgpt.check_openai_policy_violation") as mock_check, patch(
"haystack.nodes.prompt.invocation_layer.chatgpt.openai_request"
) as mock_request:
VIOLENT_TEXT = "some violent text"
mock_check.side_effect = lambda input, headers: input == VIOLENT_TEXT or input == [VIOLENT_TEXT]
Expand All @@ -1093,11 +1093,15 @@ def test_content_moderation_gpt_35():
assert prompt_node(VIOLENT_TEXT) == []
# case 2: prompt passes the moderation check but the generated output fails the check
# function should also return an empty list
mock_request.return_value = {"choices": [{"text": VIOLENT_TEXT, "finish_reason": ""}]}
mock_request.return_value = {
"choices": [{"message": {"content": VIOLENT_TEXT, "role": "assistant"}, "finish_reason": ""}]
}
assert prompt_node("normal prompt") == []
# case 3: both prompt and output pass the moderation check
# function should return the output
mock_request.return_value = {"choices": [{"text": "normal output", "finish_reason": ""}]}
mock_request.return_value = {
"choices": [{"message": {"content": "normal output", "role": "assistant"}, "finish_reason": ""}]
}
assert prompt_node("normal prompt") == ["normal output"]


Expand Down

0 comments on commit ccc9f01

Please sign in to comment.