-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
175 lines (144 loc) · 6.5 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import gc
import time
import torch
import pyaudio
import multiprocessing
import sounddevice as sd
from preprocessor import Preprocessor
from streaming_buffer import StreamBuffer
from stt_llm_tts_model import STT_LLM_TTS
def record(audio_buffer, start_recording):
"""Record an audio stream from the microphone in a separate process
Args:
audio_buffer: multiprocessing queue to store the recorded audio data
start_recording: multiprocessing value to start and stop the recording
"""
RATE = 16000
CHUNK = 2048
# Open audio input stream
audio = pyaudio.PyAudio()
streamIn = audio.open(format=pyaudio.paFloat32, channels=1,
rate=RATE, input=True, input_device_index=0,
frames_per_buffer=CHUNK)
while(True):
try:
# start_recording is set to 1 in the main loop to start the recording
if start_recording == 0:
time.sleep(0.1)
continue
# read a chunk of fixed size from the input stream and add it to the input buffer
data = streamIn.read(CHUNK, exception_on_overflow=False)
audio_buffer.put(data)
except KeyboardInterrupt:
return
except Exception as e:
raise e
def play_audio(audio_output_buffer):
"""Play synthesized audio data in a separate process
Args:
audio_output_buffer: multiprocessing-queue to receive audio data
"""
fs = 24000
while(True):
# get next audio data
wav = audio_output_buffer.get()
# play the audio and wait until it is finished (only this sub process is blocked, not the main loop)
sd.play(wav, fs, blocking=True)
def flush():
"""Flush Cuda cache to prevent side effect and slowdowns
"""
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
def main_loop(streaming_buffer, model, audio_input_buffer, audio_output_buffer, start_recording):
"""Wait for audio input, call voice assistant model and play synthesized speech
Args:
streaming_buffer: streaming buffer instance to store preprocessed audio chunks
model: instance of STT_LLM_TTS model
audio_input_buffer: multiprocessing queue for audio input
audio_output_buffer: multiprocessing queue for audio output
start_recording: multiprocessing value to start recording of audio chunks
"""
# init preprocessor and streaming iterator
preprocessor = Preprocessor()
streaming_buffer_iter = iter(streaming_buffer)
# send signal to recording process to start the recording
start_recording.value = 1
# control buffer stream id for first chunk
first = True
# start main loop
while True:
# get as many audio chunks from the buffer as possible. If the buffer is empty, an exception is thrown
# and the inner loop breaks
while True:
# select stream id (-1) for first chunk (0) else
if first:
stream_id = -1
first = False
else:
stream_id = 0
# try to get the next audio chunk, if buffer is empty an exception is thrown
try:
# get audio data from buffer
data = audio_input_buffer.get(block=False)
# preprocess audio data
t = torch.frombuffer(data, dtype=torch.float32)
t = torch.unsqueeze(t,0)
length = torch.tensor([t.shape[1]], dtype=torch.float32)
processed_signal, _ = preprocessor(t, length)
# add processed audio chunks to the streaming buffer
streaming_buffer.append_processed_signal(processed_signal, stream_id=stream_id)
except Exception as e:
# leave inner loop and process received data
break
# check if enough audio chunks were recorded for a forward path
if streaming_buffer.buffer is not None and streaming_buffer.buffer.size(-1) > streaming_buffer.buffer_idx + streaming_buffer.shift_size:
# --> enough chunks are available
# get preprocessed audio chunks from buffer
data = next(streaming_buffer_iter, None)
if data is None:
break
chunk_audio, chunk_lengths = data
# call model and pass preprocessed audio data
chunk_audio = chunk_audio.to("cuda")
chunk_lengths = chunk_lengths.to("cuda")
text, wav, interrupt = model(chunk_audio, chunk_lengths)
else:
# --> not enough chunks. Call model with empty input to generate text
text, wav, interrupt = model(None, None)
# TODO: Implement interrup behavior to stop audio process when user starts speaking
# model return is None except when a new sentence is generated and synthesized
if text is not None:
# --> A new sentence is finished
print(text.replace("\n", ""))
# Put synthesized audio to output buffer which will be played by the play-audio process
audio_output_buffer.put(wav)
time.sleep(0.001) # TODO Is this really needed?
def main():
""" Start processes for recording and audio output, initialize voice assist model and start main loop
"""
# !! Make sure to start multiprocessing before using any pytorch tensors to prevent GPU memory problems !!
# start multiprocesses for sound input
audio_buffer = multiprocessing.Queue()
start_recording = multiprocessing.Value('i', 0)
record_process = multiprocessing.Process(target=record, args=(audio_buffer,start_recording))
record_process.start()
# start multiprocesses for sound output
audio_output_buffer = multiprocessing.Queue()
play_audio_process = multiprocessing.Process(target=play_audio, args=(audio_output_buffer,))
play_audio_process.start()
# initialize buffer for processed audio input
streaming_buffer = StreamBuffer(chunk_size=16, shift_size=16)
# get device
if torch.cuda.is_available():
device = 'cuda'
# flush GPU memory
flush()
else:
device = 'cpu'
# init STT-LLM-TTS pipeline
model = STT_LLM_TTS(device=device)
# start inference
main_loop(streaming_buffer, model, audio_buffer, audio_output_buffer, start_recording)
if __name__ == "__main__":
main()