Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Migrate RemoteWhisperTranscriber to OpenAI SDK. #6149

Merged
183 changes: 94 additions & 89 deletions haystack/preview/components/audio/whisper_remote.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
from typing import List, Optional, Dict, Any, Union, BinaryIO, Literal, get_args, Sequence

import os
import json
import io
import logging
from pathlib import Path

from haystack.preview.utils import request_with_retry
from haystack.preview import component, Document, default_to_dict
import os
from typing import Any, Dict, List, Optional

logger = logging.getLogger(__name__)
import openai

from haystack.preview import Document, component, default_from_dict, default_to_dict
from haystack.preview.dataclasses import ByteStream

OPENAI_TIMEOUT = float(os.environ.get("HAYSTACK_OPENAI_TIMEOUT_SEC", 600))
logger = logging.getLogger(__name__)


WhisperRemoteModel = Literal["whisper-1"]
API_BASE_URL = "https://api.openai.com/v1"


@component
Expand All @@ -30,108 +27,116 @@ class RemoteWhisperTranscriber:

def __init__(
self,
api_key: str,
model_name: WhisperRemoteModel = "whisper-1",
api_base: str = "https://api.openai.com/v1",
whisper_params: Optional[Dict[str, Any]] = None,
api_key: Optional[str] = None,
model_name: str = "whisper-1",
organization: Optional[str] = None,
api_base_url: str = API_BASE_URL,
**kwargs,
):
"""
Transcribes a list of audio files into a list of Documents.

:param api_key: OpenAI API key.
:param model_name: Name of the model to use. It now accepts only `whisper-1`.
:param organization: The OpenAI-Organization ID, defaults to `None`. For more details, see OpenAI
[documentation](https://platform.openai.com/docs/api-reference/requesting-organization).
:param api_base: OpenAI base URL, defaults to `"https://api.openai.com/v1"`.
:param kwargs: Other parameters to use for the model. These parameters are all sent directly to the OpenAI
endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/audio) for more details.
Some of the supported parameters:
- `language`: The language of the input audio.
Supplying the input language in ISO-639-1 format
will improve accuracy and latency.
- `prompt`: An optional text to guide the model's
style or continue a previous audio segment.
The prompt should match the audio language.
- `response_format`: The format of the transcript
output, in one of these options: json, text, srt,
verbose_json, or vtt. Defaults to "json". Currently only "json" is supported.
- `temperature`: The sampling temperature, between 0
and 1. Higher values like 0.8 will make the output more
random, while lower values like 0.2 will make it more
focused and deterministic. If set to 0, the model will
use log probability to automatically increase the
temperature until certain thresholds are hit.
"""
if model_name not in get_args(WhisperRemoteModel):
raise ValueError(
f"Model name not recognized. Choose one among: " f"{', '.join(get_args(WhisperRemoteModel))}."
)
if not api_key:
raise ValueError("API key is None.")

# if the user does not provide the API key, check if it is set in the module client
api_key = api_key or openai.api_key
if api_key is None:
try:
api_key = os.environ["OPENAI_API_KEY"]
except KeyError as e:
raise ValueError(
"RemoteWhisperTranscriber expects an OpenAI API key. "
"Set the OPENAI_API_KEY environment variable (recommended) or pass it explicitly."
) from e
awinml marked this conversation as resolved.
Show resolved Hide resolved
openai.api_key = api_key

self.organization = organization
self.model_name = model_name
self.api_key = api_key
self.api_base = api_base
self.whisper_params = whisper_params or {}
self.api_base_url = api_base_url

@component.output_types(documents=List[Document])
def run(self, audio_files: List[Path], whisper_params: Optional[Dict[str, Any]] = None):
"""
Transcribe the audio files into a list of Documents, one for each input file.
# Only response_format = "json" is supported
whisper_params = kwargs
if whisper_params.get("response_format") != "json":
logger.warning(
"RemoteWhisperTranscriber only supports 'response_format: json'. This parameter will be overwritten."
)
whisper_params["response_format"] = "json"
awinml marked this conversation as resolved.
Show resolved Hide resolved
self.whisper_params = whisper_params

For the supported audio formats, languages, and other parameters, see the
[Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper
[github repo](https://github.com/openai/whisper).
if organization is not None:
openai.organization = organization

:param audio_files: a list of paths or binary streams to transcribe
:returns: a list of Documents, one for each file. The content of the document is the transcription text,
while the document's metadata contains all the other values returned by the Whisper model, such as the
alignment data. Another key called `audio_file` contains the path to the audio file used for the
transcription.
def to_dict(self) -> Dict[str, Any]:
"""
if whisper_params is None:
whisper_params = self.whisper_params

documents = self.transcribe(audio_files, **whisper_params)
return {"documents": documents}

def transcribe(self, audio_files: Sequence[Union[str, Path, BinaryIO]], **kwargs) -> List[Document]:
Serialize this component to a dictionary.
This method overrides the default serializer in order to
avoid leaking the `api_key` value passed to the constructor.
"""
Transcribe the audio files into a list of Documents, one for each input file.

For the supported audio formats, languages, and other parameters, see the
[Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper
[github repo](https://github.com/openai/whisper).
return default_to_dict(
self,
model_name=self.model_name,
organization=self.organization,
api_base_url=self.api_base_url,
**self.whisper_params,
)

:param audio_files: a list of paths or binary streams to transcribe
:returns: a list of transcriptions.
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "RemoteWhisperTranscriber":
"""
transcriptions = self._raw_transcribe(audio_files=audio_files, **kwargs)
documents = []
for audio, transcript in zip(audio_files, transcriptions):
content = transcript.pop("text")
if not isinstance(audio, (str, Path)):
audio = "<<binary stream>>"
doc = Document(text=content, metadata={"audio_file": audio, **transcript})
documents.append(doc)
return documents
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)

def _raw_transcribe(self, audio_files: Sequence[Union[str, Path, BinaryIO]], **kwargs) -> List[Dict[str, Any]]:
@component.output_types(documents=List[Document])
def run(self, streams: List[ByteStream]):
"""
Transcribe the given audio files. Returns a list of strings.
Transcribe the audio files into a list of Documents, one for each input file.

For the supported audio formats, languages, and other parameters, see the
[Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper
[github repo](https://github.com/openai/whisper).

:param audio_files: a list of paths or binary streams to transcribe.
:param kwargs: any other parameters that Whisper API can understand.
:returns: a list of transcriptions as they are produced by the Whisper API (JSON).
:param audio_files: a list of ByteStream objects to transcribe.
:returns: a list of Documents, one for each file. The content of the document is the transcription text.
"""
translate = kwargs.pop("translate", False)
url = f"{self.api_base}/audio/{'translations' if translate else 'transcriptions'}"
data = {"model": self.model_name, **kwargs}
headers = {"Authorization": f"Bearer {self.api_key}"}

transcriptions = []
for audio_file in audio_files:
if isinstance(audio_file, (str, Path)):
audio_file = open(audio_file, "rb")

request_files = ("file", (audio_file.name, audio_file, "application/octet-stream"))
response = request_with_retry(
method="post", url=url, data=data, headers=headers, files=[request_files], timeout=OPENAI_TIMEOUT
)
transcription = json.loads(response.content)
documents = []

transcriptions.append(transcription)
return transcriptions
for stream in streams:
try:
file = io.BytesIO(stream.data)
file.name = stream.metadata["file_path"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awinml yes, let's do a check here if stream.metadata["file_path"] is present. If it is, use it. If not, just use a random name, e.g. audio_input , perhaps an extension is not even needed. Please check.

Copy link
Contributor Author

@awinml awinml Oct 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried the API without using an extension, it does not allow that. I updated the example notebook with a test.

I think, we can use audio_input.wav as the name if stream.metadata["file_path"] is not present. Should we also log a warning if we do this?

Something like this:

for stream in streams:
    file = io.BytesIO(stream.data)
    try:
        file.name = stream.metadata["file_path"]
    except KeyError as e:
        file.name = "audio_input.wav"
        warning_msg = """Did not find 'file_path', setting 'file_name' to 'audio_input.wav'."""
        logger.warning(warning_msg, e)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't. That file name is not important at all, as far as I can tell, and we needlessly scare users. wdyt @ZanSara @awinml

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, I will push the changes without the warning then.

except KeyError as e:
logger.warning(
"Could not read audio file. Skipping it. Make sure the 'file_path' is present in the metadata. Error message: %s",
e,
)
continue

content = openai.Audio.transcribe(file=file, model=self.model_name, **self.whisper_params)
doc = Document(text=content["text"], metadata=stream.metadata)
documents.append(doc)

def to_dict(self) -> Dict[str, Any]:
"""
This method overrides the default serializer in order to avoid leaking the `api_key` value passed
to the constructor.
"""
return default_to_dict(
self, model_name=self.model_name, api_base=self.api_base, whisper_params=self.whisper_params
)
return {"documents": documents}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
preview:
- |
Migrate RemoteWhisperTranscriber to OpenAI SDK.
Loading