diff --git a/.github/workflows/unittests.yml b/.github/workflows/unittests.yml index 384683dd..956c2215 100644 --- a/.github/workflows/unittests.yml +++ b/.github/workflows/unittests.yml @@ -32,6 +32,7 @@ jobs: run: | python -m pip install 'pocketsphinx<5' python -m pip install git+https://github.com/openai/whisper.git soundfile + python -m pip install openai python -m pip install . - name: Test with unittest run: | diff --git a/tests/recognizers/test_whisper.py b/tests/recognizers/test_whisper.py new file mode 100644 index 00000000..f2c8e7fd --- /dev/null +++ b/tests/recognizers/test_whisper.py @@ -0,0 +1,42 @@ +from unittest import TestCase +from unittest.mock import MagicMock, patch + +from speech_recognition import AudioData, Recognizer +from speech_recognition.recognizers import whisper + + +@patch("speech_recognition.recognizers.whisper.os.environ") +@patch("speech_recognition.recognizers.whisper.BytesIO") +@patch("openai.OpenAI") +class RecognizeWhisperApiTestCase(TestCase): + def test_recognize_default_arguments(self, OpenAI, BytesIO, environ): + client = OpenAI.return_value + transcript = client.audio.transcriptions.create.return_value + + recognizer = MagicMock(spec=Recognizer) + audio_data = MagicMock(spec=AudioData) + + actual = whisper.recognize_whisper_api(recognizer, audio_data) + + self.assertEqual(actual, transcript.text) + audio_data.get_wav_data.assert_called_once_with() + BytesIO.assert_called_once_with(audio_data.get_wav_data.return_value) + OpenAI.assert_called_once_with(api_key=None) + client.audio.transcriptions.create.assert_called_once_with( + file=BytesIO.return_value, model="whisper-1" + ) + + def test_recognize_pass_arguments(self, OpenAI, BytesIO, environ): + client = OpenAI.return_value + + recognizer = MagicMock(spec=Recognizer) + audio_data = MagicMock(spec=AudioData) + + actual = whisper.recognize_whisper_api( + recognizer, audio_data, model="x-whisper", api_key="OPENAI_API_KEY" + ) + + OpenAI.assert_called_once_with(api_key="OPENAI_API_KEY") + client.audio.transcriptions.create.assert_called_once_with( + file=BytesIO.return_value, model="x-whisper" + )