Skip to content

Commit

Permalink
FIxed #90
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinfrlch committed Dec 22, 2024
1 parent 0610466 commit 147821b
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 53 deletions.
74 changes: 46 additions & 28 deletions custom_components/llmvision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ async def _remember(hass, call, start, response) -> None:
else:
camera_name = "Unknown"

camera_name = camera_name.replace("camera.", "").replace("image.", "").capitalize()
camera_name = camera_name.replace(
"camera.", "").replace("image.", "").capitalize()

await semantic_index.remember(
start=start,
Expand All @@ -185,13 +186,37 @@ async def _remember(hass, call, start, response) -> None:
)


async def _update_sensor(hass, sensor_entity, new_value) -> None:
async def _update_sensor(hass, sensor_entity: str, new_value: str | int, type: str) -> None:
"""Update the value of a sensor entity."""
if sensor_entity:
# Attempt to parse the response
if type == "boolean" and new_value.lower() not in ["on", "off"]:
if new_value.lower() in ["true", "false"]:
new_value = "on" if new_value.lower() == "true" else "off"
elif new_value.split(" ")[0].replace(",", "").lower() == "yes":
new_value = "on"
elif new_value.split(" ")[0].replace(",", "").lower() == "no":
new_value = "off"
else:
raise ServiceValidationError(
"Response could not be parsed. Please check your prompt.")
elif type == "number":
try:
new_value = float(new_value)
except ValueError:
raise ServiceValidationError(
"Response could not be parsed. Please check your prompt.")
elif type == "option":
options = hass.states.get(sensor_entity).attributes["options"]
if new_value not in options:
raise ServiceValidationError(
"Response could not be parsed. Please check your prompt.")
# Set the value
if new_value:
_LOGGER.info(
f"Updating sensor {sensor_entity} with new value: {new_value}")
try:
hass.states.async_set(sensor_entity, new_value)
current_attributes = hass.states.get(sensor_entity).attributes.copy()
hass.states.async_set(sensor_entity, new_value, current_attributes)
except Exception as e:
_LOGGER.error(f"Failed to update sensor {sensor_entity}: {e}")
raise
Expand Down Expand Up @@ -309,40 +334,33 @@ async def stream_analyzer(data_call):

async def data_analyzer(data_call):
"""Handle the service call to analyze visual data"""
def is_number(s):
"""Helper function to check if string can be parsed as number"""
try:
float(s)
return True
except ValueError:
return False

start = dt_util.now()
call = ServiceCallData(data_call).get_service_call_data()
sensor_entity = data_call.data.get("sensor_entity")
_LOGGER.info(f"Sensor entity: {sensor_entity}")

# get current value to determine data type
state = hass.states.get(sensor_entity).state
sensor_type = sensor_entity.split(".")[0]
_LOGGER.info(f"Current state: {state}")
if state == "unavailable":
raise ServiceValidationError("Sensor entity is unavailable")
if state == "on" or state == "off":
data_type = "'on' or 'off' (lowercase)"
elif is_number(state):
data_type = "number"
if sensor_type == "input_boolean" or sensor_type == "binary_sensor" or sensor_type == "switch" or sensor_type == "boolean":
data_type = "one of: ['on', 'off']"
type = "boolean"
elif sensor_type == "input_number" or sensor_type == "number" or sensor_type == "sensor":
data_type = "a number"
type = "number"
elif sensor_type == "input_select":
options = hass.states.get(sensor_entity).attributes["options"]
data_type = "one of these options: " + \
", ".join([f"'{option}'" for option in options])
type = "option"
else:
if "options" in hass.states.get(sensor_entity).attributes:
data_type = "one of these options: " + \
", ".join([f"'{option}'" for option in hass.states.get(
sensor_entity).attributes["options"]])
else:
data_type = "string"

message = f"Your job is to extract data from images. Return a {data_type} only. No additional text or other options allowed!. If unsure, choose the option that best matches. Follow these instructions: " + call.message
_LOGGER.info(f"Message: {message}")
raise ServiceValidationError("Unsupported sensor entity type")

call.message = f"Your job is to extract data from images. You can only respond with {data_type}. You must respond with one of the options! If unsure, choose the option that best matches. Answer the following question with the options provided: " + call.message
request = Request(hass,
message=message,
message=call.message,
max_tokens=call.max_tokens,
temperature=call.temperature,
)
Expand All @@ -355,7 +373,7 @@ def is_number(s):
response = await request.call(call)
_LOGGER.info(f"Response: {response}")
# udpate sensor in data_call.data.get("sensor_entity")
await _update_sensor(hass, sensor_entity, response["response_text"])
await _update_sensor(hass, sensor_entity, response["response_text"], type)
return response

# Register services
Expand Down
2 changes: 2 additions & 0 deletions custom_components/llmvision/media_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import time
import asyncio
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from functools import partial
from PIL import Image, UnidentifiedImageError
import numpy as np
Expand All @@ -19,6 +20,7 @@
class MediaProcessor:
def __init__(self, hass, client):
self.hass = hass
self.session = async_get_clientsession(self.hass)
self.client = client
self.base64_images = []
self.filenames = []
Expand Down
70 changes: 45 additions & 25 deletions custom_components/llmvision/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,13 @@ def get_provider(hass, provider_uid):

return None

def validate(self, call):
def validate(self, call) -> None | ServiceValidationError:
"""Validate call data"""
# Check image input
if not call.base64_images:
raise ServiceValidationError(ERROR_NO_IMAGE_INPUT)
# Check if single image is provided for Groq
if len(call.base64_images) > 1 and self.get_provider(call.provider) == 'Groq':
if len(call.base64_images) > 1 and self.get_provider(self.hass, call.provider) == 'Groq':
raise ServiceValidationError(ERROR_GROQ_MULTIPLE_IMAGES)
# Check provider is configured
if not call.provider:
Expand All @@ -105,13 +105,25 @@ def default_model(provider): return {
"Azure": "gpt-4o-mini",
"Custom OpenAI": "gpt-4o-mini",
"Google": "gemini-1.5-flash-latest",
"Groq": "llava-v1.5-7b-4096-preview",
"Groq": "llama-3.2-11b-vision-preview ",
"LocalAI": "gpt-4-vision-preview",
"Ollama": "llava-phi3:latest",
"Ollama": "minicpm-v",
"OpenAI": "gpt-4o-mini"
}.get(provider, "gpt-4o-mini") # Default value

async def call(self, call):
"""
Forwards a request to the specified provider and optionally generates a title.
Args:
call (object): The call object containing request details.
Raises:
ServiceValidationError: If the provider is invalid.
Returns:
dict: A dictionary containing the generated title (if any) and the response text.
"""
entry_id = call.provider
config = self.hass.data.get(DOMAIN).get(entry_id)

Expand All @@ -121,6 +133,8 @@ async def call(self, call):
call.base64_images = self.base64_images
call.filenames = self.filenames

self.validate(call)

gen_title_prompt = "Your job is to generate a title in the form '<object> seen' for texts. Do not mention the time, do not speculate. Generate a title for this text: {response}"

if provider == 'OpenAI':
Expand Down Expand Up @@ -249,6 +263,14 @@ async def _resolve_error(self, response, provider):


class Provider(ABC):
"""
Abstract base class for providers
Args:
hass (object): Home Assistant instance
api_key (str, optional): API key for the provider, defaults to ""
endpoint (dict, optional): Endpoint configuration for the provider
"""
def __init__(self,
hass,
api_key="",
Expand Down Expand Up @@ -337,16 +359,15 @@ async def _resolve_error(self, response, provider) -> str:
class OpenAI(Provider):
def __init__(self, hass, api_key="", endpoint={'base_url': ENDPOINT_OPENAI}):
super().__init__(hass, api_key, endpoint=endpoint)
self.default_model = "gpt-4o-mini"

def _generate_headers(self) -> dict:
return {'Content-type': 'application/json',
'Authorization': 'Bearer ' + self.api_key}

async def _make_request(self, data) -> str:
headers = self._generate_headers()

response = await self._post(url=self.endpoint.get('base_url'), headers=headers, data=data)

response_text = response.get(
"choices")[0].get("message").get("content")
return response_text
Expand Down Expand Up @@ -394,6 +415,7 @@ async def validate(self) -> None | ServiceValidationError:
class AzureOpenAI(Provider):
def __init__(self, hass, api_key="", endpoint={'base_url': "", 'deployment': "", 'api_version': ""}):
super().__init__(hass, api_key, endpoint)
self.default_model = "gpt-4o-mini"

def _generate_headers(self) -> dict:
return {'Content-type': 'application/json',
Expand All @@ -407,9 +429,7 @@ async def _make_request(self, data) -> str:
api_version=self.endpoint.get("api_version")
)

response = await self._post(
url=endpoint, headers=headers, data=data)

response = await self._post(url=endpoint, headers=headers, data=data)
response_text = response.get(
"choices")[0].get("message").get("content")
return response_text
Expand Down Expand Up @@ -459,6 +479,7 @@ async def validate(self) -> None | ServiceValidationError:
class Anthropic(Provider):
def __init__(self, hass, api_key=""):
super().__init__(hass, api_key)
self.default_model = "claude-3-5-sonnet-latest"

def _generate_headers(self) -> dict:
return {
Expand All @@ -468,8 +489,6 @@ def _generate_headers(self) -> dict:
}

async def _make_request(self, data) -> str:
api_key = self.api_key

headers = self._generate_headers()
response = await self._post(url=ENDPOINT_ANTHROPIC, headers=headers, data=data)
response_text = response.get("content")[0].get("text")
Expand All @@ -490,7 +509,7 @@ def _prepare_vision_data(self, call) -> dict:
data["messages"][0]["content"].append({"type": "image", "source": {
"type": "base64", "media_type": "image/jpeg", "data": f"{image}"}})
data["messages"][0]["content"].append(
{"type": "text", "text": self.message})
{"type": "text", "text": call.message})
return data

def _prepare_text_data(self, call) -> dict:
Expand Down Expand Up @@ -520,6 +539,7 @@ async def validate(self) -> None | ServiceValidationError:
class Google(Provider):
def __init__(self, hass, api_key="", endpoint={'base_url': ENDPOINT_GOOGLE, 'model': "gemini-1.5-flash-latest"}):
super().__init__(hass, api_key, endpoint)
self.default_model = "gemini-1.5-flash-latest"

def _generate_headers(self) -> dict:
return {'content-type': 'application/json'}
Expand Down Expand Up @@ -567,13 +587,12 @@ async def validate(self) -> None | ServiceValidationError:
class Groq(Provider):
def __init__(self, hass, api_key=""):
super().__init__(hass, api_key)
self.default_model = "llama-3.2-11b-vision-preview"

def _generate_headers(self) -> dict:
return {'Content-type': 'application/json', 'Authorization': 'Bearer ' + self.api_key}

async def _make_request(self, data) -> str:
api_key = self.api_key

headers = self._generate_headers()
response = await self._post(url=ENDPOINT_GROQ, headers=headers, data=data)
response_text = response.get(
Expand Down Expand Up @@ -615,18 +634,19 @@ async def validate(self) -> None | ServiceValidationError:
raise ServiceValidationError("empty_api_key")
headers = self._generate_headers()
data = {
"contents": [{
"parts": [
{"text": "Hello"}
]}
]
"model": "llama3-8b-8192",
"messages": [{
"role": "user",
"content": "Hi"
}]
}
await self._post(url=ENDPOINT_GROQ, headers=headers, data=data)


class LocalAI(Provider):
def __init__(self, hass, api_key="", endpoint={'ip_address': "", 'port': "", 'https': False}):
super().__init__(hass, api_key, endpoint)
self.default_model = "gpt-4-vision-preview"

async def _make_request(self, data) -> str:
endpoint = ENDPOINT_LOCALAI.format(
Expand Down Expand Up @@ -682,22 +702,19 @@ async def validate(self) -> None | ServiceValidationError:
class Ollama(Provider):
def __init__(self, hass, api_key="", endpoint={'ip_address': "0.0.0.0", 'port': "11434", 'https': False}):
super().__init__(hass, api_key, endpoint)
self.default_model = "minicpm-v"

async def _make_request(self, data) -> str:
https = self.endpoint.get("https")
ip_address = self.endpoint.get("ip_address")
port = self.endpoint.get("port")
protocol = "https" if https else "http"

endpoint = ENDPOINT_OLLAMA.format(
ip_address=ip_address,
port=port,
protocol=protocol
)

_LOGGER.info(
f"endpoint: {endpoint} https: {https} ip_address: {ip_address} port: {port}")

response = await self._post(url=endpoint, headers={}, data=data)
response_text = response.get("message").get("content")
return response_text
Expand Down Expand Up @@ -732,8 +749,11 @@ async def validate(self) -> None | ServiceValidationError:
protocol = "https" if self.endpoint.get("https") else "http"

try:
response = await session.get(f"{protocol}://{ip_address}:{port}/api/tags")
_LOGGER.info(f"Checking connection to {protocol}://{ip_address}:{port}")
response = await session.get(f"{protocol}://{ip_address}:{port}/api/tags", headers={})
_LOGGER.info(f"Response: {response}")
if response.status != 200:
raise ServiceValidationError('handshake_failed')
except Exception:
except Exception as e:
_LOGGER.error(f"Error: {e}")
raise ServiceValidationError('handshake_failed')

0 comments on commit 147821b

Please sign in to comment.