-
Notifications
You must be signed in to change notification settings - Fork 13
/
app.py
173 lines (145 loc) · 5.41 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import argparse
import io
from typing import Any, Dict, List, Tuple, Union
import moviepy.editor as mp
import numpy as np
import streamlit as st
import torch
import torchaudio
import torchaudio.transforms as T
from scipy.io import wavfile
from streamlit_mic_recorder import mic_recorder
from transformers import (
AutomaticSpeechRecognitionPipeline,
AutoModelForSpeechSeq2Seq,
AutoProcessor,
pipeline,
)
def parse_arguments() -> argparse.Namespace:
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(
description="Streamlit app for speech transcription."
)
parser.add_argument(
"--model_id", type=str, required=True, help="Path to the model directory"
)
return parser.parse_args()
# Load model and processor from the specified path
@st.cache_resource # type: ignore
def load_model_and_processor(
model_id: str,
) -> Tuple[AutoModelForSpeechSeq2Seq, AutoProcessor]:
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
model.generation_config.median_filter_width = 3
processor = AutoProcessor.from_pretrained(model_id)
return model, processor
# Setup the pipeline
@st.cache_resource # type: ignore
def setup_pipeline(
_model: AutoModelForSpeechSeq2Seq, _processor: AutoProcessor
) -> AutomaticSpeechRecognitionPipeline:
return pipeline(
"automatic-speech-recognition",
model=_model,
tokenizer=_processor.tokenizer,
feature_extractor=_processor.feature_extractor,
chunk_length_s=30,
batch_size=1,
return_timestamps=True,
torch_dtype=torch_dtype,
device=device,
)
def wav_to_black_mp4(wav_path: str, output_path: str, fps: int = 25) -> None:
"""Convert WAV file to a black-screen MP4 with the same audio."""
waveform, sample_rate = torchaudio.load(wav_path)
duration: float = waveform.shape[1] / sample_rate
audio = mp.AudioFileClip(wav_path)
black_clip = mp.ColorClip((256, 250), color=(0, 0, 0), duration=duration)
final_clip = black_clip.set_audio(audio)
final_clip.write_videofile(output_path, fps=fps)
def timestamps_to_vtt(timestamps: List[Dict[str, Union[str, Any]]]) -> str:
"""Convert timestamps to VTT format."""
vtt_content: str = "WEBVTT\n\n"
for word in timestamps:
start_time, end_time = word["timestamp"]
start_time_str = f"{int(start_time // 3600)}:{int(start_time // 60 % 60):02d}:{start_time % 60:06.3f}"
end_time_str = f"{int(end_time // 3600)}:{int(end_time // 60 % 60):02d}:{end_time % 60:06.3f}"
vtt_content += f"{start_time_str} --> {end_time_str}\n{word['text']}\n\n"
return vtt_content
def process_audio_bytes(audio_bytes: bytes) -> torch.Tensor:
"""Process audio bytes to the required format."""
audio_stream = io.BytesIO(audio_bytes)
sr, y = wavfile.read(audio_stream)
y = y.astype(np.float32)
y_mean = np.mean(y)
y_std = np.std(y)
y_normalized = (y - y_mean) / y_std
transform = T.Resample(sr, 16000)
waveform = transform(torch.unsqueeze(torch.tensor(y_normalized / 8), 0))
torchaudio.save("sample.wav", waveform, sample_rate=16000)
return waveform
def transcribe(audio_bytes: bytes) -> Dict[str, Any]:
"""Transcribe the given audio bytes."""
waveform = process_audio_bytes(audio_bytes)
transcription = pipe(waveform[0, :].numpy(), return_timestamps="word")
return transcription
args = parse_arguments()
model_id = args.model_id
# Set up device and data type for processing
device: str = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype: torch.dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model, processor = load_model_and_processor(model_id)
pipe = setup_pipeline(model, processor)
# Streamlit app interface
st.title("CrisperWhisper++ 🦻")
st.subheader("Caution when using. Make sure you can handle the crispness. ⚠️")
st.write("🎙️ Record an audio to transcribe or 📁 upload an audio file.")
# Audio recorder component
audio = mic_recorder(
start_prompt="Start recording",
stop_prompt="Stop recording",
just_once=False,
use_container_width=False,
format="wav",
callback=None,
args=(),
kwargs={},
key=None,
)
audio_bytes: Union[bytes, None] = audio["bytes"] if audio else None
# Audio file upload handling
audio_file = st.file_uploader("Or upload an audio file", type=["wav", "mp3", "ogg"])
if audio_file is not None:
audio_bytes = audio_file.getvalue()
if audio_bytes:
try:
transcription = transcribe(audio_bytes)
vtt = timestamps_to_vtt(transcription["chunks"])
with open("subtitles.vtt", "w") as file:
file.write(vtt)
wav_to_black_mp4("sample.wav", "video.mp4")
st.video("video.mp4", subtitles="subtitles.vtt")
st.subheader("Transcription:")
st.markdown(
f"""
<div style="background-color: #f0f0f0; padding: 10px; border-radius: 5px;">
<p style="font-size: 16px; color: #333;">{transcription['text']}</p>
</div>
""",
unsafe_allow_html=True,
)
except Exception as e:
st.error(f"An error occurred during transcription: {e}")
# Footer
st.markdown(
"""
<hr>
<footer>
<p style="text-align: center;">© 2024 nyra health GmbH</p>
</footer>
""",
unsafe_allow_html=True,
)