Skip to content

Commit

Permalink
add gemini support
Browse files Browse the repository at this point in the history
  • Loading branch information
SmartManoj committed Dec 7, 2024
1 parent 19cb66a commit 98c1872
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 16 deletions.
32 changes: 25 additions & 7 deletions run_voice_assistant.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# voice_assistant/main.py

import logging
import os
import time
from colorama import Fore, init
from voice_assistant.audio import record_audio, play_audio
Expand All @@ -18,22 +19,39 @@
init(autoreset=True)

import threading


# os.environ["GEMINI_MODEL"] = "gemini/gemini-exp-1206"
# os.environ["GEMINI_MODEL"] = "gemini/gemini-1.5-pro-002"
os.environ["GEMINI_MODEL"] = "gemini/gemini-1.5-flash-002"
import winsound

def generate_beep():
# tick sound
winsound.Beep(200, 500)

system_content ='''
You are the best true astrologer in the world.
You will give response with sources.
You can explain everything without hiding anything.
Respond to people in a friendly manner.
You know everything.
Speak in colloquial கொங்கு தமிழ்.
Your answers are short and concise.
'''
def main():
"""
Main function to run the voice assistant.
"""
chat_history = [
{"role": "system", "content": """ You are a helpful Assistant called Verbi.
You are friendly and fun and you will help the users with their requests.
Your answers are short and concise. """}
{"role": "system", "content": system_content
}
]

while True:
try:
# Record audio from the microphone and save it as 'test.wav'
record_audio(Config.INPUT_AUDIO)
generate_beep()

# Get the API key for transcription
transcription_api_key = get_transcription_api_key()
Expand All @@ -42,7 +60,7 @@ def main():
user_input = transcribe_audio(Config.TRANSCRIPTION_MODEL, transcription_api_key, Config.INPUT_AUDIO, Config.LOCAL_MODEL_PATH)

# Check if the transcription is empty and restart the recording if it is. This check will avoid empty requests if vad_filter is used in the fastwhisperapi.
if not user_input:
if not user_input or "<empty>" == user_input.lower().strip():
logging.info("No transcription was returned. Starting recording again.")
continue
logging.info(Fore.GREEN + "You said: " + user_input + Fore.RESET)
Expand Down Expand Up @@ -88,7 +106,7 @@ def main():

except Exception as e:
logging.error(Fore.RED + f"An error occurred: {e}" + Fore.RESET)
delete_file(Config.INPUT_AUDIO)
# delete_file(Config.INPUT_AUDIO)
if 'output_file' in locals():
delete_file(output_file)
time.sleep(1)
Expand Down
3 changes: 2 additions & 1 deletion voice_assistant/api_key_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
"transcription":{
"openai": Config.OPENAI_API_KEY,
"groq": Config.GROQ_API_KEY,
"deepgram": Config.DEEPGRAM_API_KEY
"deepgram": Config.DEEPGRAM_API_KEY,
"gemini": Config.GEMINI_API_KEY,
},
"response":{
"openai":Config.OPENAI_API_KEY,
Expand Down
4 changes: 4 additions & 0 deletions voice_assistant/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from io import BytesIO
from pydub import AudioSegment
from functools import lru_cache
import winsound
def generate_beep():
winsound.Beep(200, 500)

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
Expand Down Expand Up @@ -48,6 +51,7 @@ def record_audio(file_path, timeout=10, phrase_time_limit=None, retries=3, energ
logging.info("Calibrating for ambient noise...")
recognizer.adjust_for_ambient_noise(source, duration=calibration_duration)
logging.info("Recording started")
generate_beep()
# Listen for the first phrase and extract it into audio data
audio_data = recognizer.listen(source, timeout=timeout, phrase_time_limit=phrase_time_limit)
logging.info("Recording complete")
Expand Down
7 changes: 4 additions & 3 deletions voice_assistant/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ class Config:
LOCAL_MODEL_PATH (str): Path to the local model.
"""
# Model selection
TRANSCRIPTION_MODEL = 'deepgram' # possible values: openai, groq, deepgram, fastwhisperapi
RESPONSE_MODEL = 'openai' # possible values: openai, groq, ollama
TTS_MODEL = 'openai' # possible values: openai, deepgram, elevenlabs, melotts, cartesia
TRANSCRIPTION_MODEL = 'groq' # possible values: openai, groq, deepgram, fastwhisperapi
RESPONSE_MODEL = 'gemini' # possible values: openai, groq, ollama
TTS_MODEL = 'elevenlabs' # possible values: openai, deepgram, elevenlabs, melotts, cartesia

# currently using the MeloTTS for local models. here is how to get started:
# https://github.com/myshell-ai/MeloTTS/blob/main/docs/install.md#linux-and-macos-install
Expand All @@ -38,6 +38,7 @@ class Config:
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
DEEPGRAM_API_KEY = os.getenv("DEEPGRAM_API_KEY")
ELEVENLABS_API_KEY = os.getenv("ELEVENLABS_API_KEY")
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
LOCAL_MODEL_PATH = os.getenv("LOCAL_MODEL_PATH")
CARTESIA_API_KEY = os.getenv("CARTESIA_API_KEY")

Expand Down
22 changes: 21 additions & 1 deletion voice_assistant/response_generation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# voice_assistant/response_generation.py

import logging
import os

import litellm
from openai import OpenAI
from groq import Groq
import ollama
Expand All @@ -25,6 +27,8 @@ def generate_response(model:str, api_key:str, chat_history:list, local_model_pat
try:
if model == 'openai':
return _generate_openai_response(api_key, chat_history)
elif model == 'gemini':
return _generate_gemini_response(chat_history)
elif model == 'groq':
return _generate_groq_response(api_key, chat_history)
elif model == 'ollama':
Expand Down Expand Up @@ -61,4 +65,20 @@ def _generate_ollama_response(chat_history):
model=Config.OLLAMA_LLM,
messages=chat_history,
)
return response['message']['content']
return response['message']['content']


def _generate_gemini_response(chat_history):
model = os.environ["GEMINI_MODEL"]

response = litellm.completion(
model=model,
messages=chat_history,
)
return response.choices[0].message.content


if __name__ == "__main__":
msg = ' நாபத்திரண்டும் எம்பத்திரண்டும் எத்தனை?'
os.environ["GEMINI_MODEL"] = "gemini/gemini-1.5-flash-002"
print(_generate_gemini_response([ {"role": "user", "content": msg}]))
4 changes: 2 additions & 2 deletions voice_assistant/text_to_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ def text_to_speech(model: str, api_key:str, text:str, output_file_path:str, loca
client = ElevenLabs(api_key=api_key)
audio = client.generate(
text=text,
voice="Paul J.",
voice="Nila - Warm & Expressive Tamil Voice",
output_format="mp3_22050_32",
model="eleven_turbo_v2"
model="eleven_turbo_v2_5"
)
elevenlabs.save(audio, output_file_path)

Expand Down
47 changes: 45 additions & 2 deletions voice_assistant/transcription.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# voice_assistant/transcription.py

import os
import traceback
import json
import logging
import requests
Expand Down Expand Up @@ -46,6 +48,8 @@ def transcribe_audio(model, api_key, audio_file_path, local_model_path=None):
return _transcribe_with_groq(api_key, audio_file_path)
elif model == 'deepgram':
return _transcribe_with_deepgram(api_key, audio_file_path)
elif model == 'gemini':
return _transcribe_with_gemini(api_key,audio_file_path)
elif model == 'fastwhisperapi':
return _transcribe_with_fastwhisperapi(audio_file_path)
elif model == 'local':
Expand All @@ -55,6 +59,7 @@ def transcribe_audio(model, api_key, audio_file_path, local_model_path=None):
raise ValueError("Unsupported transcription model")
except Exception as e:
logging.error(f"{Fore.RED}Failed to transcribe audio: {e}{Fore.RESET}")
traceback.print_exc()
raise Exception("Error in transcribing audio")

def _transcribe_with_openai(api_key, audio_file_path):
Expand All @@ -74,7 +79,8 @@ def _transcribe_with_groq(api_key, audio_file_path):
transcription = client.audio.transcriptions.create(
model="whisper-large-v3",
file=audio_file,
language='en'
# language='en'
language='ta'
)
return transcription.text

Expand Down Expand Up @@ -112,4 +118,41 @@ def _transcribe_with_fastwhisperapi(audio_file_path):

response = requests.post(endpoint, files=files, data=data, headers=headers)
response_json = response.json()
return response_json.get('text', 'No text found in the response.')
return response_json.get('text', 'No text found in the response.')


from pathlib import Path
import base64
import litellm


def _transcribe_with_gemini(api_key,audio_file_path):
model = os.environ["GEMINI_MODEL"]
audio_bytes = Path(audio_file_path).read_bytes()
encoded_data = base64.b64encode(audio_bytes).decode("utf-8")

response = litellm.completion(
model=model,
api_key=api_key,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Just transcribe the Tamil audio. If no audio is detected, just say '<empty>'."},
{
"type": "image_url",
"image_url": "data:audio/mp3;base64,{}".format(encoded_data), # 👈 SET MIME_TYPE + DATA
},
],
}
],
)
return response.choices[0].message.content


if __name__ == "__main__":
from dotenv import load_dotenv
load_dotenv()
os.environ["GEMINI_MODEL"] = "gemini/gemini-1.5-flash-002"
# print(_transcribe_with_gemini(os.environ["GEMINI_API_KEY"],"test_cropped.mp3"))
print(_transcribe_with_groq(os.environ["GROQ_API_KEY"],"test_cropped.mp3"))

0 comments on commit 98c1872

Please sign in to comment.