diff --git a/benchmark_visualization/benchmark_data.csv b/benchmark_visualization/benchmark_data.csv index 08a008f..b71a35f 100644 --- a/benchmark_visualization/benchmark_data.csv +++ b/benchmark_visualization/benchmark_data.csv @@ -1,4 +1,6 @@ Model,Size,Date,Overall,Art & Design,Business,Science,Health & Medicine,Human. & Social Sci.,Tech & Eng.,Cost +Llama 3.2 90B, 90B, 2024-09-25, 60.3, -, -, -, -, -, -, 0.00000001 +Llama 3.2 11B, 11B, 2024-09-25, 50.7, -, -, -, -, -, -, 0.00000001 GPT-4o,-,2024-05-27,69.1,-,-,-,-,-,-,5 GPT-4o mini,-,2024-05-27,59.4,-,-,-,-,-,-,0.15 Gemini 1.5 Pro,-,2024-05-31,65.8,-,-,-,-,-,-,3.5 @@ -33,9 +35,10 @@ Yi-VL-6B*,6B,2024-01-23,39.1,52.5,30.7,31.3,38,53.3,35.7, InternVL-Chat-V1.1*,-,2024-01-27,39.1,56.7,34.7,31.3,39.3,57.5,27.1, Bunny-3B*,3B,2024-02-13,38.2,49.2,30.7,30.7,40.7,45,37.1, SVIT*,-,2023-12-26,38,52.5,27.3,28,42,51.7,33.8, -MiniCPM-V*,-,2024-02-07,37.2,55.8,33.3,28,32.7,58.3,27.1, -MiniCPM-V-2*,-,2024-04-16,37.1,63.3,28.7,30,30,56.7,27.1, -LLaVA-1.5-13B,13B,2023-11-27,36.4,51.7,22.7,29.3,38.7,53.3,31.4, +MiniCPM-V*,-,2024-02-07,37.2,55.8,33.3,28,32.7,58.3,27.1,0.00000001 +MiniCPM-V-2*,-,2024-04-16,37.1,63.3,28.7,30,30,56.7,27.1,0.00000001 +MiniCPM-V-2.6,-,2024-04-16,49.8,63.3,28.7,30,30,56.7,27.1,0.00000001 +LLaVA-1.5-13B,13B,2023-11-27,36.4,51.7,22.7,29.3,38.7,53.3,31.4,0.00000001 Emu2-Chat*,-,2023-12-24,36.3,55,30,28.7,28.7,46.7,35.2, Qwen-VL-7B-Chat,-,2023-11-27,35.9,51.7,29.3,29.3,33.3,45,32.9, InstructBLIP-T5-XXL,-,2023-11-27,35.7,44.2,24,30.7,35.3,49.2,35.2, diff --git a/benchmark_visualization/model_benchmark_visualizer.py b/benchmark_visualization/model_benchmark_visualizer.py index 7eb7593..724c3e7 100644 --- a/benchmark_visualization/model_benchmark_visualizer.py +++ b/benchmark_visualization/model_benchmark_visualizer.py @@ -24,6 +24,8 @@ def category_name(model_name): return 'Anthropic Claude 3' elif "Gemini 1.5" in model_name: return 'Google Gemini 1.5' + elif "Llama 3.2" in model_name: + return 'Meta Llama 3.2' return 'Other' @@ -47,7 +49,7 @@ def categorize_model(model_name): # Set order for legend category_order = ['GPT-4', 'Claude 3', - 'Claude 3.5', 'Gemini 1.5'] # Add 'Gemini 1.5' + 'Claude 3.5', 'Gemini 1.5'] df['Category'] = pd.Categorical( df['Category'], categories=category_order, ordered=True) df = df.sort_values('Category') @@ -63,16 +65,20 @@ def categorize_model(model_name): x = group_df['Cost'].astype(float) y = group_df['Overall'].astype(float) - # Fit a polynomial - z = np.polyfit(x, y, 2) - p = np.poly1d(z) - - x_poly = np.linspace(x.min(), x.max(), 100) - y_poly = p(x_poly) - - # Plot the poly line - fig.add_trace(go.Scatter(x=x_poly, y=y_poly, mode='lines', - name=f"{category_name(category)}", line=dict(color=colors[category], width=6, dash='dash'), opacity=0.5)) + try: + # Fit a polynomial + z = np.polyfit(x, y, 2) + p = np.poly1d(z) + + x_poly = np.linspace(x.min(), x.max(), 100) + y_poly = p(x_poly) + + # Plot the poly line + fig.add_trace(go.Scatter(x=x_poly, y=y_poly, mode='lines', + name=f"{category_name(category)}", line=dict(color=colors[category], width=6, dash='dash'), opacity=0.5)) + except np.linalg.LinAlgError: + print(f"LinAlgError: SVD did not converge for category {category}") + continue # plot the actual datapoints for index, model in df.iterrows(): @@ -102,7 +108,101 @@ def categorize_model(model_name): fig.write_image("benchmark_visualization/benchmark_visualization.jpg", width=1920, height=1080, scale=1) + # Create a second visualization for open source models + fig_open_source = go.Figure() + + def categorize_open_source_model(model_name): + """Categories open source models based on name""" + if "Llama" in model_name: + return 'Llama' + elif "LLaVA" in model_name: + return 'LLaVA' + elif "MiniCPM" in model_name: + return 'MiniCPM' + return 'Other' + + # Categorize each open source model in the DataFrame + df['OpenSourceCategory'] = df['Model'].apply(categorize_open_source_model) + + # Set order for legend + open_source_category_order = ['Llama', 'LLaVA', 'MiniCPM'] + df['OpenSourceCategory'] = pd.Categorical( + df['OpenSourceCategory'], categories=open_source_category_order, ordered=True) + df = df.sort_values('OpenSourceCategory') + + # Set colors for different open source models + open_source_colors = {'Llama': '#0081fb', + 'LLaVA': '#ff7f0e', 'MiniCPM': '#2ca02c', 'Other': 'gray'} + + # Convert 'Size' column to float + def convert_size_to_float(size_str): + """Convert model size string to float""" + size_str = size_str.strip() # Remove leading/trailing spaces + if size_str == "-": + return 0 + if size_str.endswith('B'): + return float(size_str[:-1]) * 1e9 + elif size_str.endswith('M'): + return float(size_str[:-1]) * 1e6 + return float(size_str) + + df['Size'] = df['Size'].apply(convert_size_to_float) + + for category, group_df in df.groupby('OpenSourceCategory'): + if category not in ['Llama', 'LLaVA', 'MiniCPM']: + continue + + x = group_df['Size'].astype(float) + y = group_df['Overall'].astype(float) + + if len(x) == 0 or len(y) == 0: + continue + + try: + # Fit a polynomial + z = np.polyfit(x, y, 2) + p = np.poly1d(z) + + x_poly = np.linspace(x.min(), x.max(), 100) + y_poly = p(x_poly) + + # Plot the poly line + fig_open_source.add_trace(go.Scatter(x=x_poly, y=y_poly, mode='lines', + name=f"{category_name(category)}", line=dict(color=open_source_colors[category], width=6, dash='dash'), opacity=0.5)) + except np.linalg.LinAlgError: + print(f"LinAlgError: SVD did not converge for category {category}") + continue + + # plot the actual datapoints + for index, model in df.iterrows(): + try: + fig_open_source.add_trace(go.Scatter(x=[model['Size']], y=[model['Overall']], + mode='markers', # Removed 'text' from mode + name=model['Model'], marker=dict(size=20, color=open_source_colors[model['OpenSourceCategory']]))) + + except Exception as e: + continue + + # Add model name + fig_open_source.add_annotation(x=model['Size'], y=model['Overall'], + text=model['Model'], + showarrow=False, + yshift=-35) + + fig_open_source.update_layout(title={'text': 'Performance vs Model Size of Open Source Models in LLM Vision', + 'font': {'size': 50}}, + xaxis_title='Model Size (Parameters)', yaxis_title='MMMU Score Average', + paper_bgcolor='#0d1117', plot_bgcolor='#161b22', + font=dict(color='white', family='Product Sans', size=25), + xaxis=dict(color='white', linecolor='grey', + showgrid=False, zeroline=False), + yaxis=dict(color='white', linecolor='grey', + showgrid=False, zeroline=False)) + # Save the plot as an image + fig_open_source.write_image("benchmark_visualization/open_source_benchmark_visualization.jpg", + width=1920, height=1080, scale=1) + if __name__ == "__main__": df = read_benchmark_data() - create_benchmark_visualization(df) + create_benchmark_visualization(df) \ No newline at end of file diff --git a/benchmark_visualization/open_source_benchmark_visualization.jpg b/benchmark_visualization/open_source_benchmark_visualization.jpg new file mode 100644 index 0000000..3bd5847 Binary files /dev/null and b/benchmark_visualization/open_source_benchmark_visualization.jpg differ diff --git a/blueprints/event_summary.yaml b/blueprints/event_summary.yaml index ea5304e..e2051af 100644 --- a/blueprints/event_summary.yaml +++ b/blueprints/event_summary.yaml @@ -431,4 +431,4 @@ action: group: "{{group}}" interruption-level: passive - - delay: '00:{{cooldown|int}}:00' + - delay: '00:{{cooldown|int}}:00' \ No newline at end of file diff --git a/custom_components/llmvision/__init__.py b/custom_components/llmvision/__init__.py index ef72029..07aece5 100644 --- a/custom_components/llmvision/__init__.py +++ b/custom_components/llmvision/__init__.py @@ -2,6 +2,10 @@ from .const import ( DOMAIN, CONF_OPENAI_API_KEY, + CONF_AZURE_API_KEY, + CONF_AZURE_VERSION, + CONF_AZURE_BASE_URL, + CONF_AZURE_DEPLOYMENT, CONF_ANTHROPIC_API_KEY, CONF_GOOGLE_API_KEY, CONF_GROQ_API_KEY, @@ -30,20 +34,22 @@ DURATION, MAX_FRAMES, TEMPERATURE, - DETAIL, INCLUDE_FILENAME, EXPOSE_IMAGES, EXPOSE_IMAGES_PERSIST, + GENERATE_TITLE, SENSOR_ENTITY, ) from .calendar import SemanticIndex +from .providers import Request +from .media_handlers import MediaProcessor +import os from datetime import timedelta from homeassistant.util import dt as dt_util from homeassistant.config_entries import ConfigEntry -from .request_handlers import RequestHandler -from .media_handlers import MediaProcessor from homeassistant.core import SupportsResponse from homeassistant.exceptions import ServiceValidationError +from functools import partial import logging _LOGGER = logging.getLogger(__name__) @@ -56,6 +62,10 @@ async def async_setup_entry(hass, entry): # Get all entries from config flow openai_api_key = entry.data.get(CONF_OPENAI_API_KEY) + azure_api_key = entry.data.get(CONF_AZURE_API_KEY) + azure_base_url = entry.data.get(CONF_AZURE_BASE_URL) + azure_deployment = entry.data.get(CONF_AZURE_DEPLOYMENT) + azure_version = entry.data.get(CONF_AZURE_VERSION) anthropic_api_key = entry.data.get(CONF_ANTHROPIC_API_KEY) google_api_key = entry.data.get(CONF_GOOGLE_API_KEY) groq_api_key = entry.data.get(CONF_GROQ_API_KEY) @@ -76,6 +86,10 @@ async def async_setup_entry(hass, entry): # Create a dictionary for the entry data entry_data = { CONF_OPENAI_API_KEY: openai_api_key, + CONF_AZURE_API_KEY: azure_api_key, + CONF_AZURE_BASE_URL: azure_base_url, + CONF_AZURE_DEPLOYMENT: azure_deployment, + CONF_AZURE_VERSION: azure_version, CONF_ANTHROPIC_API_KEY: anthropic_api_key, CONF_GOOGLE_API_KEY: google_api_key, CONF_GROQ_API_KEY: groq_api_key, @@ -99,6 +113,8 @@ async def async_setup_entry(hass, entry): # check if the entry is the calendar entry (has entry rentention_time) if filtered_entry_data.get(CONF_RETENTION_TIME) is not None: + # make sure 'llmvision' directory exists + await hass.loop.run_in_executor(None, partial(os.makedirs, "/llmvision", exist_ok=True)) # forward the calendar entity to the platform await hass.config_entries.async_forward_entry_setups(entry, ["calendar"]) @@ -113,7 +129,7 @@ async def async_remove_entry(hass, entry): if entry_uid in hass.data[DOMAIN]: # Remove the entry from hass.data _LOGGER.info(f"Removing {entry.title} from hass.data") - async_unload_entry(hass, entry) + await async_unload_entry(hass, entry) hass.data[DOMAIN].pop(entry_uid) else: _LOGGER.warning( @@ -134,12 +150,11 @@ async def async_migrate_entry(hass, config_entry: ConfigEntry) -> bool: return False -async def _remember(hass, call, start, response): +async def _remember(hass, call, start, response) -> None: if call.remember: # Find semantic index config config_entry = None for entry in hass.config_entries.async_entries(DOMAIN): - _LOGGER.info(f"Entry: {entry.data}") # Check if the config entry is empty if entry.data["provider"] == "Event Calendar": config_entry = entry @@ -150,35 +165,8 @@ async def _remember(hass, call, start, response): f"'Event Calendar' config not found") semantic_index = SemanticIndex(hass, config_entry) - # Define a mapping of keywords to labels - keyword_to_label = { - "person": "Person", - "man": "Person", - "woman": "Person", - "individual": "Person", - "delivery": "Delivery", - "courier": "Courier", - "package": "Package", - "car": "Car", - "vehicle": "Car", - "bike": "Bike", - "bicycle": "Bike", - "bus": "Bus", - "truck": "Truck", - "motorcycle": "Motorcycle", - "bicycle": "Bicycle", - "dog": "Dog", - "cat": "Cat", - } - - # Default label - label = "Unknown object" - - # Check each keyword in the response text and update the label accordingly - for keyword, mapped_label in keyword_to_label.items(): - if keyword in response["response_text"].lower(): - label = mapped_label - break + + title = response.get("title", "Unknown object seen") if call.image_entities and len(call.image_entities) > 0: camera_name = call.image_entities[0] @@ -188,23 +176,49 @@ async def _remember(hass, call, start, response): else: camera_name = "Unknown" - camera_name = camera_name.replace("camera.", "").replace("image.", "") + camera_name = camera_name.replace( + "camera.", "").replace("image.", "").capitalize() await semantic_index.remember( start=start, end=dt_util.now() + timedelta(minutes=1), - label=label + " seen", - camera_name=camera_name, + label=title + " near " + camera_name if camera_name != "Unknown" else title, + camera_name=camera_name if camera_name != "Unknown" else "Image Input", summary=response["response_text"] ) -async def _update_sensor(hass, sensor_entity, new_value): +async def _update_sensor(hass, sensor_entity: str, new_value: str | int, type: str) -> None: """Update the value of a sensor entity.""" - if sensor_entity: - _LOGGER.info(f"Updating sensor {sensor_entity} with new value: {new_value}") + # Attempt to parse the response + if type == "boolean" and new_value.lower() not in ["on", "off"]: + if new_value.lower() in ["true", "false"]: + new_value = "on" if new_value.lower() == "true" else "off" + elif new_value.split(" ")[0].replace(",", "").lower() == "yes": + new_value = "on" + elif new_value.split(" ")[0].replace(",", "").lower() == "no": + new_value = "off" + else: + raise ServiceValidationError( + "Response could not be parsed. Please check your prompt.") + elif type == "number": + try: + new_value = float(new_value) + except ValueError: + raise ServiceValidationError( + "Response could not be parsed. Please check your prompt.") + elif type == "option": + options = hass.states.get(sensor_entity).attributes["options"] + if new_value not in options: + raise ServiceValidationError( + "Response could not be parsed. Please check your prompt.") + # Set the value + if new_value: + _LOGGER.info( + f"Updating sensor {sensor_entity} with new value: {new_value}") try: - hass.states.async_set(sensor_entity, new_value) + current_attributes = hass.states.get(sensor_entity).attributes.copy() + hass.states.async_set(sensor_entity, new_value, current_attributes) except Exception as e: _LOGGER.error(f"Failed to update sensor {sensor_entity}: {e}") raise @@ -236,11 +250,14 @@ def __init__(self, data_call): self.target_width = data_call.data.get(TARGET_WIDTH, 3840) self.temperature = float(data_call.data.get(TEMPERATURE, 0.3)) self.max_tokens = int(data_call.data.get(MAXTOKENS, 100)) - self.detail = str(data_call.data.get(DETAIL, "auto")) self.include_filename = data_call.data.get(INCLUDE_FILENAME, False) self.expose_images = data_call.data.get(EXPOSE_IMAGES, False) self.expose_images_persist = data_call.data.get(EXPOSE_IMAGES_PERSIST, False) + self.generate_title = data_call.data.get(GENERATE_TITLE, False) self.sensor_entity = data_call.data.get(SENSOR_ENTITY) + # ------------ Added during call ------------ + # self.base64_images : List[str] = [] + # self.filenames : List[str] = [] def get_service_call_data(self): return self @@ -254,24 +271,24 @@ async def image_analyzer(data_call): # Initialize call object with service call data call = ServiceCallData(data_call).get_service_call_data() # Initialize the RequestHandler client - client = RequestHandler(hass, - message=call.message, - max_tokens=call.max_tokens, - temperature=call.temperature, - detail=call.detail) + request = Request(hass=hass, + message=call.message, + max_tokens=call.max_tokens, + temperature=call.temperature, + ) # Fetch and preprocess images - processor = MediaProcessor(hass, client) + processor = MediaProcessor(hass, request) # Send images to RequestHandler client - client = await processor.add_images(image_entities=call.image_entities, - image_paths=call.image_paths, - target_width=call.target_width, - include_filename=call.include_filename, - expose_images=call.expose_images - ) + request = await processor.add_images(image_entities=call.image_entities, + image_paths=call.image_paths, + target_width=call.target_width, + include_filename=call.include_filename, + expose_images=call.expose_images + ) # Validate configuration, input data and make the call - response = await client.make_request(call) + response = await request.call(call) await _remember(hass, call, start, response) return response @@ -280,23 +297,24 @@ async def video_analyzer(data_call): start = dt_util.now() call = ServiceCallData(data_call).get_service_call_data() call.message = "The attached images are frames from a video. " + call.message - client = RequestHandler(hass, - message=call.message, - max_tokens=call.max_tokens, - temperature=call.temperature, - detail=call.detail) - processor = MediaProcessor(hass, client) - client = await processor.add_videos(video_paths=call.video_paths, - event_ids=call.event_id, - max_frames=call.max_frames, - target_width=call.target_width, - include_filename=call.include_filename, - expose_images=call.expose_images, - expose_images_persist=call.expose_images_persist, - frigate_retry_attempts=call.frigate_retry_attempts, - frigate_retry_seconds=call.frigate_retry_seconds - ) - response = await client.make_request(call) + + request = Request(hass, + message=call.message, + max_tokens=call.max_tokens, + temperature=call.temperature, + ) + processor = MediaProcessor(hass, request) + request = await processor.add_videos(video_paths=call.video_paths, + event_ids=call.event_id, + max_frames=call.max_frames, + target_width=call.target_width, + include_filename=call.include_filename, + expose_images=call.expose_images, + expose_images_persist=call.expose_images_persist, + frigate_retry_attempts=call.frigate_retry_attempts, + frigate_retry_seconds=call.frigate_retry_seconds + ) + response = await request.call(call) await _remember(hass, call, start, response) return response @@ -305,71 +323,66 @@ async def stream_analyzer(data_call): start = dt_util.now() call = ServiceCallData(data_call).get_service_call_data() call.message = "The attached images are frames from a live camera feed. " + call.message - client = RequestHandler(hass, - message=call.message, - max_tokens=call.max_tokens, - temperature=call.temperature, - detail=call.detail) - processor = MediaProcessor(hass, client) - client = await processor.add_streams(image_entities=call.image_entities, - duration=call.duration, - max_frames=call.max_frames, - target_width=call.target_width, - include_filename=call.include_filename, - expose_images=call.expose_images - ) - - response = await client.make_request(call) + request = Request(hass, + message=call.message, + max_tokens=call.max_tokens, + temperature=call.temperature, + ) + processor = MediaProcessor(hass, request) + request = await processor.add_streams(image_entities=call.image_entities, + duration=call.duration, + max_frames=call.max_frames, + target_width=call.target_width, + include_filename=call.include_filename, + expose_images=call.expose_images + ) + + response = await request.call(call) await _remember(hass, call, start, response) return response async def data_analyzer(data_call): """Handle the service call to analyze visual data""" - def is_number(s): - """Helper function to check if string can be parsed as number""" - try: - float(s) - return True - except ValueError: - return False - - start = dt_util.now() call = ServiceCallData(data_call).get_service_call_data() sensor_entity = data_call.data.get("sensor_entity") _LOGGER.info(f"Sensor entity: {sensor_entity}") - + # get current value to determine data type state = hass.states.get(sensor_entity).state + sensor_type = sensor_entity.split(".")[0] _LOGGER.info(f"Current state: {state}") if state == "unavailable": raise ServiceValidationError("Sensor entity is unavailable") - if state == "on" or state == "off": - data_type = "'on' or 'off' (lowercase)" - elif is_number(state): - data_type = "number" + if sensor_type == "input_boolean" or sensor_type == "binary_sensor" or sensor_type == "switch" or sensor_type == "boolean": + data_type = "one of: ['on', 'off']" + type = "boolean" + elif sensor_type == "input_number" or sensor_type == "number" or sensor_type == "sensor": + data_type = "a number" + type = "number" + elif sensor_type == "input_select": + options = hass.states.get(sensor_entity).attributes["options"] + data_type = "one of these options: " + \ + ", ".join([f"'{option}'" for option in options]) + type = "option" else: - if "options" in hass.states.get(sensor_entity).attributes: - data_type = "one of these options: " + ", ".join([f"'{option}'" for option in hass.states.get(sensor_entity).attributes["options"]]) - else: - data_type = "string" - - message = f"Your job is to extract data from images. Return a {data_type} only. No additional text or other options allowed!. If unsure, choose the option that best matches. Follow these instructions: " + call.message - _LOGGER.info(f"Message: {message}") - client = RequestHandler(hass, - message=message, - max_tokens=call.max_tokens, - temperature=call.temperature, - detail=call.detail) - processor = MediaProcessor(hass, client) - client = await processor.add_visual_data(image_entities=call.image_entities, - image_paths=call.image_paths, - target_width=call.target_width, - include_filename=call.include_filename - ) - response = await client.make_request(call) + raise ServiceValidationError("Unsupported sensor entity type") + + call.message = f"Your job is to extract data from images. You can only respond with {data_type}. You must respond with one of the options! If unsure, choose the option that best matches. Answer the following question with the options provided: " + call.message + request = Request(hass, + message=call.message, + max_tokens=call.max_tokens, + temperature=call.temperature, + ) + processor = MediaProcessor(hass, request) + request = await processor.add_visual_data(image_entities=call.image_entities, + image_paths=call.image_paths, + target_width=call.target_width, + include_filename=call.include_filename + ) + response = await request.call(call) _LOGGER.info(f"Response: {response}") - # udpate sensor in data_call.data.get("sensor_entity") - await _update_sensor(hass, sensor_entity, response["response_text"]) + # update sensor in data_call.data.get("sensor_entity") + await _update_sensor(hass, sensor_entity, response["response_text"], type) return response # Register services diff --git a/custom_components/llmvision/calendar.py b/custom_components/llmvision/calendar.py index 9276cb8..70a2eaa 100644 --- a/custom_components/llmvision/calendar.py +++ b/custom_components/llmvision/calendar.py @@ -37,10 +37,10 @@ def __init__(self, hass: HomeAssistant, config_entry: ConfigEntry): self._attr_supported_features = (CalendarEntityFeature.DELETE_EVENT) # Path to the JSON file where events are stored self._file_path = os.path.join( - self.hass.config.path("custom_components/llmvision"), "events.json" + self.hass.config.path("llmvision"), "events.json" ) self.hass.loop.create_task(self.async_update()) - + def _ensure_datetime(self, dt): """Ensure the input is a datetime.datetime object.""" if isinstance(dt, datetime.date) and not isinstance(dt, datetime.datetime): diff --git a/custom_components/llmvision/config_flow.py b/custom_components/llmvision/config_flow.py index abb4e68..4458c26 100644 --- a/custom_components/llmvision/config_flow.py +++ b/custom_components/llmvision/config_flow.py @@ -1,11 +1,23 @@ from homeassistant import config_entries from homeassistant.helpers.selector import selector from homeassistant.exceptions import ServiceValidationError -from homeassistant.helpers.aiohttp_client import async_get_clientsession -import urllib.parse +from .providers import ( + OpenAI, + AzureOpenAI, + Anthropic, + Google, + Groq, + LocalAI, + Ollama, +) from .const import ( DOMAIN, CONF_OPENAI_API_KEY, + CONF_AZURE_API_KEY, + CONF_AZURE_BASE_URL, + CONF_AZURE_DEPLOYMENT, + CONF_AZURE_VERSION, + ENDPOINT_AZURE, CONF_ANTHROPIC_API_KEY, CONF_GOOGLE_API_KEY, CONF_GROQ_API_KEY, @@ -17,7 +29,6 @@ CONF_OLLAMA_HTTPS, CONF_CUSTOM_OPENAI_API_KEY, CONF_CUSTOM_OPENAI_ENDPOINT, - VERSION_ANTHROPIC, CONF_RETENTION_TIME, ) import voluptuous as vol @@ -26,203 +37,21 @@ _LOGGER = logging.getLogger(__name__) -class Validator: - def __init__(self, hass, user_input): - self.hass = hass - self.user_input = user_input - - async def _validate_api_key(self, api_key): - if not api_key or api_key == "": - _LOGGER.error("You need to provide a valid API key.") - raise ServiceValidationError("empty_api_key") - elif self.user_input["provider"] == "OpenAI": - header = {'Content-type': 'application/json', - 'Authorization': 'Bearer ' + api_key} - base_url = "api.openai.com" - endpoint = "/v1/models" - payload = {} - method = "GET" - elif self.user_input["provider"] == "Anthropic": - header = { - 'x-api-key': api_key, - 'content-type': 'application/json', - 'anthropic-version': VERSION_ANTHROPIC - } - payload = { - "model": "claude-3-haiku-20240307", - "messages": [ - {"role": "user", "content": "Hello, world"} - ], - "max_tokens": 50, - "temperature": 0.5 - } - base_url = "api.anthropic.com" - endpoint = "/v1/messages" - method = "POST" - elif self.user_input["provider"] == "Google": - header = {"content-type": "application/json"} - base_url = "generativelanguage.googleapis.com" - endpoint = f"/v1beta/models/gemini-1.5-flash-latest:generateContent?key={api_key}" - payload = { - "contents": [{ - "parts": [ - {"text": "Hello"} - ]} - ] - } - method = "POST" - elif self.user_input["provider"] == "Groq": - header = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' - } - base_url = "api.groq.com" - endpoint = "/openai/v1/chat/completions" - payload = {"messages": [ - {"role": "user", "content": "Hello"}], "model": "gemma-7b-it"} - method = "POST" - - return await self._handshake(base_url=base_url, endpoint=endpoint, protocol="https", header=header, payload=payload, expected_status=200, method=method) - - def _validate_provider(self): - if not self.user_input["provider"]: - raise ServiceValidationError("empty_mode") - - async def _handshake(self, base_url, endpoint, protocol="http", port="", header={}, payload={}, expected_status=200, method="GET"): - _LOGGER.debug( - f"Connecting to {protocol}://{base_url}{port}{endpoint}") - session = async_get_clientsession(self.hass) - url = f'{protocol}://{base_url}{port}{endpoint}' - try: - if method == "GET": - response = await session.get(url, headers=header) - elif method == "POST": - response = await session.post(url, headers=header, json=payload) - if response.status == expected_status: - return True - else: - _LOGGER.error( - f"Handshake failed with status: {response.status}") - return False - except Exception as e: - _LOGGER.error(f"Could not connect to {url}: {e}") - return False - - async def localai(self): - self._validate_provider() - if not self.user_input[CONF_LOCALAI_IP_ADDRESS]: - raise ServiceValidationError("empty_ip_address") - if not self.user_input[CONF_LOCALAI_PORT]: - raise ServiceValidationError("empty_port") - protocol = "https" if self.user_input[CONF_LOCALAI_HTTPS] else "http" - if not await self._handshake(base_url=self.user_input[CONF_LOCALAI_IP_ADDRESS], port=":"+str(self.user_input[CONF_LOCALAI_PORT]), protocol=protocol, endpoint="/readyz"): - _LOGGER.error("Could not connect to LocalAI server.") - raise ServiceValidationError("handshake_failed") - - async def ollama(self): - self._validate_provider() - if not self.user_input[CONF_OLLAMA_IP_ADDRESS]: - raise ServiceValidationError("empty_ip_address") - if not self.user_input[CONF_OLLAMA_PORT]: - raise ServiceValidationError("empty_port") - protocol = "https" if self.user_input[CONF_OLLAMA_HTTPS] else "http" - if not await self._handshake(base_url=self.user_input[CONF_OLLAMA_IP_ADDRESS], port=":"+str(self.user_input[CONF_OLLAMA_PORT]), protocol=protocol, endpoint="/api/tags"): - _LOGGER.error("Could not connect to Ollama server.") - raise ServiceValidationError("handshake_failed") - - async def openai(self): - self._validate_provider() - if not await self._validate_api_key(self.user_input[CONF_OPENAI_API_KEY]): - _LOGGER.error("Could not connect to OpenAI server.") - raise ServiceValidationError("handshake_failed") - - async def custom_openai(self): - self._validate_provider() - try: - url = urllib.parse.urlparse( - self.user_input[CONF_CUSTOM_OPENAI_ENDPOINT]) - protocol = url.scheme - base_url = url.hostname - path = url.path if url.path else "" - port = ":" + str(url.port) if url.port else "" - - endpoint = path + "/v1/models" - header = {'Content-type': 'application/json', - 'Authorization': 'Bearer ' + self.user_input[CONF_CUSTOM_OPENAI_API_KEY]} if CONF_CUSTOM_OPENAI_API_KEY in self.user_input else {} - except Exception as e: - _LOGGER.error(f"Could not parse endpoint: {e}") - raise ServiceValidationError("endpoint_parse_failed") - - _LOGGER.debug( - f"Connecting to: [protocol: {protocol}, base_url: {base_url}, port: {port}, endpoint: {endpoint}]") - - if not await self._handshake(base_url=base_url, port=port, protocol=protocol, endpoint=endpoint, header=header): - _LOGGER.error("Could not connect to Custom OpenAI server.") - raise ServiceValidationError("handshake_failed") - - async def anthropic(self): - self._validate_provider() - if not await self._validate_api_key(self.user_input[CONF_ANTHROPIC_API_KEY]): - _LOGGER.error("Could not connect to Anthropic server.") - raise ServiceValidationError("handshake_failed") - - async def google(self): - self._validate_provider() - if not await self._validate_api_key(self.user_input[CONF_GOOGLE_API_KEY]): - _LOGGER.error("Could not connect to Google server.") - raise ServiceValidationError("handshake_failed") - - async def groq(self): - self._validate_provider() - if not await self._validate_api_key(self.user_input[CONF_GROQ_API_KEY]): - _LOGGER.error("Could not connect to Groq server.") - raise ServiceValidationError("handshake_failed") - - async def semantic_index(self) -> bool: - # check if semantic_index is already configured - for uid in self.hass.data[DOMAIN]: - if 'retention_time' in self.hass.data[DOMAIN][uid]: - return False - return True - - def get_configured_providers(self): - providers = [] - try: - if self.hass.data[DOMAIN] is None: - return providers - except KeyError: - return providers - if CONF_OPENAI_API_KEY in self.hass.data[DOMAIN]: - providers.append("OpenAI") - if CONF_ANTHROPIC_API_KEY in self.hass.data[DOMAIN]: - providers.append("Anthropic") - if CONF_GOOGLE_API_KEY in self.hass.data[DOMAIN]: - providers.append("Google") - if CONF_LOCALAI_IP_ADDRESS in self.hass.data[DOMAIN] and CONF_LOCALAI_PORT in self.hass.data[DOMAIN]: - providers.append("LocalAI") - if CONF_OLLAMA_IP_ADDRESS in self.hass.data[DOMAIN] and CONF_OLLAMA_PORT in self.hass.data[DOMAIN]: - providers.append("Ollama") - if CONF_CUSTOM_OPENAI_ENDPOINT in self.hass.data[DOMAIN]: - providers.append("Custom OpenAI") - if CONF_GROQ_API_KEY in self.hass.data[DOMAIN]: - providers.append("Groq") - return providers - - class llmvisionConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): VERSION = 2 async def handle_provider(self, provider): provider_steps = { - "Event Calendar": self.async_step_semantic_index, - "OpenAI": self.async_step_openai, "Anthropic": self.async_step_anthropic, + "Azure": self.async_step_azure, + "Custom OpenAI": self.async_step_custom_openai, + "Event Calendar": self.async_step_semantic_index, "Google": self.async_step_google, "Groq": self.async_step_groq, - "Ollama": self.async_step_ollama, "LocalAI": self.async_step_localai, - "Custom OpenAI": self.async_step_custom_openai, + "Ollama": self.async_step_ollama, + "OpenAI": self.async_step_openai, } step_method = provider_steps.get(provider) @@ -236,7 +65,7 @@ async def async_step_user(self, user_input=None): data_schema = vol.Schema({ vol.Required("provider", default="Event Calendar"): selector({ "select": { - "options": ["Event Calendar", "OpenAI", "Anthropic", "Google", "Groq", "Ollama", "LocalAI", "Custom OpenAI"], + "options": ["Anthropic", "Azure", "Google", "Groq", "LocalAI", "Ollama", "OpenAI", "Custom OpenAI", "Event Calendar"], "mode": "dropdown", "sort": False, "custom_value": False @@ -265,9 +94,13 @@ async def async_step_localai(self, user_input=None): if user_input is not None: # save provider to user_input user_input["provider"] = self.init_info["provider"] - validator = Validator(self.hass, user_input) try: - await validator.localai() + localai = LocalAI(self.hass, endpoint={ + 'ip_address': user_input[CONF_LOCALAI_IP_ADDRESS], + 'port': user_input[CONF_LOCALAI_PORT], + 'https': user_input[CONF_LOCALAI_HTTPS] + }) + await localai.validate() # add the mode to user_input return self.async_create_entry(title=f"LocalAI ({user_input[CONF_LOCALAI_IP_ADDRESS]})", data=user_input) except ServiceValidationError as e: @@ -293,9 +126,13 @@ async def async_step_ollama(self, user_input=None): if user_input is not None: # save provider to user_input user_input["provider"] = self.init_info["provider"] - validator = Validator(self.hass, user_input) try: - await validator.ollama() + ollama = Ollama(self.hass, endpoint={ + 'ip_address': user_input[CONF_OLLAMA_IP_ADDRESS], + 'port': user_input[CONF_OLLAMA_PORT], + 'https': user_input[CONF_OLLAMA_HTTPS] + }) + await ollama.validate() # add the mode to user_input return self.async_create_entry(title=f"Ollama ({user_input[CONF_OLLAMA_IP_ADDRESS]})", data=user_input) except ServiceValidationError as e: @@ -319,9 +156,10 @@ async def async_step_openai(self, user_input=None): if user_input is not None: # save provider to user_input user_input["provider"] = self.init_info["provider"] - validator = Validator(self.hass, user_input) try: - await validator.openai() + openai = OpenAI( + self.hass, api_key=user_input[CONF_OPENAI_API_KEY]) + await openai.validate() # add the mode to user_input user_input["provider"] = self.init_info["provider"] return self.async_create_entry(title="OpenAI", data=user_input) @@ -338,6 +176,41 @@ async def async_step_openai(self, user_input=None): data_schema=data_schema, ) + async def async_step_azure(self, user_input=None): + data_schema = vol.Schema({ + vol.Required(CONF_AZURE_API_KEY): str, + vol.Required(CONF_AZURE_BASE_URL, default="https://domain.openai.azure.com/"): str, + vol.Required(CONF_AZURE_DEPLOYMENT, default="deployment"): str, + vol.Required(CONF_AZURE_VERSION, default="2024-10-01-preview"): str, + }) + + if user_input is not None: + # save provider to user_input + user_input["provider"] = self.init_info["provider"] + try: + azure = AzureOpenAI(self.hass, api_key=user_input[CONF_AZURE_API_KEY], endpoint={ + 'base_url': ENDPOINT_AZURE, + 'endpoint': user_input[CONF_AZURE_BASE_URL], + 'deployment': user_input[CONF_AZURE_DEPLOYMENT], + 'api_version': user_input[CONF_AZURE_VERSION] + }) + await azure.validate() + # add the mode to user_input + user_input["provider"] = self.init_info["provider"] + return self.async_create_entry(title="Azure", data=user_input) + except ServiceValidationError as e: + _LOGGER.error(f"Validation failed: {e}") + return self.async_show_form( + step_id="azure", + data_schema=data_schema, + errors={"base": "handshake_failed"} + ) + + return self.async_show_form( + step_id="azure", + data_schema=data_schema, + ) + async def async_step_anthropic(self, user_input=None): data_schema = vol.Schema({ vol.Required(CONF_ANTHROPIC_API_KEY): str, @@ -346,9 +219,10 @@ async def async_step_anthropic(self, user_input=None): if user_input is not None: # save provider to user_input user_input["provider"] = self.init_info["provider"] - validator = Validator(self.hass, user_input) try: - await validator.anthropic() + anthropic = Anthropic( + self.hass, api_key=user_input[CONF_ANTHROPIC_API_KEY]) + await anthropic.validate() # add the mode to user_input user_input["provider"] = self.init_info["provider"] return self.async_create_entry(title="Anthropic Claude", data=user_input) @@ -373,9 +247,10 @@ async def async_step_google(self, user_input=None): if user_input is not None: # save provider to user_input user_input["provider"] = self.init_info["provider"] - validator = Validator(self.hass, user_input) try: - await validator.google() + google = Google( + self.hass, api_key=user_input[CONF_GOOGLE_API_KEY]) + await google.validate() # add the mode to user_input user_input["provider"] = self.init_info["provider"] return self.async_create_entry(title="Google Gemini", data=user_input) @@ -400,9 +275,9 @@ async def async_step_groq(self, user_input=None): if user_input is not None: # save provider to user_input user_input["provider"] = self.init_info["provider"] - validator = Validator(self.hass, user_input) try: - await validator.groq() + groq = Groq(self.hass, api_key=user_input[CONF_GROQ_API_KEY]) + await groq.validate() # add the mode to user_input user_input["provider"] = self.init_info["provider"] return self.async_create_entry(title="Groq", data=user_input) @@ -428,9 +303,11 @@ async def async_step_custom_openai(self, user_input=None): if user_input is not None: # save provider to user_input user_input["provider"] = self.init_info["provider"] - validator = Validator(self.hass, user_input) try: - await validator.custom_openai() + custom_openai = OpenAI(self.hass, api_key=user_input[CONF_CUSTOM_OPENAI_API_KEY], endpoint={ + 'base_url': user_input[CONF_CUSTOM_OPENAI_ENDPOINT] + }) + await custom_openai.validate() # add the mode to user_input user_input["provider"] = self.init_info["provider"] return self.async_create_entry(title="Custom OpenAI compatible Provider", data=user_input) @@ -453,19 +330,12 @@ async def async_step_semantic_index(self, user_input=None): }) if user_input is not None: user_input["provider"] = self.init_info["provider"] - validator = Validator(self.hass, user_input) - try: - if not await validator.semantic_index(): - return self.async_abort(reason="already_configured") + + for uid in self.hass.data[DOMAIN]: + if 'retention_time' in self.hass.data[DOMAIN][uid]: + self.async_abort(reason="already_configured") # add the mode to user_input - return self.async_create_entry(title="LLM Vision Events", data=user_input) - except ServiceValidationError as e: - _LOGGER.error(f"Validation failed: {e}") - return self.async_show_form( - step_id="semantic_index", - data_schema=data_schema, - errors={"base": "handshake_failed"} - ) + return self.async_create_entry(title="LLM Vision Events", data=user_input) return self.async_show_form( step_id="semantic_index", diff --git a/custom_components/llmvision/const.py b/custom_components/llmvision/const.py index 6e0a61f..2cc05bf 100644 --- a/custom_components/llmvision/const.py +++ b/custom_components/llmvision/const.py @@ -5,6 +5,10 @@ # Configuration values from setup CONF_OPENAI_API_KEY = 'openai_api_key' +CONF_AZURE_API_KEY = 'azure_api_key' +CONF_AZURE_BASE_URL = 'azure_base_url' +CONF_AZURE_DEPLOYMENT = 'azure_deployment' +CONF_AZURE_VERSION = 'azure_version' CONF_ANTHROPIC_API_KEY = 'anthropic_api_key' CONF_GOOGLE_API_KEY = 'google_api_key' CONF_GROQ_API_KEY = 'groq_api_key' @@ -34,22 +38,16 @@ FRIGATE_RETRY_ATTEMPTS = 'frigate_retry_attempts' FRIGATE_RETRY_SECONDS = 'frigate_retry_seconds' MAX_FRAMES = 'max_frames' -DETAIL = 'detail' TEMPERATURE = 'temperature' INCLUDE_FILENAME = 'include_filename' EXPOSE_IMAGES = 'expose_images' EXPOSE_IMAGES_PERSIST = 'expose_images_persist' +GENERATE_TITLE = 'generate_title' SENSOR_ENTITY = 'sensor_entity' # Error messages -ERROR_OPENAI_NOT_CONFIGURED = "OpenAI is not configured" -ERROR_ANTHROPIC_NOT_CONFIGURED = "Anthropic is not configured" -ERROR_GOOGLE_NOT_CONFIGURED = "Google is not configured" -ERROR_GROQ_NOT_CONFIGURED = "Groq is not configured" +ERROR_NOT_CONFIGURED = "{provider} is not configured" ERROR_GROQ_MULTIPLE_IMAGES = "Groq does not support videos or streams" -ERROR_LOCALAI_NOT_CONFIGURED = "LocalAI is not configured" -ERROR_OLLAMA_NOT_CONFIGURED = "Ollama is not configured" -ERROR_CUSTOM_OPENAI_NOT_CONFIGURED = "Custom OpenAI provider is not configured" ERROR_NO_IMAGE_INPUT = "No image input provided" ERROR_HANDSHAKE_FAILED = "Connection could not be established" @@ -63,4 +61,5 @@ ENDPOINT_GOOGLE = "https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={api_key}" ENDPOINT_GROQ = "https://api.groq.com/openai/v1/chat/completions" ENDPOINT_LOCALAI = "{protocol}://{ip_address}:{port}/v1/chat/completions" -ENDPOINT_OLLAMA = "{protocol}://{ip_address}:{port}/api/chat" \ No newline at end of file +ENDPOINT_OLLAMA = "{protocol}://{ip_address}:{port}/api/chat" +ENDPOINT_AZURE = "{base_url}openai/deployments/{deployment}/chat/completions?api-version={api_version}" diff --git a/custom_components/llmvision/manifest.json b/custom_components/llmvision/manifest.json index aca2eb5..f7001e4 100644 --- a/custom_components/llmvision/manifest.json +++ b/custom_components/llmvision/manifest.json @@ -6,5 +6,5 @@ "documentation": "https://github.com/valentinfrlch/ha-llmvision", "iot_class": "cloud_polling", "issue_tracker": "https://github.com/valentinfrlch/ha-llmvision/issues", - "version": "1.3.1" + "version": "1.3.5" } \ No newline at end of file diff --git a/custom_components/llmvision/media_handlers.py b/custom_components/llmvision/media_handlers.py index f58f436..d71511e 100644 --- a/custom_components/llmvision/media_handlers.py +++ b/custom_components/llmvision/media_handlers.py @@ -5,6 +5,7 @@ import logging import time import asyncio +from homeassistant.helpers.aiohttp_client import async_get_clientsession from functools import partial from PIL import Image, UnidentifiedImageError import numpy as np @@ -19,6 +20,7 @@ class MediaProcessor: def __init__(self, hass, client): self.hass = hass + self.session = async_get_clientsession(self.hass) self.client = client self.base64_images = [] self.filenames = [] @@ -135,6 +137,28 @@ async def resize_image(self, target_width, image_path=None, image_data=None, img base64_image = await self._encode_image(img) return base64_image + + async def _fetch(self, url, max_retries=2, retry_delay=1): + """Fetch image from url and return image data""" + retries = 0 + while retries < max_retries: + _LOGGER.info( + f"Fetching {url} (attempt {retries + 1}/{max_retries})") + try: + response = await self.session.get(url) + if response.status != 200: + _LOGGER.warning( + f"Couldn't fetch frame (status code: {response.status})") + retries += 1 + await asyncio.sleep(retry_delay) + continue + data = await response.read() + return data + except Exception as e: + _LOGGER.error(f"Fetch failed: {e}") + retries += 1 + await asyncio.sleep(retry_delay) + _LOGGER.warning(f"Failed to fetch {url} after {max_retries} retries") async def record(self, image_entities, duration, max_frames, target_width, include_filename, expose_images): """Wrapper for client.add_frame with integrated recorder @@ -162,7 +186,7 @@ async def record_camera(image_entity, camera_number): frame_url = base_url + \ self.hass.states.get(image_entity).attributes.get( 'entity_picture') - frame_data = await self.client._fetch(frame_url) + frame_data = await self._fetch(frame_url) # Skip frame if fetch failed if not frame_data: @@ -251,7 +275,7 @@ async def add_images(self, image_entities, image_paths, target_width, include_fi image_url = base_url + \ self.hass.states.get(image_entity).attributes.get( 'entity_picture') - image_data = await self.client._fetch(image_url) + image_data = await self._fetch(image_url) # Skip frame if fetch failed if not image_data: @@ -307,7 +331,8 @@ async def add_videos(self, video_paths, event_ids, max_frames, target_width, inc try: base_url = get_url(self.hass) frigate_url = base_url + "/api/frigate/notifications/" + event_id + "/clip.mp4" - clip_data = await self.client._fetch(frigate_url, max_retries=frigate_retry_attempts, retry_delay=frigate_retry_seconds) + + clip_data = await self._fetch(frigate_url, max_retries=frigate_retry_attempts, retry_delay=frigate_retry_seconds) if not clip_data: raise ServiceValidationError( diff --git a/custom_components/llmvision/providers.py b/custom_components/llmvision/providers.py new file mode 100644 index 0000000..3e939eb --- /dev/null +++ b/custom_components/llmvision/providers.py @@ -0,0 +1,733 @@ +from abc import ABC, abstractmethod +from homeassistant.exceptions import ServiceValidationError +from homeassistant.helpers.aiohttp_client import async_get_clientsession +import logging +import inspect +from .const import ( + DOMAIN, + CONF_OPENAI_API_KEY, + CONF_AZURE_API_KEY, + CONF_AZURE_BASE_URL, + CONF_AZURE_DEPLOYMENT, + CONF_AZURE_VERSION, + CONF_ANTHROPIC_API_KEY, + CONF_GOOGLE_API_KEY, + CONF_GROQ_API_KEY, + CONF_LOCALAI_IP_ADDRESS, + CONF_LOCALAI_PORT, + CONF_LOCALAI_HTTPS, + CONF_OLLAMA_IP_ADDRESS, + CONF_OLLAMA_PORT, + CONF_OLLAMA_HTTPS, + CONF_CUSTOM_OPENAI_ENDPOINT, + CONF_CUSTOM_OPENAI_API_KEY, + VERSION_ANTHROPIC, + ENDPOINT_OPENAI, + ENDPOINT_AZURE, + ENDPOINT_ANTHROPIC, + ENDPOINT_GOOGLE, + ENDPOINT_LOCALAI, + ENDPOINT_OLLAMA, + ENDPOINT_GROQ, + ERROR_NOT_CONFIGURED, + ERROR_GROQ_MULTIPLE_IMAGES, + ERROR_NO_IMAGE_INPUT, +) + +_LOGGER = logging.getLogger(__name__) + + +class Request: + def __init__(self, hass, message, max_tokens, temperature): + self.session = async_get_clientsession(hass) + self.hass = hass + self.message = message + self.max_tokens = max_tokens + self.temperature = temperature + self.base64_images = [] + self.filenames = [] + + @staticmethod + def sanitize_data(data): + """Remove long string data from request data to reduce log size""" + if isinstance(data, dict): + return {key: Request.sanitize_data(value) for key, value in data.items()} + elif isinstance(data, list): + return [Request.sanitize_data(item) for item in data] + elif isinstance(data, str) and len(data) > 400 and data.count(' ') < 50: + return '' + else: + return data + + @staticmethod + def get_provider(hass, provider_uid): + """Translate UID of the config entry into provider name""" + if DOMAIN not in hass.data: + return None + + entry_data = hass.data[DOMAIN].get(provider_uid) + if not entry_data: + return None + + if CONF_ANTHROPIC_API_KEY in entry_data: + return "Anthropic" + elif CONF_AZURE_API_KEY in entry_data: + return "Azure" + elif CONF_CUSTOM_OPENAI_API_KEY in entry_data: + return "Custom OpenAI" + elif CONF_GOOGLE_API_KEY in entry_data: + return "Google" + elif CONF_GROQ_API_KEY in entry_data: + return "Groq" + elif CONF_LOCALAI_IP_ADDRESS in entry_data: + return "LocalAI" + elif CONF_OLLAMA_IP_ADDRESS in entry_data: + return "Ollama" + elif CONF_OPENAI_API_KEY in entry_data: + return "OpenAI" + + return None + + def validate(self, call) -> None | ServiceValidationError: + """Validate call data""" + # Check image input + if not call.base64_images: + raise ServiceValidationError(ERROR_NO_IMAGE_INPUT) + # Check if single image is provided for Groq + if len(call.base64_images) > 1 and self.get_provider(self.hass, call.provider) == 'Groq': + raise ServiceValidationError(ERROR_GROQ_MULTIPLE_IMAGES) + # Check provider is configured + if not call.provider: + raise ServiceValidationError(ERROR_NOT_CONFIGURED) + + async def call(self, call): + """ + Forwards a request to the specified provider and optionally generates a title. + + Args: + call (object): The call object containing request details. + + Raises: + ServiceValidationError: If the provider is invalid. + + Returns: + dict: A dictionary containing the generated title (if any) and the response text. + """ + entry_id = call.provider + config = self.hass.data.get(DOMAIN).get(entry_id) + + provider = Request.get_provider(self.hass, entry_id) + call.base64_images = self.base64_images + call.filenames = self.filenames + + self.validate(call) + + gen_title_prompt = "Your job is to generate a title in the form ' seen' for texts. Do not mention the time, do not speculate. Generate a title for this text: {response}" + + if provider == 'OpenAI': + api_key = config.get(CONF_OPENAI_API_KEY) + provider_instance = OpenAI(hass=self.hass, api_key=api_key) + + elif provider == 'Azure': + api_key = config.get(CONF_AZURE_API_KEY) + endpoint = config.get(CONF_AZURE_BASE_URL) + deployment = config.get(CONF_AZURE_DEPLOYMENT) + version = config.get(CONF_AZURE_VERSION) + + provider_instance = AzureOpenAI(self.hass, + api_key=api_key, + endpoint={ + 'base_url': ENDPOINT_AZURE, + 'endpoint': endpoint, + 'deployment': deployment, + 'api_version': version + }) + + elif provider == 'Anthropic': + api_key = config.get(CONF_ANTHROPIC_API_KEY) + + provider_instance = Anthropic(self.hass, api_key=api_key) + + elif provider == 'Google': + api_key = config.get(CONF_GOOGLE_API_KEY) + + provider_instance = Google(self.hass, api_key=api_key, endpoint={ + 'base_url': ENDPOINT_GOOGLE, 'model': call.model + }) + + elif provider == 'Groq': + api_key = config.get(CONF_GROQ_API_KEY) + + provider_instance = Groq(self.hass, api_key=api_key) + + elif provider == 'LocalAI': + ip_address = config.get(CONF_LOCALAI_IP_ADDRESS) + port = config.get(CONF_LOCALAI_PORT) + https = config.get(CONF_LOCALAI_HTTPS, False) + + provider_instance = LocalAI(self.hass, endpoint={ + 'ip_address': ip_address, + 'port': port, + 'https': https + }) + + elif provider == 'Ollama': + ip_address = config.get(CONF_OLLAMA_IP_ADDRESS) + port = config.get(CONF_OLLAMA_PORT) + https = config.get(CONF_OLLAMA_HTTPS, False) + + provider_instance = Ollama(self.hass, endpoint={ + 'ip_address': ip_address, + 'port': port, + 'https': https + }) + response_text = await provider_instance.vision_request(call) + if call.generate_title: + call.message = gen_title_prompt.format(response=response_text) + gen_title = await provider_instance.title_request(call) + + elif provider == 'Custom OpenAI': + api_key = config.get(CONF_CUSTOM_OPENAI_API_KEY, "") + endpoint = config.get( + CONF_CUSTOM_OPENAI_ENDPOINT) + "/v1/chat/completions" + provider_instance = OpenAI( + self.hass, api_key=api_key, endpoint=endpoint) + + else: + raise ServiceValidationError("invalid_provider") + + # Make call to provider + call.model = call.model if call.model and call.model != 'None' else provider_instance.default_model + response_text = await provider_instance.vision_request(call) + + if call.generate_title: + call.message = gen_title_prompt.format(response=response_text) + gen_title = await provider_instance.title_request(call) + + return {"title": gen_title.replace(".", "").replace("'", ""), "response_text": response_text} + else: + return {"response_text": response_text} + + def add_frame(self, base64_image, filename): + self.base64_images.append(base64_image) + self.filenames.append(filename) + + async def _resolve_error(self, response, provider): + """Translate response status to error message""" + import json + full_response_text = await response.text() + _LOGGER.info(f"[INFO] Full Response: {full_response_text}") + + try: + response_json = json.loads(full_response_text) + if provider == 'anthropic': + error_info = response_json.get('error', {}) + error_message = f"{error_info.get('type', 'Unknown error')}: {error_info.get('message', 'Unknown error')}" + elif provider == 'ollama': + error_message = response_json.get('error', 'Unknown error') + else: + error_info = response_json.get('error', {}) + error_message = error_info.get('message', 'Unknown error') + except json.JSONDecodeError: + error_message = 'Unknown error' + + return error_message + + +class Provider(ABC): + """ + Abstract base class for providers + + Args: + hass (object): Home Assistant instance + api_key (str, optional): API key for the provider, defaults to "" + endpoint (dict, optional): Endpoint configuration for the provider + """ + + def __init__(self, + hass, + api_key="", + endpoint={ + 'base_url': "", + 'deployment': "", + 'api_version': "", + 'ip_address': "", + 'port': "", + 'https': False + } + ): + self.hass = hass + self.session = async_get_clientsession(hass) + self.api_key = api_key + self.endpoint = endpoint + + @abstractmethod + async def _make_request(self, data) -> str: + pass + + @abstractmethod + def _prepare_vision_data(self, call) -> dict: + pass + + @abstractmethod + def _prepare_text_data(self, call) -> dict: + pass + + @abstractmethod + async def validate(self) -> None | ServiceValidationError: + pass + + async def vision_request(self, call) -> str: + data = self._prepare_vision_data(call) + return await self._make_request(data) + + async def title_request(self, call) -> str: + call.temperature = 0.1 + call.max_tokens = 3 + data = self._prepare_text_data(call) + return await self._make_request(data) + + async def _post(self, url, headers, data) -> dict: + """Post data to url and return response data""" + _LOGGER.info(f"Request data: {Request.sanitize_data(data)}") + + try: + _LOGGER.info(f"Posting to {url}") + response = await self.session.post(url, headers=headers, json=data) + except Exception as e: + raise ServiceValidationError(f"Request failed: {e}") + + if response.status != 200: + frame = inspect.stack()[1] + provider = frame.frame.f_locals["self"].__class__.__name__.lower() + parsed_response = await self._resolve_error(response, provider) + raise ServiceValidationError(parsed_response) + else: + response_data = await response.json() + _LOGGER.info(f"Response data: {response_data}") + return response_data + + async def _resolve_error(self, response, provider) -> str: + """Translate response status to error message""" + import json + full_response_text = await response.text() + _LOGGER.info(f"[INFO] Full Response: {full_response_text}") + + try: + response_json = json.loads(full_response_text) + if provider == 'anthropic': + error_info = response_json.get('error', {}) + error_message = f"{error_info.get('type', 'Unknown error')}: {error_info.get('message', 'Unknown error')}" + elif provider == 'ollama': + error_message = response_json.get('error', 'Unknown error') + else: + error_info = response_json.get('error', {}) + error_message = error_info.get('message', 'Unknown error') + except json.JSONDecodeError: + error_message = 'Unknown error' + + return error_message + + +class OpenAI(Provider): + def __init__(self, hass, api_key="", endpoint={'base_url': ENDPOINT_OPENAI}): + super().__init__(hass, api_key, endpoint=endpoint) + self.default_model = "gpt-4o-mini" + + def _generate_headers(self) -> dict: + return {'Content-type': 'application/json', + 'Authorization': 'Bearer ' + self.api_key} + + async def _make_request(self, data) -> str: + headers = self._generate_headers() + response = await self._post(url=self.endpoint.get('base_url'), headers=headers, data=data) + response_text = response.get( + "choices")[0].get("message").get("content") + return response_text + + def _prepare_vision_data(self, call) -> list: + payload = {"model": call.model, + "messages": [{"role": "user", "content": []}], + "max_tokens": call.max_tokens, + "temperature": call.temperature + } + + for image, filename in zip(call.base64_images, call.filenames): + tag = ("Image " + str(call.base64_images.index(image) + 1) + ) if filename == "" else filename + payload["messages"][0]["content"].append( + {"type": "text", "text": tag + ":"}) + payload["messages"][0]["content"].append({"type": "image_url", "image_url": { + "url": f"data:image/jpeg;base64,{image}"}}) + payload["messages"][0]["content"].append( + {"type": "text", "text": call.message}) + return payload + + def _prepare_text_data(self, call) -> list: + return { + "model": call.model, + "messages": [{"role": "user", "content": [{"type": "text", "text": call.message}]}], + "max_tokens": call.max_tokens, + "temperature": call.temperature + } + + async def validate(self) -> None | ServiceValidationError: + if self.api_key: + headers = self._generate_headers() + data = { + "model": "gpt-4o-mini", + "messages": [{"role": "user", "content": [{"type": "text", "text": "Hi"}]}], + "max_tokens": 1, + "temperature": 0.5 + } + await self._post(url=self.endpoint.get('base_url'), headers=headers, data=data) + else: + raise ServiceValidationError("empty_api_key") + + +class AzureOpenAI(Provider): + def __init__(self, hass, api_key="", endpoint={'base_url': ENDPOINT_AZURE, 'endpoint': "", 'deployment': "", 'api_version': ""}): + super().__init__(hass, api_key, endpoint) + self.default_model = "gpt-4o-mini" + + def _generate_headers(self) -> dict: + return {'Content-type': 'application/json', + 'api-key': self.api_key} + + async def _make_request(self, data) -> str: + headers = self._generate_headers() + endpoint = self.endpoint.get("base_url").format( + base_url=self.endpoint.get("endpoint"), + deployment=self.endpoint.get("deployment"), + api_version=self.endpoint.get("api_version") + ) + + response = await self._post(url=endpoint, headers=headers, data=data) + response_text = response.get( + "choices")[0].get("message").get("content") + return response_text + + def _prepare_vision_data(self, call) -> list: + payload = {"messages": [{"role": "user", "content": []}], + "max_tokens": call.max_tokens, + "temperature": call.temperature, + "stream": False + } + for image, filename in zip(call.base64_images, call.filenames): + tag = ("Image " + str(call.base64_images.index(image) + 1) + ) if filename == "" else filename + payload["messages"][0]["content"].append( + {"type": "text", "text": tag + ":"}) + payload["messages"][0]["content"].append({"type": "image_url", "image_url": { + "url": f"data:image/jpeg;base64,{image}"}}) + payload["messages"][0]["content"].append( + {"type": "text", "text": call.message}) + return payload + + def _prepare_text_data(self, call) -> list: + return {"messages": [{"role": "user", "content": [{"type": "text", "text": call.message}]}], + "max_tokens": call.max_tokens, + "temperature": call.temperature, + "stream": False + } + + async def validate(self) -> None | ServiceValidationError: + if not self.api_key: + raise ServiceValidationError("empty_api_key") + + endpoint = self.endpoint.get("base_url").format( + base_url=self.endpoint.get("endpoint"), + deployment=self.endpoint.get("deployment"), + api_version=self.endpoint.get("api_version") + ) + headers = self._generate_headers() + data = {"messages": [{"role": "user", "content": [{"type": "text", "text": "Hi"}]}], + "max_tokens": 1, + "temperature": 0.5, + "stream": False + } + await self._post(url=endpoint, headers=headers, data=data) + + +class Anthropic(Provider): + def __init__(self, hass, api_key=""): + super().__init__(hass, api_key) + self.default_model = "claude-3-5-sonnet-latest" + + def _generate_headers(self) -> dict: + return { + 'content-type': 'application/json', + 'x-api-key': self.api_key, + 'anthropic-version': VERSION_ANTHROPIC + } + + async def _make_request(self, data) -> str: + headers = self._generate_headers() + response = await self._post(url=ENDPOINT_ANTHROPIC, headers=headers, data=data) + response_text = response.get("content")[0].get("text") + return response_text + + def _prepare_vision_data(self, call) -> dict: + data = { + "model": call.model, + "messages": [{"role": "user", "content": []}], + "max_tokens": call.max_tokens, + "temperature": call.temperature + } + for image, filename in zip(call.base64_images, call.filenames): + tag = ("Image " + str(call.base64_images.index(image) + 1) + ) if filename == "" else filename + data["messages"][0]["content"].append( + {"type": "text", "text": tag + ":"}) + data["messages"][0]["content"].append({"type": "image", "source": { + "type": "base64", "media_type": "image/jpeg", "data": f"{image}"}}) + data["messages"][0]["content"].append( + {"type": "text", "text": call.message}) + return data + + def _prepare_text_data(self, call) -> dict: + return { + "model": call.model, + "messages": [{"role": "user", "content": [{"type": "text", "text": call.message}]}], + "max_tokens": call.max_tokens, + "temperature": call.temperature + } + + async def validate(self) -> None | ServiceValidationError: + if not self.api_key: + raise ServiceValidationError("empty_api_key") + + header = self._generate_headers() + payload = { + "model": "claude-3-haiku-20240307", + "messages": [ + {"role": "user", "content": "Hi"} + ], + "max_tokens": 1, + "temperature": 0.5 + } + await self._post(url=f"https://api.anthropic.com/v1/messages", headers=header, data=payload) + + +class Google(Provider): + def __init__(self, hass, api_key="", endpoint={'base_url': ENDPOINT_GOOGLE, 'model': "gemini-1.5-flash-latest"}): + super().__init__(hass, api_key, endpoint) + self.default_model = "gemini-1.5-flash-latest" + + def _generate_headers(self) -> dict: + return {'content-type': 'application/json'} + + async def _make_request(self, data) -> str: + endpoint = self.endpoint.get('base_url').format( + model=self.endpoint.get('model'), api_key=self.api_key) + + headers = self._generate_headers() + response = await self._post(url=endpoint, headers=headers, data=data) + response_text = response.get("candidates")[0].get( + "content").get("parts")[0].get("text") + return response_text + + def _prepare_vision_data(self, call) -> dict: + data = {"contents": [], "generationConfig": { + "maxOutputTokens": call.max_tokens, "temperature": call.temperature}} + for image, filename in zip(call.base64_images, call.filenames): + tag = ("Image " + str(call.base64_images.index(image) + 1) + ) if filename == "" else filename + data["contents"].append({"role": "user", "parts": [ + {"text": tag + ":"}, {"inline_data": {"mime_type": "image/jpeg", "data": image}}]}) + data["contents"].append( + {"role": "user", "parts": [{"text": call.message}]}) + return data + + def _prepare_text_data(self, call) -> dict: + return { + "contents": [{"role": "user", "parts": [{"text": call.message + ":"}]}], + "generationConfig": {"maxOutputTokens": call.max_tokens, "temperature": call.temperature} + } + + async def validate(self) -> None | ServiceValidationError: + if not self.api_key: + raise ServiceValidationError("empty_api_key") + + headers = self._generate_headers() + data = { + "contents": [{"role": "user", "parts": [{"text": "Hi"}]}], + "generationConfig": {"maxOutputTokens": 1, "temperature": 0.5} + } + await self._post(url=self.endpoint.get('base_url').format(model=self.endpoint.get('model'), api_key=self.api_key), headers=headers, data=data) + + +class Groq(Provider): + def __init__(self, hass, api_key=""): + super().__init__(hass, api_key) + self.default_model = "llama-3.2-11b-vision-preview" + + def _generate_headers(self) -> dict: + return {'Content-type': 'application/json', 'Authorization': 'Bearer ' + self.api_key} + + async def _make_request(self, data) -> str: + headers = self._generate_headers() + response = await self._post(url=ENDPOINT_GROQ, headers=headers, data=data) + response_text = response.get( + "choices")[0].get("message").get("content") + return response_text + + def _prepare_vision_data(self, call) -> dict: + first_image = call.base64_images[0] + data = { + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": call.message}, + {"type": "image_url", "image_url": { + "url": f"data:image/jpeg;base64,{first_image}"}} + ] + } + ], + "model": call.model + } + return data + + def _prepare_text_data(self, call) -> dict: + return { + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": call.message} + ] + } + ], + "model": call.model + } + + async def validate(self) -> None | ServiceValidationError: + if not self.api_key: + raise ServiceValidationError("empty_api_key") + headers = self._generate_headers() + data = { + "model": "llama3-8b-8192", + "messages": [{ + "role": "user", + "content": "Hi" + }] + } + await self._post(url=ENDPOINT_GROQ, headers=headers, data=data) + + +class LocalAI(Provider): + def __init__(self, hass, api_key="", endpoint={'ip_address': "", 'port': "", 'https': False}): + super().__init__(hass, api_key, endpoint) + self.default_model = "gpt-4-vision-preview" + + async def _make_request(self, data) -> str: + endpoint = ENDPOINT_LOCALAI.format( + protocol="https" if self.endpoint.get("https") else "http", + ip_address=self.endpoint.get("ip_address"), + port=self.endpoint.get("port") + ) + + headers = {} + response = await self._post(url=endpoint, headers=headers, data=data) + response_text = response.get( + "choices")[0].get("message").get("content") + return response_text + + def _prepare_vision_data(self, call) -> dict: + data = {"model": call.model, "messages": [{"role": "user", "content": [ + ]}], "max_tokens": call.max_tokens, "temperature": call.temperature} + for image, filename in zip(call.base64_images, call.filenames): + tag = ("Image " + str(call.base64_images.index(image) + 1) + ) if filename == "" else filename + data["messages"][0]["content"].append( + {"type": "text", "text": tag + ":"}) + data["messages"][0]["content"].append( + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image}"}}) + data["messages"][0]["content"].append( + {"type": "text", "text": call.message}) + return data + + def _prepare_text_data(self, call) -> dict: + return { + "model": call.model, + "messages": [{"role": "user", "content": [{"type": "text", "text": call.message}]}], + "max_tokens": call.max_tokens, + "temperature": call.temperature + } + + async def validate(self) -> None | ServiceValidationError: + if not self.endpoint.get("ip_address") or not self.endpoint.get("port"): + raise ServiceValidationError('handshake_failed') + session = async_get_clientsession(self.hass) + ip_address = self.endpoint.get("ip_address") + port = self.endpoint.get("port") + protocol = "https" if self.endpoint.get("https") else "http" + + try: + response = await session.get(f"{protocol}://{ip_address}:{port}/readyz") + if response.status != 200: + raise ServiceValidationError('handshake_failed') + except Exception: + raise ServiceValidationError('handshake_failed') + + +class Ollama(Provider): + def __init__(self, hass, api_key="", endpoint={'ip_address': "0.0.0.0", 'port': "11434", 'https': False}): + super().__init__(hass, api_key, endpoint) + self.default_model = "minicpm-v" + + async def _make_request(self, data) -> str: + https = self.endpoint.get("https") + ip_address = self.endpoint.get("ip_address") + port = self.endpoint.get("port") + protocol = "https" if https else "http" + endpoint = ENDPOINT_OLLAMA.format( + ip_address=ip_address, + port=port, + protocol=protocol + ) + + response = await self._post(url=endpoint, headers={}, data=data) + response_text = response.get("message").get("content") + return response_text + + def _prepare_vision_data(self, call) -> dict: + data = {"model": call.model, "messages": [], "stream": False, "options": { + "num_predict": call.max_tokens, "temperature": call.temperature}} + for image, filename in zip(call.base64_images, call.filenames): + tag = ("Image " + str(call.base64_images.index(image) + 1) + ) if filename == "" else filename + image_message = {"role": "user", + "content": tag + ":", "images": [image]} + data["messages"].append(image_message) + prompt_message = {"role": "user", "content": call.message} + data["messages"].append(prompt_message) + return data + + def _prepare_text_data(self, call) -> dict: + return { + "model": call.model, + "messages": [{"role": "user", "content": call.message}], + "stream": False, + "options": {"num_predict": call.max_tokens, "temperature": call.temperature} + } + + async def validate(self) -> None | ServiceValidationError: + if not self.endpoint.get("ip_address") or not self.endpoint.get("port"): + raise ServiceValidationError('handshake_failed') + session = async_get_clientsession(self.hass) + ip_address = self.endpoint.get("ip_address") + port = self.endpoint.get("port") + protocol = "https" if self.endpoint.get("https") else "http" + + try: + _LOGGER.info( + f"Checking connection to {protocol}://{ip_address}:{port}") + response = await session.get(f"{protocol}://{ip_address}:{port}/api/tags", headers={}) + _LOGGER.info(f"Response: {response}") + if response.status != 200: + raise ServiceValidationError('handshake_failed') + except Exception as e: + _LOGGER.error(f"Error: {e}") + raise ServiceValidationError('handshake_failed') diff --git a/custom_components/llmvision/request_handlers.py b/custom_components/llmvision/request_handlers.py deleted file mode 100644 index 94d9b39..0000000 --- a/custom_components/llmvision/request_handlers.py +++ /dev/null @@ -1,502 +0,0 @@ -from homeassistant.exceptions import ServiceValidationError -from homeassistant.helpers.aiohttp_client import async_get_clientsession -import logging -import asyncio -import inspect -from .const import ( - DOMAIN, - CONF_OPENAI_API_KEY, - CONF_ANTHROPIC_API_KEY, - CONF_GOOGLE_API_KEY, - CONF_GROQ_API_KEY, - CONF_LOCALAI_IP_ADDRESS, - CONF_LOCALAI_PORT, - CONF_LOCALAI_HTTPS, - CONF_OLLAMA_IP_ADDRESS, - CONF_OLLAMA_PORT, - CONF_OLLAMA_HTTPS, - CONF_CUSTOM_OPENAI_ENDPOINT, - CONF_CUSTOM_OPENAI_API_KEY, - VERSION_ANTHROPIC, - ENDPOINT_OPENAI, - ENDPOINT_ANTHROPIC, - ENDPOINT_GOOGLE, - ENDPOINT_LOCALAI, - ENDPOINT_OLLAMA, - ENDPOINT_GROQ, - ERROR_OPENAI_NOT_CONFIGURED, - ERROR_ANTHROPIC_NOT_CONFIGURED, - ERROR_GOOGLE_NOT_CONFIGURED, - ERROR_GROQ_NOT_CONFIGURED, - ERROR_GROQ_MULTIPLE_IMAGES, - ERROR_LOCALAI_NOT_CONFIGURED, - ERROR_OLLAMA_NOT_CONFIGURED, - ERROR_NO_IMAGE_INPUT -) - -_LOGGER = logging.getLogger(__name__) - - -def sanitize_data(data): - """Remove long string data from request data to reduce log size""" - if isinstance(data, dict): - return {key: sanitize_data(value) for key, value in data.items()} - elif isinstance(data, list): - return [sanitize_data(item) for item in data] - elif isinstance(data, str) and len(data) > 400 and data.count(' ') < 50: - return '' - else: - return data - - -def get_provider(hass, provider_uid): - """Translate UID of the config entry into provider name""" - if DOMAIN not in hass.data: - return None - - entry_data = hass.data[DOMAIN].get(provider_uid) - if not entry_data: - return None - - if CONF_OPENAI_API_KEY in entry_data: - return "OpenAI" - elif CONF_ANTHROPIC_API_KEY in entry_data: - return "Anthropic" - elif CONF_GOOGLE_API_KEY in entry_data: - return "Google" - elif CONF_GROQ_API_KEY in entry_data: - return "Groq" - elif CONF_LOCALAI_IP_ADDRESS in entry_data: - return "LocalAI" - elif CONF_OLLAMA_IP_ADDRESS in entry_data: - return "Ollama" - elif CONF_CUSTOM_OPENAI_API_KEY in entry_data: - return "Custom OpenAI" - - return None - - -def default_model(provider): return { - "OpenAI": "gpt-4o-mini", - "Anthropic": "claude-3-5-sonnet-latest", - "Google": "gemini-1.5-flash-latest", - "Groq": "llava-v1.5-7b-4096-preview", - "LocalAI": "gpt-4-vision-preview", - "Ollama": "llava-phi3:latest", - "Custom OpenAI": "gpt-4o-mini" -}.get(provider, "gpt-4o-mini") # Default value if provider is not found - - -class RequestHandler: - """class to handle requests to AI providers""" - - def __init__(self, hass, message, max_tokens, temperature, detail): - self.session = async_get_clientsession(hass) - self.hass = hass - self.message = message - self.max_tokens = max_tokens - self.temperature = temperature - self.detail = detail - self.base64_images = [] - self.filenames = [] - - async def make_request(self, call): - """Forward request to providers""" - entry_id = call.provider - provider = get_provider(self.hass, entry_id) - _LOGGER.info(f"Provider from call: {provider}") - model = call.model if call.model != "None" else default_model(provider) - - if provider == 'OpenAI': - api_key = self.hass.data.get(DOMAIN).get( - entry_id).get(CONF_OPENAI_API_KEY) - self._validate_call(provider=provider, - api_key=api_key, - base64_images=self.base64_images) - response_text = await self.openai(model=model, api_key=api_key) - elif provider == 'Anthropic': - api_key = self.hass.data.get(DOMAIN).get( - entry_id).get(CONF_ANTHROPIC_API_KEY) - self._validate_call(provider=provider, - api_key=api_key, - base64_images=self.base64_images) - response_text = await self.anthropic(model=model, api_key=api_key) - elif provider == 'Google': - api_key = self.hass.data.get(DOMAIN).get( - entry_id).get(CONF_GOOGLE_API_KEY) - self._validate_call(provider=provider, - api_key=api_key, - base64_images=self.base64_images) - response_text = await self.google(model=model, api_key=api_key) - elif provider == 'Groq': - api_key = self.hass.data.get(DOMAIN).get( - entry_id).get(CONF_GROQ_API_KEY) - self._validate_call(provider=provider, - api_key=api_key, - base64_images=self.base64_images) - response_text = await self.groq(model=model, api_key=api_key) - elif provider == 'LocalAI': - ip_address = self.hass.data.get( - DOMAIN).get( - entry_id).get(CONF_LOCALAI_IP_ADDRESS) - port = self.hass.data.get( - DOMAIN).get( - entry_id).get(CONF_LOCALAI_PORT) - https = self.hass.data.get( - DOMAIN).get( - entry_id).get(CONF_LOCALAI_HTTPS, False) - self._validate_call(provider=provider, - api_key=None, - base64_images=self.base64_images, - ip_address=ip_address, - port=port) - response_text = await self.localai(model=model, - ip_address=ip_address, - port=port, - https=https) - elif provider == 'Ollama': - ip_address = self.hass.data.get( - DOMAIN).get( - entry_id).get(CONF_OLLAMA_IP_ADDRESS) - port = self.hass.data.get(DOMAIN).get( - entry_id).get(CONF_OLLAMA_PORT) - https = self.hass.data.get(DOMAIN).get( - entry_id).get( - CONF_OLLAMA_HTTPS, False) - self._validate_call(provider=provider, - api_key=None, - base64_images=self.base64_images, - ip_address=ip_address, - port=port) - response_text = await self.ollama(model=model, - ip_address=ip_address, - port=port, - https=https) - elif provider == 'Custom OpenAI': - api_key = self.hass.data.get(DOMAIN).get( - entry_id).get( - CONF_CUSTOM_OPENAI_API_KEY, "") - endpoint = self.hass.data.get(DOMAIN).get(entry_id).get( - CONF_CUSTOM_OPENAI_ENDPOINT) + "/v1/chat/completions" - self._validate_call(provider=provider, - api_key=api_key, - base64_images=self.base64_images) - response_text = await self.openai(model=model, api_key=api_key, endpoint=endpoint) - else: - raise ServiceValidationError("invalid_provider") - return {"response_text": response_text} - - def add_frame(self, base64_image, filename): - self.base64_images.append(base64_image) - self.filenames.append(filename) - - # Request Handlers - async def openai(self, model, api_key, endpoint=ENDPOINT_OPENAI): - # Set headers and payload - headers = {'Content-type': 'application/json', - 'Authorization': 'Bearer ' + api_key} - data = {"model": model, - "messages": [{"role": "user", "content": [ - ]}], - "max_tokens": self.max_tokens, - "temperature": self.temperature - } - - # Add the images to the request - for image, filename in zip(self.base64_images, self.filenames): - tag = ("Image " + str(self.base64_images.index(image) + 1) - ) if filename == "" else filename - data["messages"][0]["content"].append( - {"type": "text", "text": tag + ":"}) - data["messages"][0]["content"].append( - {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image}", "detail": self.detail}}) - - # append the message to the end of the request - data["messages"][0]["content"].append( - {"type": "text", "text": self.message} - ) - - response = await self._post( - url=endpoint, headers=headers, data=data) - - response_text = response.get( - "choices")[0].get("message").get("content") - return response_text - - async def anthropic(self, model, api_key): - # Set headers and payload - headers = {'content-type': 'application/json', - 'x-api-key': api_key, - 'anthropic-version': VERSION_ANTHROPIC} - data = {"model": model, - "messages": [ - {"role": "user", "content": []} - ], - "max_tokens": self.max_tokens, - "temperature": self.temperature - } - - # Add the images to the request - for image, filename in zip(self.base64_images, self.filenames): - tag = ("Image " + str(self.base64_images.index(image) + 1) - ) if filename == "" or not filename else filename - data["messages"][0]["content"].append( - { - "type": "text", - "text": tag + ":" - }) - data["messages"][0]["content"].append( - {"type": "image", "source": - {"type": "base64", - "media_type": "image/jpeg", - "data": f"{image}" - } - } - ) - - # append the message to the end of the request - data["messages"][0]["content"].append( - {"type": "text", "text": self.message} - ) - - response = await self._post( - url=ENDPOINT_ANTHROPIC, headers=headers, data=data) - - response_text = response.get("content")[0].get("text") - return response_text - - async def google(self, model, api_key): - # Set headers and payload - headers = {'content-type': 'application/json'} - data = {"contents": [ - ], - "generationConfig": { - "maxOutputTokens": self.max_tokens, - "temperature": self.temperature - } - } - - # Add the images to the request - for image, filename in zip(self.base64_images, self.filenames): - tag = ("Image " + str(self.base64_images.index(image) + 1) - ) if filename == "" or not filename else filename - data["contents"].append( - { - "role": "user", - "parts": [ - { - "text": tag + ":" - }, - { - "inline_data": { - "mime_type": "image/jpeg", - "data": image - } - } - ] - } - ) - - # append the message to the end of the request - data["contents"].append( - {"role": "user", - "parts": [{"text": self.message} - ] - } - ) - - response = await self._post( - url=ENDPOINT_GOOGLE.format(model=model, api_key=api_key), headers=headers, data=data) - - response_text = response.get("candidates")[0].get( - "content").get("parts")[0].get("text") - return response_text - - async def groq(self, model, api_key, endpoint=ENDPOINT_GROQ): - first_image = self.base64_images[0] - # Set headers and payload - headers = {'Content-type': 'application/json', - 'Authorization': 'Bearer ' + api_key} - data = { - "messages": [ - { - "role": "user", - "content": [ - {"type": "text", "text": self.message}, - { - "type": "image_url", - "image_url": {"url": f"data:image/jpeg;base64,{first_image}"} - } - ] - } - ], - "model": model - } - - response = await self._post( - url=endpoint, headers=headers, data=data) - - print(response) - - response_text = response.get( - "choices")[0].get("message").get("content") - return response_text - - async def localai(self, model, ip_address, port, https): - data = {"model": model, - "messages": [{"role": "user", "content": [ - ]}], - "max_tokens": self.max_tokens, - "temperature": self.temperature - } - for image, filename in zip(self.base64_images, self.filenames): - tag = ("Image " + str(self.base64_images.index(image) + 1) - ) if filename == "" or not filename else filename - data["messages"][0]["content"].append( - {"type": "text", "text": tag + ":"}) - data["messages"][0]["content"].append( - {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image}"}}) - - # append the message to the end of the request - data["messages"][0]["content"].append( - {"type": "text", "text": self.message} - ) - - protocol = "https" if https else "http" - response = await self._post( - url=ENDPOINT_LOCALAI.format(ip_address=ip_address, port=port, protocol=protocol), headers={}, data=data) - - response_text = response.get( - "choices")[0].get("message").get("content") - return response_text - - async def ollama(self, model, ip_address, port, https): - data = { - "model": model, - "messages": [], - "stream": False, - "options": { - "num_predict": self.max_tokens, - "temperature": self.temperature - } - } - - for image, filename in zip(self.base64_images, self.filenames): - tag = ("Image " + str(self.base64_images.index(image) + 1) - ) if filename == "" or not filename else filename - image_message = { - "role": "user", - "content": tag + ":", - "images": [image] - } - data["messages"].append(image_message) - # append to the end of the request - prompt_message = { - "role": "user", - "content": self.message - } - data["messages"].append(prompt_message) - - protocol = "https" if https else "http" - response = await self._post(url=ENDPOINT_OLLAMA.format(ip_address=ip_address, port=port, protocol=protocol), headers={}, data=data) - response_text = response.get("message").get("content") - return response_text - - # Helpers - async def _post(self, url, headers, data): - """Post data to url and return response data""" - _LOGGER.info(f"Request data: {sanitize_data(data)}") - - try: - response = await self.session.post(url, headers=headers, json=data) - except Exception as e: - raise ServiceValidationError(f"Request failed: {e}") - - if response.status != 200: - provider = inspect.stack()[1].function - parsed_response = await self._resolve_error(response, provider) - raise ServiceValidationError(parsed_response) - else: - response_data = await response.json() - _LOGGER.info(f"Response data: {response_data}") - return response_data - - async def _fetch(self, url, max_retries=2, retry_delay=1): - """Fetch image from url and return image data""" - retries = 0 - while retries < max_retries: - _LOGGER.info( - f"Fetching {url} (attempt {retries + 1}/{max_retries})") - try: - response = await self.session.get(url) - if response.status != 200: - _LOGGER.warning( - f"Couldn't fetch frame (status code: {response.status})") - retries += 1 - await asyncio.sleep(retry_delay) - continue - data = await response.read() - return data - except Exception as e: - _LOGGER.error(f"Fetch failed: {e}") - retries += 1 - await asyncio.sleep(retry_delay) - _LOGGER.warning(f"Failed to fetch {url} after {max_retries} retries") - return None - - def _validate_call(self, provider, api_key, base64_images, ip_address=None, port=None): - """Validate the service call data""" - # Checks for OpenAI - if provider == 'OpenAI': - if not api_key: - raise ServiceValidationError(ERROR_OPENAI_NOT_CONFIGURED) - # Checks for Anthropic - elif provider == 'Anthropic': - if not api_key: - raise ServiceValidationError(ERROR_ANTHROPIC_NOT_CONFIGURED) - elif provider == 'Google': - if not api_key: - raise ServiceValidationError(ERROR_GOOGLE_NOT_CONFIGURED) - # Checks for Groq - elif provider == 'Groq': - if not api_key: - raise ServiceValidationError(ERROR_GROQ_NOT_CONFIGURED) - if len(base64_images) > 1: - raise ServiceValidationError(ERROR_GROQ_MULTIPLE_IMAGES) - # Checks for LocalAI - elif provider == 'LocalAI': - if not ip_address or not port: - raise ServiceValidationError(ERROR_LOCALAI_NOT_CONFIGURED) - # Checks for Ollama - elif provider == 'Ollama': - if not ip_address or not port: - raise ServiceValidationError(ERROR_OLLAMA_NOT_CONFIGURED) - elif provider == 'Custom OpenAI': - pass - else: - raise ServiceValidationError( - "Invalid provider selected. The event calendar cannot be used for analysis.") - # Check media input - if base64_images == []: - raise ServiceValidationError(ERROR_NO_IMAGE_INPUT) - - async def _resolve_error(self, response, provider): - """Translate response status to error message""" - import json - full_response_text = await response.text() - _LOGGER.info(f"[INFO] Full Response: {full_response_text}") - - try: - response_json = json.loads(full_response_text) - if provider == 'anthropic': - error_info = response_json.get('error', {}) - error_message = f"{error_info.get('type', 'Unknown error')}: {error_info.get('message', 'Unknown error')}" - elif provider == 'ollama': - error_message = response_json.get('error', 'Unknown error') - else: - error_info = response_json.get('error', {}) - error_message = error_info.get('message', 'Unknown error') - except json.JSONDecodeError: - error_message = 'Unknown error' - - return error_message diff --git a/custom_components/llmvision/services.yaml b/custom_components/llmvision/services.yaml index cbd5339..f430a87 100644 --- a/custom_components/llmvision/services.yaml +++ b/custom_components/llmvision/services.yaml @@ -68,17 +68,6 @@ image_analyzer: number: min: 512 max: 1920 - detail: - name: Detail - required: false - description: "Detail parameter. Leave empty for 'auto'" - default: 'low' - example: 'low' - selector: - select: - options: - - 'high' - - 'low' max_tokens: name: Maximum Tokens description: 'Maximum number of tokens to generate' @@ -100,6 +89,13 @@ image_analyzer: min: 0.1 max: 1.0 step: 0.1 + generate_title: + name: Generate Title + required: false + description: Generate a title. (Used for notifications) + default: false + selector: + boolean: expose_images: name: Expose Images description: (Experimental) Expose analyzed frames after processing. This will save analyzed frames in /www/llmvision so they can be used for notifications. (Only works for entity input, include camera name should be enabled). Existing files will be overwritten. @@ -213,17 +209,6 @@ video_analyzer: number: min: 512 max: 1920 - detail: - name: Detail - required: false - description: "Detail parameter, leave empty for 'auto'" - default: 'low' - example: 'low' - selector: - select: - options: - - 'high' - - 'low' max_tokens: name: Maximum Tokens description: 'Maximum number of tokens to generate' @@ -346,17 +331,6 @@ stream_analyzer: number: min: 512 max: 1920 - detail: - name: Detail - required: false - description: "Detail parameter, leave empty for 'auto'" - default: 'low' - example: 'low' - selector: - select: - options: - - 'high' - - 'low' max_tokens: name: Maximum Tokens description: 'Maximum number of tokens to generate' @@ -459,17 +433,6 @@ data_analyzer: number: min: 512 max: 1920 - detail: - name: Detail - required: false - description: "Detail parameter. Leave empty for 'auto'" - default: 'high' - example: 'high' - selector: - select: - options: - - 'high' - - 'low' max_tokens: name: Maximum Tokens description: 'Maximum number of tokens to generate. A low value is recommended since this will likely result in a number.' diff --git a/custom_components/llmvision/strings.json b/custom_components/llmvision/strings.json index 2e038d6..e3b6304 100644 --- a/custom_components/llmvision/strings.json +++ b/custom_components/llmvision/strings.json @@ -30,6 +30,16 @@ "openai_api_key": "Your API key" } }, + "azure": { + "title": "Configure Azure", + "description": "Provide a valid Azure API key, base URL, deployment and API version.\nThe Base URL must be in the format `https://domain.openai.azure.com/` including the trailing slash.", + "data": { + "azure_api_key": "Your API key", + "azure_base_url": "Base URL", + "azure_deployment": "Deployment", + "azure_version": "API Version" + } + }, "anthropic": { "title": "Configure Anthropic Claude", "description": "Provide a valid Anthropic API key.", diff --git a/custom_components/llmvision/translations/de.json b/custom_components/llmvision/translations/de.json index 49956eb..fa8156a 100644 --- a/custom_components/llmvision/translations/de.json +++ b/custom_components/llmvision/translations/de.json @@ -28,6 +28,16 @@ "api_key": "Dein API-key" } }, + "azure": { + "title": "Azure konfigurieren", + "description": "Gib einen gültigen Azure API-key, die Base URL, den Namen des Deployments und die API-Version an.\nDie Base URL muss dieses Format haben: `https://domain.openai.azure.com/` (einschliesslich des abschliessenden '/')", + "data": { + "azure_api_key": "Dein API key", + "azure_base_url": "Base URL", + "azure_deployment": "Deployment", + "azure_version": "API Version" + } + }, "anthropic": { "title": "Anthropic Claude konfigurieren", "description": "Gib einen gültigen Anthropic API-key ein.", diff --git a/custom_components/llmvision/translations/en.json b/custom_components/llmvision/translations/en.json index 2e038d6..e3b6304 100644 --- a/custom_components/llmvision/translations/en.json +++ b/custom_components/llmvision/translations/en.json @@ -30,6 +30,16 @@ "openai_api_key": "Your API key" } }, + "azure": { + "title": "Configure Azure", + "description": "Provide a valid Azure API key, base URL, deployment and API version.\nThe Base URL must be in the format `https://domain.openai.azure.com/` including the trailing slash.", + "data": { + "azure_api_key": "Your API key", + "azure_base_url": "Base URL", + "azure_deployment": "Deployment", + "azure_version": "API Version" + } + }, "anthropic": { "title": "Configure Anthropic Claude", "description": "Provide a valid Anthropic API key.",