Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Pure Python audio chat app with Multimodal Live API #1551

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 248 additions & 0 deletions gemini/gradio-voice/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
import asyncio
import base64
import json
import os
from threading import Event

import gradio as gr
import numpy as np
import websockets.sync.client
from dotenv import load_dotenv
from gradio_webrtc import StreamHandler, WebRTC

Check warning on line 11 in gemini/gradio-voice/app.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`webrtc` is not a recognized word. (unrecognized-spelling)

load_dotenv()


class GeminiConfig:
def __init__(self, api_key):
self.api_key = api_key
self.host = "generativelanguage.googleapis.com"
self.model = "models/gemini-2.0-flash-exp"
self.ws_url = f"wss://{self.host}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={self.api_key}"

class AudioProcessor:
@staticmethod
def encode_audio(data, sample_rate):
encoded = base64.b64encode(data.tobytes()).decode("UTF-8")
return {
"realtimeInput": {
"mediaChunks": [
{
"mimeType": f"audio/pcm;rate={sample_rate}",
"data": encoded,
}
],
},
}

@staticmethod
def process_audio_response(data):
audio_data = base64.b64decode(data)
return np.frombuffer(audio_data, dtype=np.int16)


class GeminiHandler(StreamHandler):
def __init__(
self, expected_layout="mono", output_sample_rate=24000, output_frame_size=480
) -> None:
super().__init__(
expected_layout,
output_sample_rate,
output_frame_size,
input_sample_rate=24000,
)
self.config = None
self.ws = None
self.all_output_data = None
self.audio_processor = AudioProcessor()
self.args_set = Event()

def copy(self):
return GeminiHandler(
expected_layout=self.expected_layout,
output_sample_rate=self.output_sample_rate,
output_frame_size=self.output_frame_size,
)

def _initialize_websocket(self):
assert self.config, "Config not set"
try:
self.ws = websockets.sync.client.connect(self.config.ws_url, timeout=30)
initial_request = {
"setup": {
"model": self.config.model,
}
}
self.ws.send(json.dumps(initial_request))
setup_response = json.loads(self.ws.recv())
print(f"Setup response: {setup_response}")
except websockets.exceptions.WebSocketException as e:
print(f"WebSocket connection failed: {str(e)}")
self.ws = None
except Exception as e:
print(f"Setup failed: {str(e)}")
self.ws = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The error handling in _initialize_websocket could be improved. Currently, both WebSocketException and generic exceptions are caught and printed, but the function doesn't return any indication of failure. This can make it difficult for calling functions to handle connection errors appropriately. Consider raising the caught exceptions after printing the error message, or returning an error status. Additionally, it's a good practice to log the exception details for debugging purposes. How would you modify the code to propagate or handle these errors more effectively?

Suggested change
assert self.config, "Config not set"
try:
self.ws = websockets.sync.client.connect(self.config.ws_url, timeout=30)
initial_request = {
"setup": {
"model": self.config.model,
}
}
self.ws.send(json.dumps(initial_request))
setup_response = json.loads(self.ws.recv())
print(f"Setup response: {setup_response}")
except websockets.exceptions.WebSocketException as e:
print(f"WebSocket connection failed: {str(e)}")
self.ws = None
except Exception as e:
print(f"Setup failed: {str(e)}")
self.ws = None
def _initialize_websocket(self):
assert self.config, "Config not set"
try:
self.ws = websockets.sync.client.connect(self.config.ws_url, timeout=30)
initial_request = {
"setup": {
"model": self.config.model,
}
}
self.ws.send(json.dumps(initial_request))
setup_response = json.loads(self.ws.recv())
print(f"Setup response: {setup_response}")
return setup_response
except websockets.exceptions.WebSocketException as e:
print(f"WebSocket connection failed: {str(e)}")
raise
except Exception as e:
print(f"Setup failed: {str(e)}")
raise


async def fetch_args(
self,
):
if self.channel:
self.channel.send("tick")

def set_args(self, args):
super().set_args(args)
self.args_set.set()

def receive(self, frame: tuple[int, np.ndarray]) -> None:
if not self.channel:
return
if not self.config:
asyncio.run_coroutine_threadsafe(self.fetch_args(), self.loop)
self.args_set.wait()
print("api_key", self.latest_args[-1])
self.config = GeminiConfig(self.latest_args[-1])
try:
if not self.ws:
self._initialize_websocket()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _initialize_websocket function now returns the setup response, or raises an exception if the connection or setup fails. It's important to handle potential errors when calling this function. How would you handle a ConnectionError or ValueError raised by _initialize_websocket?


_, array = frame
array = array.squeeze()
audio_message = self.audio_processor.encode_audio(
array, self.output_sample_rate
)
self.ws.send(json.dumps(audio_message))
except Exception as e:
print(f"Error in receive: {str(e)}")
if self.ws:
self.ws.close()
self.ws = None

def _process_server_content(self, content):
for part in content.get("parts", []):
data = part.get("inlineData", {}).get("data", "")
if data:
audio_array = self.audio_processor.process_audio_response(data)
if self.all_output_data is None:
self.all_output_data = audio_array
else:
self.all_output_data = np.concatenate(
(self.all_output_data, audio_array)
)

while self.all_output_data.shape[-1] >= self.output_frame_size:
yield (
self.output_sample_rate,
self.all_output_data[: self.output_frame_size].reshape(1, -1),
)
self.all_output_data = self.all_output_data[
self.output_frame_size :
]

def generator(self):
while True:
if not self.ws or not self.config:
print("WebSocket not connected")
yield None
continue

try:
message = self.ws.recv(timeout=5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The ws.recv() call could potentially block indefinitely if the server doesn't send a message. Consider adding a timeout to prevent this. How would you handle a timeout error?

msg = json.loads(message)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The json.loads() function can raise a JSONDecodeError if the received message is not valid JSON. It's good practice to handle this exception to prevent unexpected crashes. How would you handle a JSONDecodeError?


if "serverContent" in msg:
content = msg["serverContent"].get("modelTurn", {})
yield from self._process_server_content(content)
except TimeoutError:
print("Timeout waiting for server response")
yield None
except Exception as e:
print(f"Error in generator: {str(e)}")
yield None

def emit(self) -> tuple[int, np.ndarray] | None:
if not self.ws:
return None
if not hasattr(self, "_generator"):
self._generator = self.generator()
try:
return next(self._generator)
except StopIteration:
self.reset()
return None

def reset(self) -> None:
if hasattr(self, "_generator"):
delattr(self, "_generator")
self.all_output_data = None

def shutdown(self) -> None:
if self.ws:
self.ws.close()

def check_connection(self):
try:
if not self.ws or self.ws.closed:
self._initialize_websocket()
return True
except Exception as e:
print(f"Connection check failed: {str(e)}")
return False


class GeminiVoiceChat:
def __init__(self):
self.demo = self._create_interface()

def _create_interface(self):
with gr.Blocks() as demo:
gr.HTML("""
<div style='text-align: center'>
<h1>Gemini 2.0 Voice Chat</h1>
<p>Speak with Gemini using real-time audio streaming</p>
<p>Get a Gemini API key from <a href="https://ai.google.dev/gemini-api/docs/api-key">Google</a></p>
</div>
""")

with gr.Row(visible=True) as api_key_row:
api_key = gr.Textbox(
label="Gemini API Key",
placeholder="Enter your Gemini API Key",
type="password",
)
with gr.Row(visible=False) as row:
webrtc = WebRTC(

Check warning on line 213 in gemini/gradio-voice/app.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`webrtc` is not a recognized word. (unrecognized-spelling)
label="Conversation",
modality="audio",
mode="send-receive",
# See for changes needed to deploy behind a firewall
# https://freddyaboulton.github.io/gradio-webrtc/deployment/
rtc_configuration=None,
)

webrtc.stream(

Check warning on line 222 in gemini/gradio-voice/app.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`webrtc` is not a recognized word. (unrecognized-spelling)
GeminiHandler(),
inputs=[webrtc, api_key],

Check warning on line 224 in gemini/gradio-voice/app.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`webrtc` is not a recognized word. (unrecognized-spelling)
outputs=[webrtc],

Check warning on line 225 in gemini/gradio-voice/app.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`webrtc` is not a recognized word. (unrecognized-spelling)
time_limit=90,
concurrency_limit=2,
)
api_key.submit(
lambda: (gr.update(visible=False), gr.update(visible=True)),
None,
[api_key_row, row],
)
return demo

def launch(self):
self.demo.launch(
server_name="0.0.0.0",
server_port=int(os.environ.get("PORT", 7860)),
ssl_verify=False,
ssl_keyfile=None,
ssl_certfile=None,
)


if __name__ == "__main__":
app = GeminiVoiceChat()
app.launch()
3 changes: 3 additions & 0 deletions gemini/gradio-voice/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
gradio_webrtc==0.0.23
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

Pinning the gradio_webrtc version is a good practice for reproducibility. However, it's generally recommended to specify a version range rather than a single version to allow for bug fixes and minor updates. Consider using a compatible version range like gradio-webrtc>=0.0.23,<0.1.0 to allow for updates while avoiding potentially breaking changes.

Suggested change
gradio_webrtc==0.0.23
gradio_webrtc>=0.0.23,<0.1.0

librosa
python-dotenv
Loading