From 8372d0be5a649f06e3db4cc916d9a7582918665a Mon Sep 17 00:00:00 2001 From: Sumit Paul Date: Wed, 28 Aug 2024 00:11:09 +0530 Subject: [PATCH 1/2] add gemini support --- README.md | 7 +++++-- example.env | 1 + requirements.txt | 1 + voice_assistant/api_key_manager.py | 2 ++ voice_assistant/config.py | 14 ++++++++----- voice_assistant/response_generation.py | 27 ++++++++++++++++++++++++++ 6 files changed, 45 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index bf834a5..d6093c0 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ Welcome to the Voice Assistant project! 🎙️ Our goal is to create a modular ## Features 🧰 - **Modular Design**: Easily switch between different models for transcription, response generation, and TTS. -- **Support for Multiple APIs**: Integrates with OpenAI, Groq, and Deepgram APIs, along with placeholders for local models. +- **Support for Multiple APIs**: Integrates with OpenAI, Groq, Gemini, and Deepgram APIs, along with placeholders for local models. - **Audio Recording and Playback**: Record audio from the microphone and play generated speech. - **Configuration Management**: Centralized configuration in `config.py` for easy setup and management. @@ -79,6 +79,7 @@ Create a `.env` file in the root directory and add your API keys: ```shell OPENAI_API_KEY=your_openai_api_key GROQ_API_KEY=your_groq_api_key + GEMINI_API_KEY=your_gemini_api_key DEEPGRAM_API_KEY=your_deepgram_api_key LOCAL_MODEL_PATH=path/to/local/model ``` @@ -90,12 +91,13 @@ Edit config.py to select the models you want to use: class Config: # Model selection TRANSCRIPTION_MODEL = 'groq' # Options: 'openai', 'groq', 'deepgram', 'fastwhisperapi' 'local' - RESPONSE_MODEL = 'groq' # Options: 'openai', 'groq', 'ollama', 'local' + RESPONSE_MODEL = 'groq' # Options: 'openai', 'groq', 'ollama', 'gemini', 'local' TTS_MODEL = 'deepgram' # Options: 'openai', 'deepgram', 'elevenlabs', 'local', 'melotts' # API keys and paths OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") GROQ_API_KEY = os.getenv("GROQ_API_KEY") + GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") DEEPGRAM_API_KEY = os.getenv("DEEPGRAM_API_KEY") LOCAL_MODEL_PATH = os.getenv("LOCAL_MODEL_PATH") ``` @@ -180,6 +182,7 @@ If you are running LLM locally via [Ollama](https://ollama.com/), make sure the - **OpenAI**: Uses OpenAI's GPT-4 model. - **Groq**: Uses Groq's LLaMA model. +- **Gemini**: Uses Gemini's 1.5 Flash model. - **Ollama**: Uses any model served via Ollama. - **Local**: Placeholder for a local language model. diff --git a/example.env b/example.env index 819de71..d61e862 100644 --- a/example.env +++ b/example.env @@ -1,4 +1,5 @@ OPENAI_API_KEY="OPENAI_API_KEY" +GEMINI_API_KEY="GEMINI_API_KEY" GROQ_API_KEY="GROQ_API_KEY" DEEPGRAM_API_KEY="DEEPGRAM_API_KEY" ELEVENLABS_API_KEY="ELEVENLABS_API_KEY" diff --git a/requirements.txt b/requirements.txt index d533fbb..aa428a2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,6 +19,7 @@ SpeechRecognition==3.10.4 tqdm==4.66.4 typing_extensions==4.11.0 urllib3==2.2.1 +google-ai-generativelanguage==0.6.6 colorama deepgram-sdk groq diff --git a/voice_assistant/api_key_manager.py b/voice_assistant/api_key_manager.py index 176ccfd..faa3269 100644 --- a/voice_assistant/api_key_manager.py +++ b/voice_assistant/api_key_manager.py @@ -28,6 +28,8 @@ def get_response_api_key(): return Config.OPENAI_API_KEY elif Config.RESPONSE_MODEL == 'groq': return Config.GROQ_API_KEY + elif Config.RESPONSE_MODEL == 'gemini': + return Config.GEMINI_API_KEY return None def get_tts_api_key(): diff --git a/voice_assistant/config.py b/voice_assistant/config.py index 3ae4d9a..c4ab363 100644 --- a/voice_assistant/config.py +++ b/voice_assistant/config.py @@ -12,17 +12,18 @@ class Config: Attributes: TRANSCRIPTION_MODEL (str): The model to use for transcription ('openai', 'groq', 'deepgram', 'fastwhisperapi', 'local'). - RESPONSE_MODEL (str): The model to use for response generation ('openai', 'groq', 'local'). + RESPONSE_MODEL (str): The model to use for response generation ('openai', 'groq', 'gemini', 'local'). TTS_MODEL (str): The model to use for text-to-speech ('openai', 'deepgram', 'elevenlabs', 'local'). OPENAI_API_KEY (str): API key for OpenAI services. GROQ_API_KEY (str): API key for Groq services. + GEMINI_API_KEY (str): API key for Gemini services. DEEPGRAM_API_KEY (str): API key for Deepgram services. ELEVENLABS_API_KEY (str): API key for ElevenLabs services. LOCAL_MODEL_PATH (str): Path to the local model. """ # Model selection TRANSCRIPTION_MODEL = 'openai' # possible values: openai, groq, deepgram, fastwhisperapi - RESPONSE_MODEL = 'openai' # possible values: openai, groq, ollama + RESPONSE_MODEL = 'openai' # possible values: openai, groq, gemini, ollama TTS_MODEL = 'elevenlabs' # possible values: openai, deepgram, elevenlabs, melotts, cartesia # currently using the MeloTTS for local models. here is how to get started: @@ -32,10 +33,12 @@ class Config: OLLAMA_LLM="llama3:8b" GROQ_LLM="llama3-8b-8192" OPENAI_LLM="gpt-4o" + GEMINI_LLM="gemini-1.5-flash" # API keys and paths OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") GROQ_API_KEY = os.getenv("GROQ_API_KEY") + GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") DEEPGRAM_API_KEY = os.getenv("DEEPGRAM_API_KEY") ELEVENLABS_API_KEY = os.getenv("ELEVENLABS_API_KEY") LOCAL_MODEL_PATH = os.getenv("LOCAL_MODEL_PATH") @@ -57,8 +60,8 @@ def validate_config(): """ if Config.TRANSCRIPTION_MODEL not in ['openai', 'groq', 'deepgram', 'fastwhisperapi', 'local']: raise ValueError("Invalid TRANSCRIPTION_MODEL. Must be one of ['openai', 'groq', 'deepgram', 'fastwhisperapi', 'local']") - if Config.RESPONSE_MODEL not in ['openai', 'groq', 'ollama', 'local']: - raise ValueError("Invalid RESPONSE_MODEL. Must be one of ['openai', 'groq', 'local']") + if Config.RESPONSE_MODEL not in ['openai', 'groq', 'ollama', 'gemini', 'local']: + raise ValueError("Invalid RESPONSE_MODEL. Must be one of ['openai', 'groq', 'gemini', 'local']") if Config.TTS_MODEL not in ['openai', 'deepgram', 'elevenlabs', 'melotts', 'cartesia', 'local']: raise ValueError("Invalid TTS_MODEL. Must be one of ['openai', 'deepgram', 'elevenlabs', 'melotts', 'cartesia', 'local']") @@ -73,7 +76,8 @@ def validate_config(): raise ValueError("OPENAI_API_KEY is required for OpenAI models") if Config.RESPONSE_MODEL == 'groq' and not Config.GROQ_API_KEY: raise ValueError("GROQ_API_KEY is required for Groq models") - + if Config.RESPONSE_MODEL == 'gemini' and not Config.GEMINI_API_KEY: + raise ValueError("GEMINI_API_KEY is required for Gemini models") if Config.TTS_MODEL == 'openai' and not Config.OPENAI_API_KEY: raise ValueError("OPENAI_API_KEY is required for OpenAI models") diff --git a/voice_assistant/response_generation.py b/voice_assistant/response_generation.py index 5294128..d6788b7 100644 --- a/voice_assistant/response_generation.py +++ b/voice_assistant/response_generation.py @@ -4,6 +4,7 @@ from groq import Groq import ollama import logging +import google.generativeai as genai from voice_assistant.config import Config @@ -42,6 +43,32 @@ def generate_response(model, api_key, chat_history, local_model_path=None): # stream=True, ) return response['message']['content'] + elif model == 'gemini': + genai.configure(api_key=Config.GEMINI_API_KEY) + model = genai.GenerativeModel('gemini-1.5-flash') + # Convert chat history to the required format + # The current chat history structure is not compatible with the gemini model + # It expects the chat history to be in the format [{"role": "model", "parts": ""}] and [{"role": "user", "parts": ""}] + # However, the current chat history is in the format [{"role": "system", "content": ""}] and [{"role": "user", "content": ""}] + # To make it compatible, we need to convert the chat history by replacing "content" with "parts" + # Iterate over each message in the chat history + converted_chat_history = [ + {"role": "model" if message["role"] == "system" else message["role"], "parts": message["content"]} + for message in chat_history + ] + # Extract and remove the last user message + user_text = "" + for message in reversed(converted_chat_history): + if message["role"] == "user": + converted_chat_history.remove(message) + user_text = message["parts"] + break + # Start a new chat and generate a response + chat = model.start_chat( + history=converted_chat_history + ) + response = chat.send_message(user_text) + return response.text elif model == 'local': # Placeholder for local LLM response generation return "Generated response from local model" From 74c47da708b9ae8b8496713ad47fe33e725b2213 Mon Sep 17 00:00:00 2001 From: Sumit Paul Date: Sun, 15 Sep 2024 02:27:11 +0530 Subject: [PATCH 2/2] chore: Update response_generation.py to handle assistant role in chat history --- voice_assistant/response_generation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/voice_assistant/response_generation.py b/voice_assistant/response_generation.py index 6ccc9f7..0d91bae 100644 --- a/voice_assistant/response_generation.py +++ b/voice_assistant/response_generation.py @@ -75,7 +75,7 @@ def _generate_gemini_response(chat_history): # To make it compatible, we need to convert the chat history by replacing "content" with "parts" # Iterate over each message in the chat history converted_chat_history = [ - {"role": "model" if message["role"] == "system" else message["role"], "parts": message["content"]} + {"role": "model" if (message["role"] == "system" or message["role"] == "assistant") else message["role"], "parts": message["content"]} for message in chat_history ] # Extract and remove the last user message