Skip to content

Commit

Permalink
Merge pull request #67 from valentinfrlch/dev
Browse files Browse the repository at this point in the history
Added new provider: Groq
  • Loading branch information
valentinfrlch authored Sep 27, 2024
2 parents c0b0573 + a101dc5 commit 245552e
Show file tree
Hide file tree
Showing 10 changed files with 181 additions and 131 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
</p>
<p align=center>
<img src=https://img.shields.io/badge/HACS-Custom-orange.svg?style=for-the-badg>
<img src=https://img.shields.io/badge/version-1.1.1-blue>
<img src=https://img.shields.io/badge/version-1.1.3-blue>
<a href="https://github.com/valentinfrlch/ha-llmvision/issues">
<img src="https://img.shields.io/maintenance/yes/2024.svg">
<img alt="Issues" src="https://img.shields.io/github/issues/valentinfrlch/ha-llmvision?color=0088ff"/>
Expand Down
13 changes: 7 additions & 6 deletions custom_components/llmvision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
CONF_OPENAI_API_KEY,
CONF_ANTHROPIC_API_KEY,
CONF_GOOGLE_API_KEY,
CONF_GROQ_API_KEY,
CONF_LOCALAI_IP_ADDRESS,
CONF_LOCALAI_PORT,
CONF_LOCALAI_HTTPS,
Expand All @@ -28,19 +29,15 @@
)
from .request_handlers import RequestHandler
from .media_handlers import MediaProcessor
import logging
from homeassistant.core import SupportsResponse
from homeassistant.exceptions import ServiceValidationError

_LOGGER = logging.getLogger(__name__)


async def async_setup_entry(hass, entry):
"""Save config entry to hass.data"""
# Get all entries from config flow
openai_api_key = entry.data.get(CONF_OPENAI_API_KEY)
anthropic_api_key = entry.data.get(CONF_ANTHROPIC_API_KEY)
google_api_key = entry.data.get(CONF_GOOGLE_API_KEY)
groq_api_key = entry.data.get(CONF_GROQ_API_KEY)
localai_ip_address = entry.data.get(CONF_LOCALAI_IP_ADDRESS)
localai_port = entry.data.get(CONF_LOCALAI_PORT)
localai_https = entry.data.get(CONF_LOCALAI_HTTPS)
Expand All @@ -61,6 +58,7 @@ async def async_setup_entry(hass, entry):
CONF_OPENAI_API_KEY: openai_api_key,
CONF_ANTHROPIC_API_KEY: anthropic_api_key,
CONF_GOOGLE_API_KEY: google_api_key,
CONF_GROQ_API_KEY: groq_api_key,
CONF_LOCALAI_IP_ADDRESS: localai_ip_address,
CONF_LOCALAI_PORT: localai_port,
CONF_LOCALAI_HTTPS: localai_https,
Expand Down Expand Up @@ -106,10 +104,12 @@ def _default_model(self, provider):
return "claude-3-5-sonnet-20240620"
elif provider == "Google":
return "gemini-1.5-flash-latest"
elif provider == "Groq":
return "llava-v1.5-7b-4096-preview"
elif provider == "LocalAI":
return "gpt-4-vision-preview"
elif provider == "Ollama":
return "llava"
return "llava-phi3:latest"
elif provider == "Custom OpenAI":
return "gpt-4o-mini"

Expand Down Expand Up @@ -160,3 +160,4 @@ async def video_analyzer(data_call):
)

return True

79 changes: 59 additions & 20 deletions custom_components/llmvision/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from homeassistant.helpers.selector import selector
from homeassistant.exceptions import ServiceValidationError
from homeassistant.helpers.aiohttp_client import async_get_clientsession
import urllib.parse
from .const import (
DOMAIN,
CONF_OPENAI_API_KEY,
CONF_ANTHROPIC_API_KEY,
CONF_GOOGLE_API_KEY,
CONF_GROQ_API_KEY,
CONF_LOCALAI_IP_ADDRESS,
CONF_LOCALAI_PORT,
CONF_LOCALAI_HTTPS,
Expand Down Expand Up @@ -63,11 +65,22 @@ async def _validate_api_key(self, api_key):
payload = {
"contents": [{
"parts": [
{"text": "Hello, world!"}
{"text": "Hello"}
]}
]
}
method = "POST"
elif self.user_input["provider"] == "Groq":
header = {
'Authorization': 'Bearer ' + api_key,
'Content-Type': 'application/json'
}
base_url = "api.groq.com"
endpoint = "/openai/v1/chat/completions"
payload = {"messages": [
{"role": "user", "content": "Hello"}], "model": "gemma-7b-it"}
method = "POST"

return await self._handshake(base_url=base_url, endpoint=endpoint, protocol="https", header=header, payload=payload, expected_status=200, method=method)

def _validate_provider(self):
Expand Down Expand Up @@ -124,26 +137,16 @@ async def openai(self):

async def custom_openai(self):
self._validate_provider()
_LOGGER.debug(f"Splits: {len(self.user_input[CONF_CUSTOM_OPENAI_ENDPOINT].split(":"))}")
# URL with port
try:
if len(self.user_input[CONF_CUSTOM_OPENAI_ENDPOINT].split(":")) > 2:
protocol = self.user_input[CONF_CUSTOM_OPENAI_ENDPOINT].split(
"://")[0]
base_url = self.user_input[CONF_CUSTOM_OPENAI_ENDPOINT].split(
"://")[1].split("/")[0]
port = ":" + self.user_input[CONF_CUSTOM_OPENAI_ENDPOINT].split(":")[
1].split("/")[0]
# URL without port
else:
protocol = self.user_input[CONF_CUSTOM_OPENAI_ENDPOINT].split(
"://")[0]
base_url = self.user_input[CONF_CUSTOM_OPENAI_ENDPOINT].split(
"://")[1].split("/")[0]
port = ""
url = urllib.parse.urlparse(
self.user_input[CONF_CUSTOM_OPENAI_ENDPOINT])
protocol = url.scheme
base_url = url.hostname
port = ":" + str(url.port) if url.port else ""

endpoint = "/v1/models"
header = {'Content-type': 'application/json',
'Authorization': 'Bearer ' + self.user_input[CONF_CUSTOM_OPENAI_API_KEY]}
'Authorization': 'Bearer ' + self.user_input[CONF_CUSTOM_OPENAI_API_KEY]} if CONF_CUSTOM_OPENAI_API_KEY in self.user_input else {}
except Exception as e:
_LOGGER.error(f"Could not parse endpoint: {e}")
raise ServiceValidationError("endpoint_parse_failed")
Expand All @@ -167,6 +170,12 @@ async def google(self):
_LOGGER.error("Could not connect to Google server.")
raise ServiceValidationError("handshake_failed")

async def groq(self):
self._validate_provider()
if not await self._validate_api_key(self.user_input[CONF_GROQ_API_KEY]):
_LOGGER.error("Could not connect to Groq server.")
raise ServiceValidationError("handshake_failed")

def get_configured_providers(self):
providers = []
try:
Expand All @@ -186,6 +195,8 @@ def get_configured_providers(self):
providers.append("Ollama")
if CONF_CUSTOM_OPENAI_ENDPOINT in self.hass.data[DOMAIN]:
providers.append("Custom OpenAI")
if CONF_GROQ_API_KEY in self.hass.data[DOMAIN]:
providers.append("Groq")
return providers


Expand All @@ -202,6 +213,7 @@ async def handle_provider(self, provider, configured_providers):
"OpenAI": self.async_step_openai,
"Anthropic": self.async_step_anthropic,
"Google": self.async_step_google,
"Groq": self.async_step_groq,
"Ollama": self.async_step_ollama,
"LocalAI": self.async_step_localai,
"Custom OpenAI": self.async_step_custom_openai,
Expand All @@ -218,7 +230,7 @@ async def async_step_user(self, user_input=None):
data_schema = vol.Schema({
vol.Required("provider", default="OpenAI"): selector({
"select": {
"options": ["OpenAI", "Anthropic", "Google", "Ollama", "LocalAI", "Custom OpenAI"],
"options": ["OpenAI", "Anthropic", "Google", "Groq", "Ollama", "LocalAI", "Custom OpenAI"],
"mode": "dropdown",
"sort": False,
"custom_value": False
Expand Down Expand Up @@ -378,6 +390,33 @@ async def async_step_google(self, user_input=None):
data_schema=data_schema,
)

async def async_step_groq(self, user_input=None):
data_schema = vol.Schema({
vol.Required(CONF_GROQ_API_KEY): str,
})

if user_input is not None:
# save provider to user_input
user_input["provider"] = self.init_info["provider"]
validator = Validator(self.hass, user_input)
try:
await validator.groq()
# add the mode to user_input
user_input["provider"] = self.init_info["provider"]
return self.async_create_entry(title="LLM Vision Groq", data=user_input)
except ServiceValidationError as e:
_LOGGER.error(f"Validation failed: {e}")
return self.async_show_form(
step_id="groq",
data_schema=data_schema,
errors={"base": "handshake_failed"}
)

return self.async_show_form(
step_id="groq",
data_schema=data_schema,
)

async def async_step_custom_openai(self, user_input=None):
data_schema = vol.Schema({
vol.Required(CONF_CUSTOM_OPENAI_ENDPOINT): str,
Expand All @@ -404,4 +443,4 @@ async def async_step_custom_openai(self, user_input=None):
return self.async_show_form(
step_id="custom_openai",
data_schema=data_schema,
)
)
3 changes: 3 additions & 0 deletions custom_components/llmvision/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
CONF_OPENAI_API_KEY = 'openai_api_key'
CONF_ANTHROPIC_API_KEY = 'anthropic_api_key'
CONF_GOOGLE_API_KEY = 'google_api_key'
CONF_GROQ_API_KEY = 'groq_api_key'
CONF_LOCALAI_IP_ADDRESS = 'localai_ip'
CONF_LOCALAI_PORT = 'localai_port'
CONF_LOCALAI_HTTPS = 'localai_https'
Expand Down Expand Up @@ -35,6 +36,7 @@
ERROR_OPENAI_NOT_CONFIGURED = "OpenAI provider is not configured"
ERROR_ANTHROPIC_NOT_CONFIGURED = "Anthropic provider is not configured"
ERROR_GOOGLE_NOT_CONFIGURED = "Google provider is not configured"
ERROR_GROQ_NOT_CONFIGURED = "Groq provider is not configured"
ERROR_LOCALAI_NOT_CONFIGURED = "LocalAI provider is not configured"
ERROR_OLLAMA_NOT_CONFIGURED = "Ollama provider is not configured"
ERROR_CUSTOM_OPENAI_NOT_CONFIGURED = "Custom OpenAI provider is not configured"
Expand All @@ -49,5 +51,6 @@
ENDPOINT_OPENAI = "https://api.openai.com/v1/chat/completions"
ENDPOINT_ANTHROPIC = "https://api.anthropic.com/v1/messages"
ENDPOINT_GOOGLE = "https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={api_key}"
ENDPOINT_GROQ = "https://api.groq.com/openai/v1/chat/completions"
ENDPOINT_LOCALAI = "{protocol}://{ip_address}:{port}/v1/chat/completions"
ENDPOINT_OLLAMA = "{protocol}://{ip_address}:{port}/api/chat"
2 changes: 1 addition & 1 deletion custom_components/llmvision/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
"documentation": "https://github.com/valentinfrlch/ha-llmvision",
"iot_class": "cloud_polling",
"issue_tracker": "https://github.com/valentinfrlch/ha-llmvision/issues",
"version": "1.1.1"
"version": "1.1.3"
}
Loading

0 comments on commit 245552e

Please sign in to comment.