Skip to content

Commit

Permalink
Merge branch 'main' into 851-bugapi-cannot-index-files-to-vector-stor…
Browse files Browse the repository at this point in the history
…e-with-api-key-auth
  • Loading branch information
gphorvath authored Jul 31, 2024
2 parents 2d2fe6b + 4e8092a commit 65a3c40
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 8 deletions.
26 changes: 24 additions & 2 deletions packages/whisper/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,32 @@ def make_transcribe_request(filename, task, language, temperature, prompt):
device = "cuda" if GPU_ENABLED else "cpu"
model = WhisperModel(model_path, device=device, compute_type="float32")

segments, info = model.transcribe(filename, task=task, beam_size=5)
# Prepare kwargs with non-None values
kwargs = {}
if task:
if task in ["transcribe", "translate"]:
kwargs["task"] = task
else:
logger.error(f"Task {task} is not supported")
return {"text": ""}
if language:
if language in model.supported_languages:
kwargs["language"] = language
else:
logger.error(f"Language {language} is not supported")
if temperature:
kwargs["temperature"] = temperature
if prompt:
kwargs["initial_prompt"] = prompt

try:
# Call transcribe with only non-None parameters
segments, info = model.transcribe(filename, beam_size=5, **kwargs)
except Exception as e:
logger.error(f"Error transcribing audio: {e}")
return {"text": ""}

output = ""

for segment in segments:
output += segment.text

Expand Down
2 changes: 1 addition & 1 deletion packages/whisper/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ version = "0.9.2"
# x-release-please-end

dependencies = [
"faster-whisper == 0.10.0",
"faster-whisper == 1.0.3",
"leapfrogai-sdk",
]
requires-python = "~=3.11"
Expand Down
Empty file added tests/e2e/__init__.py
Empty file.
10 changes: 5 additions & 5 deletions tests/e2e/test_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def test_transcriptions():
timestamp_granularities=["word", "segment"],
)

assert len(transcription.text) > 0 # The transcription should not be empty
assert len(transcription.text) < 500 # The transcription should not be too long
assert len(transcription.text) > 0, "The transcription should not be empty"
assert len(transcription.text) < 500, "The transcription should not be too long"


def test_translations():
Expand All @@ -65,8 +65,8 @@ def test_translations():
temperature=0.3,
)

assert len(translation.text) > 0 # The translation should not be empty
assert len(translation.text) < 500 # The translation should not be too long
assert len(translation.text) > 0, "The translation should not be empty"
assert len(translation.text) < 500, "The translation should not be too long"

def is_english_or_punctuation(c):
if c in string.punctuation or c.isspace():
Expand All @@ -78,4 +78,4 @@ def is_english_or_punctuation(c):

english_chars = [is_english_or_punctuation(c) for c in translation.text]

assert all(english_chars) # Check that only English characters are returned
assert all(english_chars), "Non-English characters have been returned"

0 comments on commit 65a3c40

Please sign in to comment.