Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extends input types of RemoteWhisperTranscriber #6218

Merged
merged 19 commits into from
Nov 22, 2023
Merged
19 changes: 13 additions & 6 deletions haystack/preview/components/audio/whisper_remote.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import io
import logging
import os
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
from pathlib import Path

import openai

Expand Down Expand Up @@ -111,7 +112,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "RemoteWhisperTranscriber":
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(self, streams: List[ByteStream]):
def run(self, sources: List[Union[str, Path, ByteStream]]):
"""
Transcribe the audio files into a list of Documents, one for each input file.

Expand All @@ -124,11 +125,17 @@ def run(self, streams: List[ByteStream]):
"""
documents = []

for stream in streams:
file = io.BytesIO(stream.data)
file.name = stream.metadata.get("file_path", "audio_input.wav") # default name if `file_path` not found
for source in sources:
if not isinstance(source, ByteStream):
path = source
source = ByteStream.from_file_path(Path(source))
source.metadata["file_path"] = path
ZanSara marked this conversation as resolved.
Show resolved Hide resolved

file = io.BytesIO(source.data)
file.name = str(source.metadata["file_path"]) if "file_path" in source.metadata else "__fallback__.wav"

content = openai.Audio.transcribe(file=file, model=self.model_name, **self.whisper_params)
doc = Document(content=content["text"], meta=stream.metadata)
doc = Document(content=content["text"], meta=source.metadata)
documents.append(doc)

return {"documents": documents}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
preview:
- Extends input types of RemoteWhisperTranscriber from List[ByteStream] to List[Union[str, Path, ByteStream]] to make possible to connect it to FileTypeRouter.
53 changes: 34 additions & 19 deletions test/preview/components/audio/test_whisper_remote.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from unittest.mock import patch
from pathlib import Path

import openai
import pytest
Expand Down Expand Up @@ -182,7 +183,33 @@ def test_from_dict_with_defualt_parameters_no_env_var(self, monkeypatch):
RemoteWhisperTranscriber.from_dict(data)

@pytest.mark.unit
def test_run(self, preview_samples_path):
def test_run_str(self, preview_samples_path):
with patch("haystack.preview.components.audio.whisper_remote.openai.Audio") as openai_audio_patch:
model = "whisper-1"
file_path = str(preview_samples_path / "audio" / "this is the content of the document.wav")
openai_audio_patch.transcribe.side_effect = mock_openai_response

transcriber = RemoteWhisperTranscriber(api_key="test_api_key", model_name=model, response_format="json")
result = transcriber.run(sources=[file_path])

assert result["documents"][0].content == "test transcription"
assert result["documents"][0].meta["file_path"] == file_path

@pytest.mark.unit
def test_run_path(self, preview_samples_path):
with patch("haystack.preview.components.audio.whisper_remote.openai.Audio") as openai_audio_patch:
model = "whisper-1"
file_path = preview_samples_path / "audio" / "this is the content of the document.wav"
openai_audio_patch.transcribe.side_effect = mock_openai_response

transcriber = RemoteWhisperTranscriber(api_key="test_api_key", model_name=model, response_format="json")
result = transcriber.run(sources=[file_path])

assert result["documents"][0].content == "test transcription"
assert result["documents"][0].meta["file_path"] == file_path

@pytest.mark.unit
def test_run_bytestream(self, preview_samples_path):
with patch("haystack.preview.components.audio.whisper_remote.openai.Audio") as openai_audio_patch:
model = "whisper-1"
file_path = preview_samples_path / "audio" / "this is the content of the document.wav"
Expand All @@ -193,7 +220,7 @@ def test_run(self, preview_samples_path):
byte_stream = audio_stream.read()
audio_file = ByteStream(byte_stream, metadata={"file_path": str(file_path.absolute())})

result = transcriber.run(streams=[audio_file])
result = transcriber.run(sources=[audio_file])

assert result["documents"][0].content == "test transcription"
assert result["documents"][0].meta["file_path"] == str(file_path.absolute())
Expand All @@ -208,32 +235,20 @@ def test_whisper_remote_transcriber(self, preview_samples_path):

paths = [
preview_samples_path / "audio" / "this is the content of the document.wav",
preview_samples_path / "audio" / "the context for this answer is here.wav",
preview_samples_path / "audio" / "answer.wav",
str(preview_samples_path / "audio" / "the context for this answer is here.wav"),
ByteStream.from_file_path(preview_samples_path / "audio" / "answer.wav"),
ZanSara marked this conversation as resolved.
Show resolved Hide resolved
]

audio_files = []
for file_path in paths:
with open(file_path, "rb") as audio_stream:
byte_stream = audio_stream.read()
audio_file = ByteStream(byte_stream, metadata={"file_path": str(file_path.absolute())})
audio_files.append(audio_file)

output = transcriber.run(streams=audio_files)
output = transcriber.run(sources=paths)

docs = output["documents"]
assert len(docs) == 3
assert docs[0].content.strip().lower() == "this is the content of the document."
assert (
str((preview_samples_path / "audio" / "this is the content of the document.wav").absolute())
== docs[0].meta["file_path"]
)
assert preview_samples_path / "audio" / "this is the content of the document.wav" == docs[0].meta["file_path"]

assert docs[1].content.strip().lower() == "the context for this answer is here."
assert (
str((preview_samples_path / "audio" / "the context for this answer is here.wav").absolute())
== docs[1].meta["file_path"]
str(preview_samples_path / "audio" / "the context for this answer is here.wav") == docs[1].meta["file_path"]
)

assert docs[2].content.strip().lower() == "answer."
assert str((preview_samples_path / "audio" / "answer.wav").absolute()) == docs[2].meta["file_path"]