Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: HuggingFaceAPIChatGenerator #7480

Merged
merged 30 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
fd4b333
draft
anakin87 Apr 4, 2024
f27bc9c
docstrings and more tests
anakin87 Apr 4, 2024
3eab79a
deprecation; reno
anakin87 Apr 4, 2024
69f731a
pydoc config
anakin87 Apr 4, 2024
04d84e4
better error messages
anakin87 Apr 4, 2024
0c6c982
wip
anakin87 Apr 4, 2024
5d3a0ec
add test
anakin87 Apr 4, 2024
b6275ba
better docstrings
anakin87 Apr 4, 2024
7499e52
deprecation; reno
anakin87 Apr 4, 2024
9d4acfb
pylint
anakin87 Apr 4, 2024
d1db792
typo
anakin87 Apr 4, 2024
d2e6c61
rm unneeded else
anakin87 Apr 5, 2024
f89630d
Merge branch 'hfapigenerator' into hfapichatgenerator
anakin87 Apr 5, 2024
08935d2
rm unneeded else
anakin87 Apr 5, 2024
b676a05
fixes from feedback
anakin87 Apr 5, 2024
58481c6
docstring showing the enum
anakin87 Apr 5, 2024
e174b4f
improve docstring
anakin87 Apr 5, 2024
d5507fb
make params mandatory
anakin87 Apr 5, 2024
537008b
Apply suggestions from code review
anakin87 Apr 5, 2024
430c7d6
document enum
anakin87 Apr 5, 2024
2f2c1cd
Merge branch 'hfapigenerator' into hfapichatgenerator
anakin87 Apr 5, 2024
9f2c9c0
Update haystack/utils/hf.py
anakin87 Apr 5, 2024
4104043
mandatory params
anakin87 Apr 5, 2024
8253515
Merge branch 'main' into hfapigenerator
anakin87 Apr 5, 2024
1c4f00f
Merge branch 'hfapigenerator' into hfapichatgenerator
anakin87 Apr 5, 2024
61b3374
fix test
anakin87 Apr 5, 2024
f006160
Merge branch 'hfapichatgenerator' of https://github.com/deepset-ai/ha…
anakin87 Apr 5, 2024
b8e9984
fix test
anakin87 Apr 5, 2024
b893ba8
Merge branch 'main' into hfapigenerator
anakin87 Apr 5, 2024
aa6774b
Merge branch 'hfapigenerator' into hfapichatgenerator
anakin87 Apr 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/pydoc/config/generators_api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ loaders:
"azure",
"hugging_face_local",
"hugging_face_tgi",
"hugging_face_api",
"openai",
"chat/azure",
"chat/hugging_face_local",
"chat/hugging_face_tgi",
"chat/hugging_face_api",
"chat/openai",
]
ignore_when_discovered: ["__init__"]
Expand Down
9 changes: 8 additions & 1 deletion haystack/components/generators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,12 @@
from haystack.components.generators.azure import AzureOpenAIGenerator
from haystack.components.generators.hugging_face_local import HuggingFaceLocalGenerator
from haystack.components.generators.hugging_face_tgi import HuggingFaceTGIGenerator
from haystack.components.generators.hugging_face_api import HuggingFaceAPIGenerator

__all__ = ["HuggingFaceLocalGenerator", "HuggingFaceTGIGenerator", "OpenAIGenerator", "AzureOpenAIGenerator"]
__all__ = [
"HuggingFaceLocalGenerator",
"HuggingFaceTGIGenerator",
"HuggingFaceAPIGenerator",
"OpenAIGenerator",
"AzureOpenAIGenerator",
]
2 changes: 2 additions & 0 deletions haystack/components/generators/chat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from haystack.components.generators.chat.azure import AzureOpenAIChatGenerator
from haystack.components.generators.chat.hugging_face_local import HuggingFaceLocalChatGenerator
from haystack.components.generators.chat.hugging_face_tgi import HuggingFaceTGIChatGenerator
from haystack.components.generators.chat.hugging_face_api import HuggingFaceAPIChatGenerator

__all__ = [
"HuggingFaceLocalChatGenerator",
"HuggingFaceTGIChatGenerator",
"HuggingFaceAPIChatGenerator",
"OpenAIChatGenerator",
"AzureOpenAIChatGenerator",
]
237 changes: 237 additions & 0 deletions haystack/components/generators/chat/hugging_face_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
from typing import Any, Callable, Dict, Iterable, List, Optional, Union

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.lazy_imports import LazyImport
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model
from haystack.utils.url_validation import is_valid_http_url

with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.22.0\"'") as huggingface_hub_import:
from huggingface_hub import ChatCompletionOutput, ChatCompletionStreamOutput, InferenceClient


logger = logging.getLogger(__name__)


@component
class HuggingFaceAPIChatGenerator:
"""
This component can be used to generate text using different Hugging Face APIs with the ChatMessage format:
- [free Serverless Inference API](https://huggingface.co/inference-api)
- [paid Inference Endpoints](https://huggingface.co/inference-endpoints)
- [self-hosted Text Generation Inference](https://github.com/huggingface/text-generation-inference)

Input and Output Format:
- ChatMessage Format: This component uses the ChatMessage format to structure both input and output,
ensuring coherent and contextually relevant responses in chat-based text generation scenarios. Details on the
ChatMessage format can be found [here](https://docs.haystack.deepset.ai/docs/data-classes#chatmessage).


Example usage with the free Serverless Inference API:
```python
from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
from haystack.dataclasses import ChatMessage
from haystack.utils import Secret
from haystack.utils.hf import HFGenerationAPIType

messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
ChatMessage.from_user("What's Natural Language Processing?")]

# the api_type can be expressed using the HFGenerationAPIType enum or as a string
api_type = HFGenerationAPIType.SERVERLESS_INFERENCE_API
api_type = "serverless_inference_api" # this is equivalent to the above

generator = HuggingFaceAPIChatGenerator(api_type=api_type,
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
token=Secret.from_token("<your-api-key>"))

result = generator.run(messages)
print(result)
```

Example usage with paid Inference Endpoints:
```python
from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
from haystack.dataclasses import ChatMessage
from haystack.utils import Secret

messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
ChatMessage.from_user("What's Natural Language Processing?")]

generator = HuggingFaceAPIChatGenerator(api_type="inference_endpoints",
api_params={"url": "<your-inference-endpoint-url>"},
token=Secret.from_token("<your-api-key>"))

result = generator.run(messages)
print(result)

Example usage with self-hosted Text Generation Inference:
```python
from haystack.components.generators.chat import HuggingFaceAPIChatGenerator
from haystack.dataclasses import ChatMessage

messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
ChatMessage.from_user("What's Natural Language Processing?")]

generator = HuggingFaceAPIChatGenerator(api_type="text_generation_inference",
anakin87 marked this conversation as resolved.
Show resolved Hide resolved
api_params={"url": "http://localhost:8080"})

result = generator.run(messages)
print(result)
```
"""

def __init__(
self,
api_type: Union[HFGenerationAPIType, str] = HFGenerationAPIType.SERVERLESS_INFERENCE_API,
api_params: Optional[Dict[str, str]] = None,
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
generation_kwargs: Optional[Dict[str, Any]] = None,
stop_words: Optional[List[str]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
):
"""
Initialize the HuggingFaceAPIChatGenerator instance.

:param api_type:
The type of Hugging Face API to use.
:param api_params:
A dictionary containing the following keys:
- `model`: model ID on the Hugging Face Hub. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
- `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or `TEXT_GENERATION_INFERENCE`.
:param token: The HuggingFace token to use as HTTP bearer authorization
You can find your HF token in your [account settings](https://huggingface.co/settings/tokens)
:param generation_kwargs:
A dictionary containing keyword arguments to customize text generation.
Some examples: `max_tokens`, `temperature`, `top_p`...
See Hugging Face's documentation for more information at: [chat_completion](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion).
:param stop_words: An optional list of strings representing the stop words.
:param streaming_callback: An optional callable for handling streaming responses.
"""

huggingface_hub_import.check()

if isinstance(api_type, str):
api_type = HFGenerationAPIType.from_str(api_type)

api_params = api_params or {}

if api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API:
model = api_params.get("model")
if model is None:
raise ValueError(
"To use the Serverless Inference API, you need to specify the `model` parameter in `api_params`."
)
check_valid_model(model, HFModelType.GENERATION, token)
model_or_url = model
elif api_type in [HFGenerationAPIType.INFERENCE_ENDPOINTS, HFGenerationAPIType.TEXT_GENERATION_INFERENCE]:
url = api_params.get("url")
if url is None:
raise ValueError(
"To use Text Generation Inference or Inference Endpoints, you need to specify the `url` parameter in `api_params`."
)
if not is_valid_http_url(url):
raise ValueError(f"Invalid URL: {url}")
model_or_url = url

# handle generation kwargs setup
generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
generation_kwargs["stop"] = generation_kwargs.get("stop", [])
generation_kwargs["stop"].extend(stop_words or [])
anakin87 marked this conversation as resolved.
Show resolved Hide resolved
generation_kwargs.setdefault("max_tokens", 512)

self.api_type = api_type
self.api_params = api_params
self.token = token
self.generation_kwargs = generation_kwargs
self.streaming_callback = streaming_callback
self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None)

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.

:returns:
A dictionary containing the serialized component.
"""
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
return default_to_dict(
self,
api_type=self.api_type,
api_params=self.api_params,
token=self.token.to_dict() if self.token else None,
generation_kwargs=self.generation_kwargs,
streaming_callback=callback_name,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPIChatGenerator":
"""
Deserialize this component from a dictionary.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler:
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
return default_from_dict(cls, data)

@component.output_types(replies=List[ChatMessage])
def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None):
"""
Invoke the text generation inference based on the provided messages and generation parameters.

:param messages: A list of ChatMessage instances representing the input messages.
:param generation_kwargs: Additional keyword arguments for text generation.
:return: A list containing the generated responses as ChatMessage instances.
anakin87 marked this conversation as resolved.
Show resolved Hide resolved
"""
anakin87 marked this conversation as resolved.
Show resolved Hide resolved

# update generation kwargs by merging with the default ones
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}

formatted_messages = [m.to_openai_format() for m in messages]

if self.streaming_callback:
return self._run_streaming(formatted_messages, generation_kwargs)

return self._run_non_streaming(formatted_messages, generation_kwargs)

def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any]):
api_output: Iterable[ChatCompletionStreamOutput] = self._client.chat_completion(
messages, stream=True, **generation_kwargs
)

generated_text = ""

for chunk in api_output: # pylint: disable=not-an-iterable
text = chunk.choices[0].delta.content
if text:
generated_text += text
finish_reason = chunk.choices[0].finish_reason

meta = {}
if finish_reason:
meta["finish_reason"] = finish_reason

stream_chunk = StreamingChunk(text, meta)
self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method)

message = ChatMessage.from_assistant(generated_text)
message.meta.update({"model": self._client.model, "finish_reason": finish_reason, "index": 0})
return {"replies": [message]}

def _run_non_streaming(
self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any]
) -> Dict[str, List[ChatMessage]]:
chat_messages: List[ChatMessage] = []

api_chat_output: ChatCompletionOutput = self._client.chat_completion(messages, **generation_kwargs)

for choice in api_chat_output.choices:
message = ChatMessage.from_assistant(choice.message.content)
message.meta.update(
{"model": self._client.model, "finish_reason": choice.finish_reason, "index": choice.index}
)
chat_messages.append(message)
return {"replies": chat_messages}
6 changes: 6 additions & 0 deletions haystack/components/generators/chat/hugging_face_tgi.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from dataclasses import asdict
from typing import Any, Callable, Dict, Iterable, List, Optional
from urllib.parse import urlparse
Expand Down Expand Up @@ -113,6 +114,11 @@ def __init__(
:param stop_words: An optional list of strings representing the stop words.
:param streaming_callback: An optional callable for handling streaming responses.
"""
warnings.warn(
"`HuggingFaceTGIChatGenerator` is deprecated and will be removed in Haystack 2.3.0."
"Use `HuggingFaceAPIChatGenerator` instead.",
DeprecationWarning,
)
transformers_import.check()

if url:
Expand Down
Loading