Skip to content

Commit

Permalink
Merge branch 'main' into add_MarkdownToTextDocument
Browse files Browse the repository at this point in the history
  • Loading branch information
awinml authored Oct 23, 2023
2 parents cdc7c90 + dd24210 commit f91c4ff
Show file tree
Hide file tree
Showing 38 changed files with 542 additions and 224 deletions.
14 changes: 14 additions & 0 deletions haystack/preview/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,17 @@
from canals.errors import DeserializationError, ComponentError
from haystack.preview.pipeline import Pipeline
from haystack.preview.dataclasses import Document, Answer, GeneratedAnswer, ExtractedAnswer


__all__ = [
"component",
"default_from_dict",
"default_to_dict",
"DeserializationError",
"ComponentError",
"Pipeline",
"Document",
"Answer",
"GeneratedAnswer",
"ExtractedAnswer",
]
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def __init__(
:param metadata_fields_to_embed: List of meta fields that should be embedded along with the Document text.
:param embedding_separator: Separator used to concatenate the meta fields to the Document text.
"""

# if the user does not provide the API key, check if it is set in the module client
api_key = api_key or openai.api_key
if api_key is None:
try:
api_key = os.environ["OPENAI_API_KEY"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def __init__(
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
"""

# if the user does not provide the API key, check if it is set in the module client
api_key = api_key or openai.api_key
if api_key is None:
try:
api_key = os.environ["OPENAI_API_KEY"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,7 @@ def run(self, documents: List[Document]):
normalize_embeddings=self.normalize_embeddings,
)

documents_with_embeddings = []
for doc, emb in zip(documents, embeddings):
doc_as_dict = doc.to_dict()
doc_as_dict["embedding"] = emb
documents_with_embeddings.append(Document.from_dict(doc_as_dict))
doc.embedding = emb

return {"documents": documents_with_embeddings}
return {"documents": documents}
22 changes: 17 additions & 5 deletions haystack/preview/components/file_converters/azure.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pathlib import Path
from typing import List, Union, Dict, Any
from typing import List, Union, Dict, Any, Optional
import os

from haystack.preview.lazy_imports import LazyImport
from haystack.preview import component, Document, default_to_dict
Expand All @@ -22,22 +23,33 @@ class AzureOCRDocumentConverter:
to set up your resource.
"""

def __init__(self, endpoint: str, api_key: str, model_id: str = "prebuilt-read"):
def __init__(self, endpoint: str, api_key: Optional[str] = None, model_id: str = "prebuilt-read"):
"""
Create an AzureOCRDocumentConverter component.
:param endpoint: The endpoint of your Azure resource.
:param api_key: The key of your Azure resource.
:param api_key: The key of your Azure resource. It can be
explicitly provided or automatically read from the
environment variable AZURE_AI_API_KEY (recommended).
:param model_id: The model ID of the model you want to use. Please refer to [Azure documentation](https://learn.microsoft.com/en-us/azure/ai-services/document-intelligence/choose-model-feature)
for a list of available models. Default: `"prebuilt-read"`.
"""
azure_import.check()

if api_key is None:
try:
api_key = os.environ["AZURE_AI_API_KEY"]
except KeyError as e:
raise ValueError(
"AzureOCRDocumentConverter expects an Azure Credential key. "
"Set the AZURE_AI_API_KEY environment variable (recommended) or pass it explicitly."
) from e

self.api_key = api_key
self.document_analysis_client = DocumentAnalysisClient(
endpoint=endpoint, credential=AzureKeyCredential(api_key)
)
self.endpoint = endpoint
self.api_key = api_key
self.model_id = model_id

@component.output_types(documents=List[Document], azure=List[Dict])
Expand Down Expand Up @@ -70,7 +82,7 @@ def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(self, endpoint=self.endpoint, api_key=self.api_key, model_id=self.model_id)
return default_to_dict(self, endpoint=self.endpoint, model_id=self.model_id)

@staticmethod
def _convert_azure_result_to_document(result: "AnalyzeResult", file_suffix: str) -> Document:
Expand Down
23 changes: 18 additions & 5 deletions haystack/preview/components/generators/openai/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
from collections import defaultdict
from dataclasses import dataclass, asdict
import os

import openai

Expand Down Expand Up @@ -43,7 +44,7 @@ class GPTGenerator:

def __init__(
self,
api_key: str,
api_key: Optional[str] = None,
model_name: str = "gpt-3.5-turbo",
system_prompt: Optional[str] = None,
streaming_callback: Optional[Callable] = None,
Expand All @@ -53,7 +54,8 @@ def __init__(
"""
Creates an instance of GPTGenerator. Unless specified otherwise in the `model_name`, this is for OpenAI's GPT-3.5 model.
:param api_key: The OpenAI API key.
:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the
environment variable OPENAI_API_KEY (recommended).
:param model_name: The name of the model to use.
:param system_prompt: An additional message to be sent to the LLM at the beginning of each conversation.
Typically, a conversation is formatted with a system message first, followed by alternating messages from
Expand Down Expand Up @@ -84,12 +86,25 @@ def __init__(
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens and the
values are the bias to add to that token.
"""
self.api_key = api_key
# if the user does not provide the API key, check if it is set in the module client
api_key = api_key or openai.api_key
if api_key is None:
try:
api_key = os.environ["OPENAI_API_KEY"]
except KeyError as e:
raise ValueError(
"GPTGenerator expects an OpenAI API key. "
"Set the OPENAI_API_KEY environment variable (recommended) or pass it explicitly."
) from e
openai.api_key = api_key

self.model_name = model_name
self.system_prompt = system_prompt
self.model_parameters = kwargs
self.streaming_callback = streaming_callback

self.api_base_url = api_base_url
openai.api_base = api_base_url

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand All @@ -112,7 +127,6 @@ def to_dict(self) -> Dict[str, Any]:

return default_to_dict(
self,
api_key=self.api_key,
model_name=self.model_name,
system_prompt=self.system_prompt,
streaming_callback=callback_name,
Expand Down Expand Up @@ -155,7 +169,6 @@ def run(self, prompt: str):

completion = openai.ChatCompletion.create(
model=self.model_name,
api_key=self.api_key,
messages=[asdict(message) for message in chat],
stream=self.streaming_callback is not None,
**self.model_parameters,
Expand Down
20 changes: 13 additions & 7 deletions haystack/preview/components/websearch/serper_dev.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import os
import logging
from typing import Dict, List, Optional, Any

Expand Down Expand Up @@ -26,20 +27,29 @@ class SerperDevWebSearch:

def __init__(
self,
api_key: str,
api_key: Optional[str] = None,
top_k: Optional[int] = 10,
allowed_domains: Optional[List[str]] = None,
search_params: Optional[Dict[str, Any]] = None,
):
"""
:param api_key: API key for the SerperDev API.
:param api_key: API key for the SerperDev API. It can be
explicitly provided or automatically read from the
environment variable SERPERDEV_API_KEY (recommended).
:param top_k: Number of documents to return.
:param allowed_domains: List of domains to limit the search to.
:param search_params: Additional parameters passed to the SerperDev API.
For example, you can set 'num' to 20 to increase the number of search results.
See the [Serper Dev website](https://serper.dev/) for more details.
"""
if api_key is None:
try:
api_key = os.environ["SERPERDEV_API_KEY"]
except KeyError as e:
raise ValueError(
"SerperDevWebSearch expects an API key. "
"Set the SERPERDEV_API_KEY environment variable (recommended) or pass it explicitly."
) from e
raise ValueError("API key for SerperDev API must be set.")
self.api_key = api_key
self.top_k = top_k
Expand All @@ -51,11 +61,7 @@ def to_dict(self) -> Dict[str, Any]:
Serialize this component to a dictionary.
"""
return default_to_dict(
self,
api_key=self.api_key,
top_k=self.top_k,
allowed_domains=self.allowed_domains,
search_params=self.search_params,
self, top_k=self.top_k, allowed_domains=self.allowed_domains, search_params=self.search_params
)

@component.output_types(documents=List[Document], links=List[str])
Expand Down
4 changes: 3 additions & 1 deletion haystack/preview/dataclasses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from haystack.preview.dataclasses.document import Document
from haystack.preview.dataclasses.answer import ExtractedAnswer, GeneratedAnswer, Answer
from haystack.preview.dataclasses.byte_stream import ByteStream
from haystack.preview.dataclasses.chat_message import ChatMessage
from haystack.preview.dataclasses.chat_message import ChatRole

__all__ = ["Document", "ExtractedAnswer", "GeneratedAnswer", "Answer", "ByteStream"]
__all__ = ["Document", "ExtractedAnswer", "GeneratedAnswer", "Answer", "ByteStream", "ChatMessage", "ChatRole"]
11 changes: 6 additions & 5 deletions haystack/preview/dataclasses/byte_stream.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Any
from typing import Optional, Dict, Any


@dataclass(frozen=True)
Expand All @@ -11,27 +11,28 @@ class ByteStream:

data: bytes
metadata: Dict[str, Any] = field(default_factory=dict, hash=False)
mime_type: Optional[str] = field(default=None)

def to_file(self, destination_path: Path):
with open(destination_path, "wb") as fd:
fd.write(self.data)

@classmethod
def from_file_path(cls, filepath: Path) -> "ByteStream":
def from_file_path(cls, filepath: Path, mime_type: Optional[str] = None) -> "ByteStream":
"""
Create a ByteStream from the contents read from a file.
:param filepath: A valid path to a file.
"""
with open(filepath, "rb") as fd:
return cls(data=fd.read())
return cls(data=fd.read(), mime_type=mime_type)

@classmethod
def from_string(cls, text: str, encoding: str = "utf-8") -> "ByteStream":
def from_string(cls, text: str, encoding: str = "utf-8", mime_type: Optional[str] = None) -> "ByteStream":
"""
Create a ByteStream encoding a string.
:param text: The string to encode
:param encoding: The encoding used to convert the string into bytes
"""
return cls(data=text.encode(encoding))
return cls(data=text.encode(encoding), mime_type=mime_type)
79 changes: 79 additions & 0 deletions haystack/preview/dataclasses/chat_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, Any, Optional


class ChatRole(str, Enum):
"""Enumeration representing the roles within a chat."""

ASSISTANT = "assistant"
USER = "user"
SYSTEM = "system"
FUNCTION = "function"


@dataclass
class ChatMessage:
"""
Represents a message in a LLM chat conversation.
:param content: The text content of the message.
:param role: The role of the entity sending the message.
:param name: The name of the function being called (only applicable for role FUNCTION).
:param metadata: Additional metadata associated with the message.
"""

content: str
role: ChatRole
name: Optional[str]
metadata: Dict[str, Any] = field(default_factory=dict, hash=False)

def is_from(self, role: ChatRole) -> bool:
"""
Check if the message is from a specific role.
:param role: The role to check against.
:return: True if the message is from the specified role, False otherwise.
"""
return self.role == role

@classmethod
def from_assistant(cls, content: str) -> "ChatMessage":
"""
Create a message from the assistant.
:param content: The text content of the message.
:return: A new ChatMessage instance.
"""
return cls(content, ChatRole.ASSISTANT, None)

@classmethod
def from_user(cls, content: str) -> "ChatMessage":
"""
Create a message from the user.
:param content: The text content of the message.
:return: A new ChatMessage instance.
"""
return cls(content, ChatRole.USER, None)

@classmethod
def from_system(cls, content: str) -> "ChatMessage":
"""
Create a message from the system.
:param content: The text content of the message.
:return: A new ChatMessage instance.
"""
return cls(content, ChatRole.SYSTEM, None)

@classmethod
def from_function(cls, content: str, name: str) -> "ChatMessage":
"""
Create a message from a function call.
:param content: The text content of the message.
:param name: The name of the function being called.
:return: A new ChatMessage instance.
"""
return cls(content, ChatRole.FUNCTION, name)
Loading

0 comments on commit f91c4ff

Please sign in to comment.