From 6e50fce11b232fa56968f3979e80d8d96cb2d10d Mon Sep 17 00:00:00 2001 From: yihong1120 Date: Sat, 2 Nov 2024 22:33:28 +0800 Subject: [PATCH] Refactor and add Asynchronous processing --- .github/workflows/ci.yml | 2 +- .github/workflows/python-app.yml | 4 +- README-zh-tw.md | 1 - README.md | 2 - main.py | 736 +++++++++++++----------- src/drawing_manager.py | 14 +- src/live_stream_detection.py | 155 ++--- src/stream_capture.py | 59 +- src/utils.py | 97 +++- tests/src/live_stream_detection_test.py | 102 ++-- tests/src/stream_capture_test.py | 125 ++-- tests/src/utils_test.py | 39 +- 12 files changed, 738 insertions(+), 598 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2f2ccf1..0a31e19 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,7 +8,7 @@ jobs: uses: actions/checkout@v4 with: fetch-depth: 0 - - name: Set up Python 3.13.0 + - name: Set up Python 3.12.7 uses: actions/setup-python@v4 with: python-version: "3.13.0" diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 292de34..108e194 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -19,10 +19,10 @@ jobs: steps: - uses: actions/checkout@v4 - - name: Set up Python 3.13.0 + - name: Set up Python 3.12.7 uses: actions/setup-python@v3 with: - python-version: "3.13.0" + python-version: "3.12.7" - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/README-zh-tw.md b/README-zh-tw.md index 0807d16..6cc5b5d 100644 --- a/README-zh-tw.md +++ b/README-zh-tw.md @@ -308,7 +308,6 @@ - 新增對 WhatsApp 通知的支援, - 在 server_api 和 streaming_web 中從 Flask 切換到 Fastapi -- 重構 main.py ## 授權 此項目根據 [AGPL-3.0](LICENSE.md) 授權。 diff --git a/README.md b/README.md index ba5ef5e..22ec686 100644 --- a/README.md +++ b/README.md @@ -316,8 +316,6 @@ We welcome contributions to this project. Please follow these steps: - Add support for WhatsApp notifications, - Switch from Flask to Fastapi in server_api and streaming_web -- Refactor main.py - ## License diff --git a/main.py b/main.py index 54b09da..1155bd1 100644 --- a/main.py +++ b/main.py @@ -1,15 +1,16 @@ from __future__ import annotations import argparse +import asyncio import gc import logging import os -import threading import time from datetime import datetime from multiprocessing import Process from typing import TypedDict +import anyio import cv2 import yaml from dotenv import load_dotenv @@ -36,7 +37,10 @@ redis_manager = RedisManager() -class StreamConfig(TypedDict, total=False): +class AppConfig(TypedDict, total=False): + """ + Typed dictionary for the configuration of a video stream. + """ video_url: str model_key: str site: str | None @@ -48,24 +52,25 @@ class StreamConfig(TypedDict, total=False): language: str | None -def run_multiple_streams(config_file: str) -> None: +class MainApp: """ - Manage multiple video streams based on a config file. - - Args: - config_file (str): The path to the YAML configuration file. - - Returns: - None + Main application class for managing multiple video streams. """ - running_processes: dict[str, Process] = {} - current_config_hashes: dict[str, str] = {} - lock = threading.Lock() - logger = logging.getLogger(__name__) - logging.basicConfig(level=logging.INFO) + def __init__(self, config_file: str): + """ + Initialise the MainApp class. - def compute_config_hash(config: dict) -> str: + Args: + config_file (str): The path to the YAML configuration file. + """ + self.config_file = config_file + self.running_processes: dict[str, dict] = {} + self.current_config_hashes: dict[str, str] = {} + self.lock = anyio.Lock() + self.logger = LoggerConfig().get_logger() + + def compute_config_hash(self, config: dict) -> str: """ Compute a hash based on relevant configuration parameters. @@ -85,265 +90,316 @@ def compute_config_hash(config: dict) -> str: } return str(relevant_config) # Convert to string for hashing - def reload_configurations(): - with open(config_file, encoding='utf-8') as file: + async def reload_configurations(self): + """ + Reload the configurations from the YAML file. + """ + with open(self.config_file, encoding='utf-8') as file: configurations = yaml.safe_load(file) current_configs = { config['video_url']: config for config in configurations } - with lock: + async with self.lock: + # Track keys that exist in the current config + current_keys = { + ( + f"{config['site']}_" + f"{config.get('stream_name', 'prediction_visual')}" + ) + for config in configurations + } + # Stop processes for removed or updated configurations - for video_url in list(running_processes.keys()): + for video_url in list(self.running_processes.keys()): + config_data = self.running_processes[video_url] config = current_configs.get(video_url) + # Get the key to be deleted + # if config has been removed or modified + site = config_data['config']['site'] + stream_name = config_data['config'].get( + 'stream_name', 'prediction_visual', + ) + key_to_delete = f"{site}_{stream_name}" + # Stop the process if the configuration is removed if not config or Utils.is_expired(config.get('expire_date')): - logger.info(f"Stop workflow: {video_url}") - stop_process(running_processes[video_url]) - del running_processes[video_url] - del current_config_hashes[video_url] + self.logger.info(f"Stop workflow: {video_url}") + self.stop_process(config_data['process']) + del self.running_processes[video_url] + del self.current_config_hashes[video_url] + + # Delete old key in Redis + # if it no longer exists in the config + if key_to_delete not in current_keys: + await redis_manager.delete(key_to_delete) + self.logger.info(f"Deleted Redis key: {key_to_delete}") # Restart the process if the configuration is updated - elif compute_config_hash(config) != current_config_hashes.get( - video_url, + elif self.compute_config_hash(config) != ( + self.current_config_hashes.get( + video_url, + ) ): - logger.info( + self.logger.info( f"Config changed for {video_url}. " 'Restarting workflow.', ) - stop_process(running_processes[video_url]) - running_processes[video_url] = start_process(config) - current_config_hashes[video_url] = compute_config_hash( - config, + self.stop_process(config_data['process']) + + # Delete old key in Redis + # if it no longer exists in the config + if key_to_delete not in current_keys: + await redis_manager.delete(key_to_delete) + self.logger.info(f"Deleted Redis key: {key_to_delete}") + + # Start the new process + self.running_processes[video_url] = { + 'process': self.start_process(config), + 'config': config, + } + self.current_config_hashes[video_url] = ( + self.compute_config_hash( + config, + ) ) # Start processes for new configurations for video_url, config in current_configs.items(): if Utils.is_expired(config.get('expire_date')): - logger.info(f"Skip expired configuration: {video_url}") + self.logger.info( + f"Skip expired configuration: {video_url}", + ) continue - if video_url not in running_processes: - logger.info(f"Launch new workflow: {video_url}") - running_processes[video_url] = start_process(config) - current_config_hashes[video_url] = compute_config_hash( - config, + if video_url not in self.running_processes: + self.logger.info(f"Launch new workflow: {video_url}") + self.running_processes[video_url] = { + 'process': self.start_process(config), + 'config': config, + } + self.current_config_hashes[video_url] = ( + self.compute_config_hash( + config, + ) ) - # Initial load of configurations - reload_configurations() - - # Set up watchdog observer - event_handler = FileEventHandler(config_file, reload_configurations) - observer = Observer() - observer.schedule( - event_handler, path=os.path.dirname( - config_file, - ), recursive=False, - ) - observer.start() - - try: - while True: - time.sleep(1) - except KeyboardInterrupt: - observer.stop() - observer.join() - - -def process_single_stream( - logger: logging.Logger, - video_url: str, - model_key: str = 'yolo11n', - site: str | None = None, - stream_name: str = 'prediction_visual', - notifications: dict[str, str] | None = None, - detect_with_server: bool = False, -) -> None: - """ - Function to detect hazards, notify, log, save images (optional). - - Args: - logger (logging.Logger): A logger instance for logging messages. - video_url (str): The URL of the live stream to monitor. - site (Optional[str]): The site for stream processing. - stream_name (str, optional): Image file name for notifications. - Defaults to 'demo_data/{site}/prediction_visual.png'. - notifications (Optional[dict]): Line tokens with their languages. - detect_with_server (bool): Whether to run detection using a server api. - """ - # Initialise the stream capture object - streaming_capture = StreamCapture(stream_url=video_url) - - # Get the API URL from environment variables - api_url = os.getenv('API_URL', 'http://localhost:5000') - - # Initialise the live stream detector - live_stream_detector = LiveStreamDetector( - api_url=api_url, - model_key=model_key, - output_folder=site, - detect_with_server=detect_with_server, - ) - - # Initialise the drawing manager - drawing_manager = DrawingManager() + async def run_multiple_streams(self) -> None: + """ + Manage multiple video streams based on a config file. - # Initialise the LINE notifier - line_notifier = LineNotifier() + Args: + config_file (str): The path to the YAML configuration file. - # Initialise the DangerDetector - danger_detector = DangerDetector() + Returns: + None + """ + # Initial load of configurations + await self.reload_configurations() - # Init last_notification_time to 300s ago, no microseconds - last_notification_time = int(time.time()) - 300 + # Set up watchdog observer + event_handler = FileEventHandler( + self.config_file, self.reload_configurations, + ) + observer = Observer() + observer.schedule( + event_handler, path=os.path.dirname( + self.config_file, + ), recursive=False, + ) + observer.start() + + try: + while True: + await anyio.sleep(1) + except KeyboardInterrupt: + observer.stop() + observer.join() + + async def process_single_stream( + self, + logger: logging.Logger, + video_url: str, + model_key: str = 'yolo11n', + site: str | None = None, + stream_name: str = 'prediction_visual', + notifications: dict[str, str] | None = None, + detect_with_server: bool = False, + ) -> None: + """ + Function to detect hazards, notify, log, save images (optional). - # Use the generator function to process detections - for frame, timestamp in streaming_capture.execute_capture(): - start_time = time.time() - # Convert UNIX timestamp to datetime object and format it as string - detection_time = datetime.fromtimestamp(timestamp) - current_hour = detection_time.hour + Args: + logger (logging.Logger): A logger instance for logging messages. + video_url (str): The URL of the live stream to monitor. + site (Optional[str]): The site for stream processing. + stream_name (str, optional): Image file name for notifications. + Defaults to 'demo_data/{site}/prediction_visual.png'. + notifications (Optional[dict]): Line tokens with their languages. + detect_with_server (bool): If run detection with server api or not. + """ + # Initialise the stream capture object + streaming_capture = StreamCapture(stream_url=video_url) - # Detect hazards in the frame - datas, _ = live_stream_detector.generate_detections(frame) + # Get the API URL from environment variables + api_url = os.getenv('API_URL', 'http://localhost:5000') - # Check for warnings and send notifications if necessary - warnings, controlled_zone_polygon = danger_detector.detect_danger( - datas, + # Initialise the live stream detector + live_stream_detector = LiveStreamDetector( + api_url=api_url, + model_key=model_key, + output_folder=site, + detect_with_server=detect_with_server, ) - # Check if there is a warning for people in the controlled zone - controlled_zone_warning_str = next( - # Find the first warning containing 'controlled area' - (warning for warning in warnings if 'controlled area' in warning), - None, - ) + # Initialise the drawing manager + drawing_manager = DrawingManager() - # Convert the warning to a list for translation - controlled_zone_warning: list[str] = [ - controlled_zone_warning_str, - ] if controlled_zone_warning_str else [] + # Initialise the LINE notifier + line_notifier = LineNotifier() - # Track whether we sent any notification - notification_sent = False - last_line_token = None - last_language = None - frame_with_detections = None + # Initialise the DangerDetector + danger_detector = DangerDetector() - if not notifications: - logger.info('No notifications provided.') + # Init last_notification_time to 300s ago, no microseconds + last_notification_time = int(time.time()) - 300 - # Draw the detections on the frame - frame_with_detections = ( - drawing_manager.draw_detections_on_frame( - frame, controlled_zone_polygon, datas, - ) - ) + # Use the generator function to process detections + async for frame, timestamp in streaming_capture.execute_capture(): + start_time = time.time() + # Convert UNIX timestamp to datetime object and format it as string + detection_time = datetime.fromtimestamp(timestamp) + current_hour = detection_time.hour - # Convert the frame to a byte array - _, buffer = cv2.imencode('.png', frame_with_detections) - frame_bytes = buffer.tobytes() - continue - - # Check if notifications are provided - for line_token, language in notifications.items(): - # Check if notification should be skipped - # (sent within last 300 seconds) - if (timestamp - last_notification_time) < 300: - # Store the last token and language, - # but don't send notifications - last_line_token = line_token - last_language = language - # Skip the current notification but remember the last one - continue - - # Translate the warnings - translated_warnings = Translator.translate_warning( - warnings, language, - ) + # Detect hazards in the frame + datas, _ = await live_stream_detector.generate_detections(frame) - # Draw the detections on the frame - frame_with_detections = ( - drawing_manager.draw_detections_on_frame( - frame, controlled_zone_polygon, datas, - language=language, - ) + # Check for warnings and send notifications if necessary + warnings, controlled_zone_polygon = danger_detector.detect_danger( + datas, ) - # Save the frame with detections - # save_file_name = f'{site}_{stream_name}_{detection_time}' - # drawing_manager.save_frame( - # frame_with_detections, - # save_file_name - # ) - - # Convert the frame to a byte array - _, buffer = cv2.imencode('.png', frame_with_detections) - frame_bytes = buffer.tobytes() + # Check if there is a warning for people in the controlled zone + controlled_zone_warning_str = next( + # Find the first warning containing 'controlled area' + ( + warning + for warning in warnings + if 'controlled area' in warning + ), + None, + ) - # Log the detection results - logger.info(f"{site} - {stream_name}") - logger.info(f"Detection time: {detection_time}") + # Convert the warning to a list for translation + controlled_zone_warning: list[str] = [ + controlled_zone_warning_str, + ] if controlled_zone_warning_str else [] - # If it is outside working hours and there is - # a warning for people in the controlled zone - if controlled_zone_warning and not (7 <= current_hour < 18): - translated_controlled_zone_warning: list[str] = ( - Translator.translate_warning( - controlled_zone_warning, language, - ) - ) - message = ( - f"{stream_name}\n[{detection_time}]\n" - f"{translated_controlled_zone_warning}" - ) + # Track whether we sent any notification + last_line_token = None + last_language = None + frame_with_detections = None - elif translated_warnings and (7 <= current_hour < 18): - # During working hours, combine all warnings - message = ( - f"{stream_name}\n[{detection_time}]\n" - + '\n'.join(translated_warnings) - ) + if not notifications: + logger.info('No notifications provided.') else: - message = None - - # If a notification needs to be sent - if not message: - logger.info('No warnings or outside notification time.') - continue - - notification_status = line_notifier.send_notification( - message, - image=frame_bytes - if frame_bytes is not None - else None, - line_token=line_token, - ) + # Check if notifications are provided + for line_token, language in notifications.items(): + # Check if notification should be skipped + # (sent within last 300 seconds) + if (timestamp - last_notification_time) < 300: + # Store the last token and language, + # but don't send notifications + last_line_token = line_token + last_language = language + # Skip the current notification + # but remember the last one + continue + + # Translate the warnings + translated_warnings = Translator.translate_warning( + warnings, language, + ) - # If you want to connect to the broadcast system, do it here: - # broadcast_status = ( - # broadcast_notifier.broadcast_message(message) - # ) - # logger.info(f"Broadcast status: {broadcast_status}") + # Draw the detections on the frame + frame_with_detections = ( + drawing_manager.draw_detections_on_frame( + frame, controlled_zone_polygon, datas, + language=language, + ) + ) - if notification_status == 200: - logger.info( - f"Notification sent successfully: {message}", - ) - notification_sent = True # Mark that a notification was sent - else: - logger.error(f"Failed to send notification: {message}") + # Convert the frame to a byte array + _, buffer = cv2.imencode('.png', frame_with_detections) + frame_bytes = buffer.tobytes() + + # If it is outside working hours and there is + # a warning for people in the controlled zone + if ( + controlled_zone_warning + and not (7 <= current_hour < 18) + ): + translated_controlled_zone_warning: list[str] = ( + Translator.translate_warning( + controlled_zone_warning, language, + ) + ) + message = ( + f"{stream_name}\n[{detection_time}]\n" + f"{translated_controlled_zone_warning}" + ) + + elif translated_warnings and (7 <= current_hour < 18): + # During working hours, combine all warnings + message = ( + f"{stream_name}\n[{detection_time}]\n" + + '\n'.join(translated_warnings) + ) + + else: + message = None + + # If a notification needs to be sent + if not message: + logger.info( + 'No warnings or outside notification time.', + ) + continue + + notification_status = line_notifier.send_notification( + message, + image=frame_bytes + if frame_bytes is not None + else None, + line_token=line_token, + ) - # Log the notification token and language - logger.info(f"Notification sent to {line_token} in {language}.") + # To connect to the broadcast system, do it here: + # broadcast_status = ( + # broadcast_notifier.broadcast_message(message) + # ) + # logger.info(f"Broadcast status: {broadcast_status}") + + if notification_status == 200: + logger.info( + f"Notification sent successfully: {message}", + ) + last_notification_time = int(timestamp) + else: + logger.error(f"Failed to send notification: {message}") + + # Log the notification token and language + logger.info( + f"Notification sent to {line_token} in {language}.", + ) - # If no notification was sent and the time condition was met, - # only draw the image - if last_line_token and last_language: + # If no notification was sent and the time condition was met, + # only draw the image + if last_line_token and last_language: + language = last_language # Draw the detections on the frame for the last token/language # (if not already drawn) @@ -352,7 +408,7 @@ def process_single_stream( drawing_manager.draw_detections_on_frame( frame, controlled_zone_polygon, datas, - language=last_language, + language=last_language or 'en', ) ) @@ -360,134 +416,129 @@ def process_single_stream( _, buffer = cv2.imencode('.png', frame_with_detections) frame_bytes = buffer.tobytes() - # Optionally save the frame with detections - # drawing_manager.save_frame(frame_with_detections, save_file_name) - - # Update last_notification_time only - # if at least one notification was sent - if notification_sent: - last_notification_time = int(timestamp) - - # Store the frame in Redis if not running on Windows - if not is_windows: - try: - # Use a unique key for each thread or process - key = f"{site}_{stream_name}" + # Save the frame with detections + # save_file_name = f'{site}_{stream_name}_{detection_time}' + # drawing_manager.save_frame( + # frame_with_detections, + # save_file_name + # ) - # Store the frame in Redis - redis_manager.set(key, frame_bytes) - except Exception as e: - logger.error(f"Failed to store frame in Redis: {e}") + # Store the frame in Redis if not running on Windows + if not is_windows: + try: + # Use a unique key for each thread or process + key = f"{site}_{stream_name}" - end_time = time.time() + # Store the frame in Redis Stream + # with a maximum length of 10 + await redis_manager.add_to_stream( + key, {'frame': frame_bytes}, maxlen=10, + ) + except Exception as e: + logger.error(f"Failed to store frame in Redis: {e}") - # Calculate the processing time - processing_time = end_time - start_time + # Update the capture interval based on processing time + end_time = time.time() + processing_time = end_time - start_time + new_interval = int(processing_time) + 5 + streaming_capture.update_capture_interval(new_interval) - # Update the capture interval based on the processing time - new_interval = int(processing_time) + 5 - streaming_capture.update_capture_interval(new_interval) + # Log the detection results + logger.info(f"{site} - {stream_name}") + logger.info(f"Detection time: {detection_time}") + logger.info(f"Processing time: {processing_time:.2f} seconds") - # Log the processing time - logger.info(f"Processing time: {processing_time:.2f} seconds") + # Clear variables to free up memory + del datas, frame, timestamp, detection_time + del frame_with_detections, buffer, frame_bytes + gc.collect() - # Clear variables to free up memory - del datas, frame, timestamp, detection_time - del frame_with_detections, buffer, frame_bytes + # Release resources after processing + await streaming_capture.release_resources() gc.collect() - # Release resources after processing - streaming_capture.release_resources() - gc.collect() - - -def process_streams(config: StreamConfig) -> None: - """ - Process a video stream based on the given configuration. - - Args: - config (StreamConfig): The configuration for the stream processing. - - Returns: - None - """ - # Load the logger configuration - logger_config = LoggerConfig() + async def process_streams(self, config: AppConfig) -> None: + """ + Process a video stream based on the given configuration. - # Initialise the logger - logger = logger_config.get_logger() + Args: + config (StreamConfig): The configuration for the stream processing. - try: - # Check if 'notifications' field exists (new format) - if 'notifications' in config and config['notifications'] is not None: - notifications = config['notifications'] - # Otherwise, handle the old format - elif 'line_token' in config and 'language' in config: - line_token = config.get('line_token') - language = config.get('language') - if line_token is not None and language is not None: - notifications = {line_token: language} + Returns: + None + """ + try: + # Check if 'notifications' field exists (new format) + if ( + 'notifications' in config and + config['notifications'] is not None + ): + notifications = config['notifications'] + # Otherwise, handle the old format + elif 'line_token' in config and 'language' in config: + line_token = config.get('line_token') + language = config.get('language') + if line_token is not None and language is not None: + notifications = {line_token: language} + else: + notifications = None else: notifications = None - else: - notifications = None - - # Continue processing the remaining configuration - video_url = config.get('video_url', '') - model_key = config.get('model_key', 'yolo11n') - site = config.get('site') - stream_name = config.get('stream_name', 'prediction_visual') - detect_with_server = config.get('detect_with_server', False) - - # Run hazard detection on a single video stream - process_single_stream( - logger, - video_url=video_url, - model_key=model_key, - site=site, - stream_name=stream_name, - notifications=notifications, - detect_with_server=detect_with_server, - ) - finally: - if not is_windows: + + # Continue processing the remaining configuration + video_url = config.get('video_url', '') + model_key = config.get('model_key', 'yolo11n') site = config.get('site') stream_name = config.get('stream_name', 'prediction_visual') - key = f"{site}_{stream_name}" - redis_manager.delete(key) - logger.info(f"Deleted Redis key: {key}") - - -def start_process(config: StreamConfig) -> Process: - """ - Start a new process for processing a video stream. + detect_with_server = config.get('detect_with_server', False) + + # Run hazard detection on a single video stream + await self.process_single_stream( + self.logger, + video_url=video_url, + model_key=model_key, + site=site, + stream_name=stream_name, + notifications=notifications, + detect_with_server=detect_with_server, + ) + finally: + if not is_windows: + site = config.get('site') + stream_name = config.get('stream_name', 'prediction_visual') + key = f"{site}_{stream_name}" + await redis_manager.delete(key) + self.logger.info(f"Deleted Redis key: {key}") - Args: - config (StreamConfig): The configuration for the stream processing. + def start_process(self, config: AppConfig) -> Process: + """ + Start a new process for processing a video stream. - Returns: - Process: The newly started process. - """ - p = Process(target=process_streams, args=(config,)) - p.start() - return p + Args: + config (StreamConfig): The configuration for the stream processing. + Returns: + Process: The newly started process. + """ + p = Process(target=lambda: asyncio.run(self.process_streams(config))) + p.start() + return p -def stop_process(process: Process) -> None: - """ - Stop a running process. + def stop_process(self, process: Process) -> None: + """ + Stop a running process. - Args: - process (Process): The process to be terminated. + Args: + process (Process): The process to be terminated. - Returns: - None - """ - process.terminate() - process.join() + Returns: + None + """ + process.terminate() + process.join() -def process_single_image( +async def process_single_image( image_path: str, model_key: str = 'yolo11n', output_folder: str = 'output_images', @@ -496,6 +547,16 @@ def process_single_image( ) -> None: """ Process a single image for hazard detection and save the result. + + Args: + image_path (str): The path to the image file. + model_key (str): The model key to use for detection. + output_folder (str): The folder to save the output image. + stream_name (str): The name of the output image file. + language (str): The language for labels on the output image. + + Returns: + None """ try: # Check if the image path exists @@ -521,7 +582,7 @@ def process_single_image( drawing_manager = DrawingManager() # Detect hazards in the image - detections, _ = live_stream_detector.generate_detections(image) + detections, _ = await live_stream_detector.generate_detections(image) # For this example, no polygons are needed, so pass an empty list frame_with_detections = drawing_manager.draw_detections_on_frame( @@ -547,7 +608,7 @@ def process_single_image( print(f"Error processing the image: {str(e)}") -if __name__ == '__main__': +async def main(): parser = argparse.ArgumentParser( description=( 'Run hazard detection on multiple video streams or a single image.' @@ -586,7 +647,7 @@ def process_single_image( # If an image path is provided, process the single image if args.image: - process_single_image( + await process_single_image( image_path=args.image, model_key=args.model_key, output_folder=args.output_folder, @@ -594,4 +655,9 @@ def process_single_image( ) else: # Otherwise, run hazard detection on multiple video streams - run_multiple_streams(args.config) + app = MainApp(args.config) + await app.run_multiple_streams() + + +if __name__ == '__main__': + anyio.run(main) diff --git a/src/drawing_manager.py b/src/drawing_manager.py index 96ac45c..c675fd5 100644 --- a/src/drawing_manager.py +++ b/src/drawing_manager.py @@ -18,20 +18,25 @@ class DrawingManager: """ # Class variable for caching default font - default_font: ImageFont.ImageFont | None = None + default_font: ImageFont.FreeTypeFont | ImageFont.ImageFont | None = None def __init__(self) -> None: """ Initialise the DrawingManager class. """ # Font cache to avoid repeated loading - self.font_cache: dict[str, ImageFont.ImageFont] = {} + self.font_cache: dict[ + str, ImageFont.FreeTypeFont | + ImageFont.ImageFont, + ] = {} # Load default font if not already loaded if DrawingManager.default_font is None: DrawingManager.default_font = ImageFont.load_default() - def get_font(self, language: str) -> ImageFont.ImageFont: + def get_font( + self, language: str, + ) -> ImageFont.FreeTypeFont | ImageFont.ImageFont: """ Load the appropriate font based on the language input, with caching. @@ -39,7 +44,8 @@ def get_font(self, language: str) -> ImageFont.ImageFont: language (str): The language to use for the font. Returns: - ImageFont.ImageFont: The loaded font object. + ImageFont.FreeTypeFont | ImageFont.ImageFont: + The loaded font object. """ # Select font path based on language if language == 'th': diff --git a/src/live_stream_detection.py b/src/live_stream_detection.py index 227c28a..1313b7d 100644 --- a/src/live_stream_detection.py +++ b/src/live_stream_detection.py @@ -8,18 +8,17 @@ from pathlib import Path from typing import TypedDict +import aiohttp +import anyio import cv2 import numpy as np -import requests from dotenv import load_dotenv -from requests.adapters import HTTPAdapter from sahi import AutoDetectionModel from sahi.predict import get_sliced_prediction from tenacity import retry from tenacity import retry_if_exception_type from tenacity import stop_after_attempt from tenacity import wait_fixed -from urllib3.util import Retry load_dotenv() @@ -60,82 +59,44 @@ def __init__( model_key (str): The model key for detection. output_folder (Optional[str]): Folder for detected frames. """ - self.api_url = api_url - self.model_key = model_key - self.output_folder = output_folder - self.session = self.requests_retry_session() - self.detect_with_server = detect_with_server - self.model = None - self.access_token = None - self.token_expiry = 0.0 - - def requests_retry_session( - self, - retries: int = 7, - backoff_factor: int = 1, - status_forcelist: tuple[int, ...] = (500, 502, 504, 401, 104), - session: requests.Session | None = None, - allowed_methods: frozenset = frozenset( - ['HEAD', 'GET', 'POST', 'PUT', 'DELETE', 'OPTIONS', 'TRACE'], - ), - ) -> requests.Session: - """ - Configures a requests session with retry logic. - - Args: - retries (int): The number of retry attempts. - backoff_factor (int): The backoff factor for retries. - status_forcelist (Tuple[int]): List of HTTP status codes for retry. - session (Optional[requests.Session]): An optional requests session. - allowed_methods (frozenset): The set of allowed HTTP methods. - - Returns: - requests.Session: The configured requests requests session. - """ - session = session or requests.Session() - retry = Retry( - total=retries, - read=retries, - connect=retries, - backoff_factor=backoff_factor, - status_forcelist=status_forcelist, - allowed_methods=allowed_methods, - raise_on_status=False, + self.api_url: str = ( + api_url if api_url.startswith('http') else f"http://{api_url}" ) - adapter = HTTPAdapter(max_retries=retry) - session.mount('http://', adapter) - session.mount('https://', adapter) - return session + self.model_key: str = model_key + self.output_folder: str | None = output_folder + self.detect_with_server: bool = detect_with_server + self.model: AutoDetectionModel | None = None + self.access_token: str | None = None + self.token_expiry: float = 0 @retry( stop=stop_after_attempt(3), wait=wait_fixed(2), - retry=retry_if_exception_type(requests.RequestException), + retry=retry_if_exception_type(aiohttp.ClientError), ) - def authenticate(self) -> None: + async def authenticate(self) -> None: """ Authenticates with the API and retrieves the access token. """ - response = self.session.post( - f"{self.api_url}/token", - json={ - 'username': os.getenv( - 'API_USERNAME', - ), - 'password': os.getenv('API_PASSWORD'), - }, - ) - response.raise_for_status() - token_data = response.json() - if 'msg' in token_data: - raise Exception(token_data['msg']) - elif 'access_token' in token_data: - self.access_token = token_data['access_token'] - else: - raise Exception( - "Token data does not contain 'msg' or 'access_token'", - ) - self.token_expiry = time.time() + 850 + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.api_url}/token", + json={ + 'username': os.getenv('API_USERNAME'), + 'password': os.getenv('API_PASSWORD'), + }, + ) as response: + response.raise_for_status() + token_data = await response.json() + if 'msg' in token_data: + raise Exception(token_data['msg']) + elif 'access_token' in token_data: + self.access_token = token_data['access_token'] + else: + raise Exception( + "Token data does not contain 'msg' or 'access_token'", + ) + self.token_expiry = time.time() + 850 def token_expired(self) -> bool: """ @@ -146,19 +107,19 @@ def token_expired(self) -> bool: """ return time.time() >= self.token_expiry - def ensure_authenticated(self) -> None: + async def ensure_authenticated(self) -> None: """ Ensures that the access token is valid and not expired. """ if self.access_token is None or self.token_expired(): - self.authenticate() + await self.authenticate() @retry( stop=stop_after_attempt(2), wait=wait_fixed(3), - retry=retry_if_exception_type(requests.RequestException), + retry=retry_if_exception_type(aiohttp.ClientError), ) - def generate_detections_cloud( + async def generate_detections_cloud( self, frame: np.ndarray, ) -> list[list[float]]: @@ -171,7 +132,7 @@ def generate_detections_cloud( Returns: List[List[float]]: The detection data. """ - self.ensure_authenticated() + await self.ensure_authenticated() _, frame_encoded = cv2.imencode('.png', frame) frame_encoded_bytes = frame_encoded.tobytes() @@ -179,18 +140,24 @@ def generate_detections_cloud( filename = f"frame_{timestamp}.png" headers = {'Authorization': f"Bearer {self.access_token}"} - files = {'image': (filename, frame_encoded_bytes, 'image/png')} - response = self.session.post( - f"{self.api_url}/detect", - files=files, - params={'model': self.model_key}, - headers=headers, + data = aiohttp.FormData() + data.add_field( + 'image', frame_encoded_bytes, + filename=filename, content_type='image/png', ) - response.raise_for_status() - detections = response.json() - return detections - def generate_detections_local( + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.api_url}/detect", + data=data, + params={'model': self.model_key}, + headers=headers, + ) as response: + response.raise_for_status() + detections = await response.json() + return detections + + async def generate_detections_local( self, frame: np.ndarray, ) -> list[list[float]]: @@ -391,7 +358,7 @@ def remove_completely_contained_labels(self, datas): return datas - def generate_detections( + async def generate_detections( self, frame: np.ndarray, ) -> tuple[list[list[float]], np.ndarray]: """ @@ -405,13 +372,12 @@ def generate_detections( Detections and original frame. """ if self.detect_with_server: - datas = self.generate_detections_cloud(frame) + datas = await self.generate_detections_cloud(frame) else: - datas = self.generate_detections_local(frame) - + datas = await self.generate_detections_local(frame) return datas, frame - def run_detection(self, stream_url: str) -> None: + async def run_detection(self, stream_url: str) -> None: """ Runs detection on the live stream. @@ -430,6 +396,9 @@ def run_detection(self, stream_url: str) -> None: continue # Perform detection + datas, frame = await self.generate_detections(frame) + print(datas) # You can replace this with actual processing + cv2.imshow('Frame', frame) if cv2.waitKey(1) & 0xFF == ord('q'): break @@ -438,7 +407,7 @@ def run_detection(self, stream_url: str) -> None: cv2.destroyAllWindows() -def main(): +async def main(): parser = argparse.ArgumentParser( description='Perform live stream detection and tracking using YOLO.', ) @@ -478,8 +447,8 @@ def main(): output_folder=args.output_folder, detect_with_server=args.detect_with_server, ) - detector.run_detection(args.url) + await detector.run_detection(args.url) if __name__ == '__main__': - main() + anyio.run(main) diff --git a/src/stream_capture.py b/src/stream_capture.py index 23c84ad..3183322 100644 --- a/src/stream_capture.py +++ b/src/stream_capture.py @@ -1,10 +1,10 @@ from __future__ import annotations import argparse +import asyncio import datetime import gc -import time -from collections.abc import Generator +from collections.abc import AsyncGenerator from typing import TypedDict import cv2 @@ -34,6 +34,8 @@ def __init__(self, stream_url: str, capture_interval: int = 15): Args: stream_url (str): The URL of the video stream. + capture_interval (int, optional): The interval at which frames + should be captured. Defaults to 15. """ # Video stream URL self.stream_url = stream_url @@ -44,7 +46,7 @@ def __init__(self, stream_url: str, capture_interval: int = 15): # Flag to indicate successful capture self.successfully_captured = False - def initialise_stream(self, stream_url: str) -> None: + async def initialise_stream(self, stream_url: str) -> None: """ Initialises the video stream. @@ -56,10 +58,10 @@ def initialise_stream(self, stream_url: str) -> None: # self.cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'H264')) if not self.cap.isOpened(): - time.sleep(5) + await asyncio.sleep(5) self.cap.open(stream_url) - def release_resources(self) -> None: + async def release_resources(self) -> None: """ Releases resources like the capture object. """ @@ -68,14 +70,16 @@ def release_resources(self) -> None: self.cap = None gc.collect() - def execute_capture(self) -> Generator[tuple[np.ndarray, float]]: + async def execute_capture( + self, + ) -> AsyncGenerator[tuple[np.ndarray, float]]: """ Captures frames from the stream and yields them with timestamps. Yields: Tuple[np.ndarray, float]: The captured frame and the timestamp. """ - self.initialise_stream(self.stream_url) + await self.initialise_stream(self.stream_url) last_process_time = datetime.datetime.now() - datetime.timedelta( seconds=self.capture_interval, ) @@ -83,7 +87,7 @@ def execute_capture(self) -> Generator[tuple[np.ndarray, float]]: while True: if self.cap is None: - self.initialise_stream(self.stream_url) + await self.initialise_stream(self.stream_url) ret, frame = ( self.cap.read() if self.cap is not None else (False, None) @@ -95,12 +99,15 @@ def execute_capture(self) -> Generator[tuple[np.ndarray, float]]: 'Failed to read frame, trying to reinitialise stream. ' f"Fail count: {fail_count}", ) - self.release_resources() - self.initialise_stream(self.stream_url) + await self.release_resources() + await self.initialise_stream(self.stream_url) # Switch to generic frame capture after 5 consecutive failures if fail_count >= 5 and not self.successfully_captured: print('Switching to generic frame capture method.') - yield from self.capture_generic_frames() + async for generic_frame, timestamp in ( + self.capture_generic_frames() + ): + yield generic_frame, timestamp return continue else: @@ -108,7 +115,7 @@ def execute_capture(self) -> Generator[tuple[np.ndarray, float]]: fail_count = 0 # Mark as successfully captured - self.successfully_captured = True # + self.successfully_captured = True # Process the frame if the capture interval has elapsed current_time = datetime.datetime.now() @@ -124,9 +131,9 @@ def execute_capture(self) -> Generator[tuple[np.ndarray, float]]: del frame, timestamp gc.collect() - time.sleep(0.01) # Adjust the sleep time as needed + await asyncio.sleep(0.01) # Adjust the sleep time as needed - self.release_resources() + await self.release_resources() def check_internet_speed(self) -> tuple[float, float]: """ @@ -137,7 +144,7 @@ def check_internet_speed(self) -> tuple[float, float]: """ st = speedtest.Speedtest() st.get_best_server() - download_speed = st.download() / 1_000_000 + download_speed = st.download() / 1_000_000 # Turn into Mbps upload_speed = st.upload() / 1_000_000 return download_speed, upload_speed @@ -183,9 +190,9 @@ def select_quality_based_on_speed(self) -> str | None: print(f"Error selecting quality based on speed: {e}") return None - def capture_generic_frames( + async def capture_generic_frames( self, - ) -> Generator[tuple[np.ndarray, float]]: + ) -> AsyncGenerator[tuple[np.ndarray, float]]: """ Captures frames from a generic stream. @@ -199,7 +206,7 @@ def capture_generic_frames( return # Initialise the stream with the selected URL - self.initialise_stream(stream_url) + await self.initialise_stream(stream_url) last_process_time = datetime.datetime.now() fail_count = 0 # Counter for consecutive failures @@ -221,8 +228,8 @@ def capture_generic_frames( # Reinitialise the stream after 5 consecutive failures if fail_count >= 5 and not self.successfully_captured: print('Reinitialising the generic stream.') - self.release_resources() - time.sleep(5) + await self.release_resources() + await asyncio.sleep(5) stream_url = self.select_quality_based_on_speed() # Exit if no suitable stream quality is available @@ -231,7 +238,7 @@ def capture_generic_frames( continue # Reinitialise the stream with the new URL - self.initialise_stream(stream_url) + await self.initialise_stream(stream_url) fail_count = 0 continue else: @@ -253,7 +260,7 @@ def capture_generic_frames( del frame, timestamp gc.collect() - time.sleep(0.01) # Adjust the sleep time as needed + await asyncio.sleep(0.01) # Adjust the sleep time as needed def update_capture_interval(self, new_interval: int) -> None: """ @@ -265,9 +272,9 @@ def update_capture_interval(self, new_interval: int) -> None: self.capture_interval = new_interval -def main(): +async def main(): parser = argparse.ArgumentParser( - description='Capture video stream frames.', + description='Capture video stream frames asynchronously.', ) parser.add_argument( '--url', @@ -278,7 +285,7 @@ def main(): args = parser.parse_args() stream_capture = StreamCapture(args.url) - for frame, timestamp in stream_capture.execute_capture(): + async for frame, timestamp in stream_capture.execute_capture(): # Process the frame here print(f"Frame at {timestamp} displayed") # Release the frame resources @@ -287,4 +294,4 @@ def main(): if __name__ == '__main__': - main() + asyncio.run(main()) diff --git a/src/utils.py b/src/utils.py index a34ac9a..4ae3480 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,14 +1,19 @@ from __future__ import annotations +import asyncio import logging import os from datetime import datetime -from redis import Redis +import redis.asyncio as redis from watchdog.events import FileSystemEventHandler class Utils: + """ + A class to provide utility functions. + """ + @staticmethod def is_expired(expire_date_str: str | None) -> bool: """ @@ -28,6 +33,10 @@ def is_expired(expire_date_str: str | None) -> bool: class FileEventHandler(FileSystemEventHandler): + """ + A class to handle file events. + """ + def __init__(self, file_path: str, callback): """ Initialises the FileEventHandler instance. @@ -47,27 +56,32 @@ def on_modified(self, event): event (FileSystemEvent): The event object. """ if event.src_path == self.file_path: - self.callback() + # Run the callback function + asyncio.run(self.callback()) class RedisManager: + """ + A class to manage Redis operations. + """ + def __init__(self): """ Initialise the RedisManager by connecting to Redis. """ self.redis_host: str = os.getenv('redis_host', 'localhost') - self.redis_port: int = int(os.getenv('redis_port', '6379')) - self.redis_password: str | None = os.getenv('redis_password', None) + self.redis_port: int = int(os.getenv('redis_port', 6379)) + self.redis_password: str | None = os.getenv('redis_password') - # Set decode_responses=False to allow bytes storage - self.redis: Redis = Redis( + # Create Redis connection + self.redis = redis.Redis( host=self.redis_host, port=self.redis_port, password=self.redis_password, decode_responses=False, ) - def set(self, key: str, value: bytes) -> None: + async def set(self, key: str, value: bytes) -> None: """ Set a key-value pair in Redis. @@ -76,11 +90,11 @@ def set(self, key: str, value: bytes) -> None: value (bytes): The value to store (in bytes). """ try: - self.redis.set(key, value) + await self.redis.set(key, value) except Exception as e: logging.error(f"Error setting Redis key {key}: {str(e)}") - def get(self, key: str) -> bytes | None: + async def get(self, key: str) -> bytes | None: """ Retrieve a value from Redis based on the key. @@ -91,12 +105,12 @@ def get(self, key: str) -> bytes | None: bytes | None: The value if found, None otherwise. """ try: - return self.redis.get(key) + return await self.redis.get(key) except Exception as e: logging.error(f"Error retrieving Redis key {key}: {str(e)}") return None - def delete(self, key: str) -> None: + async def delete(self, key: str) -> None: """ Delete a key from Redis. @@ -104,6 +118,65 @@ def delete(self, key: str) -> None: key (str): The key to delete from Redis. """ try: - self.redis.delete(key) + await self.redis.delete(key) except Exception as e: logging.error(f"Error deleting Redis key {key}: {str(e)}") + + async def add_to_stream( + self, + stream_name: str, + data: dict, + maxlen: int = 10, + ) -> None: + """ + Add data to a Redis stream with a maximum length. + + Args: + stream_name (str): The name of the Redis stream. + data (dict): The data to add to the stream. + maxlen (int): The maximum length of the stream. + """ + try: + await self.redis.xadd(stream_name, data, maxlen=maxlen) + except Exception as e: + logging.error( + f"Error adding to Redis stream {stream_name}: {str(e)}", + ) + + async def read_from_stream( + self, + stream_name: str, + last_id: str = '0', + ) -> list: + """ + Read data from a Redis stream. + + Args: + stream_name (str): The name of the Redis stream. + last_id (str): The ID of the last read message. + + Returns: + list: A list of messages from the stream. + """ + try: + return await self.redis.xread({stream_name: last_id}) + except Exception as e: + logging.error( + f"Error reading from Redis stream {stream_name}: {str(e)}", + ) + return [] + + async def delete_stream(self, stream_name: str) -> None: + """ + Delete a Redis stream. + + Args: + stream_name (str): The name of the Redis stream to delete. + """ + try: + await self.redis.delete(stream_name) + logging.info(f"Deleted Redis stream: {stream_name}") + except Exception as e: + logging.error( + f"Error deleting Redis stream {stream_name}: {str(e)}", + ) diff --git a/tests/src/live_stream_detection_test.py b/tests/src/live_stream_detection_test.py index 6a6098a..b51fc64 100644 --- a/tests/src/live_stream_detection_test.py +++ b/tests/src/live_stream_detection_test.py @@ -8,6 +8,7 @@ import cv2 import numpy as np +import pytest from src.live_stream_detection import LiveStreamDetector from src.live_stream_detection import main @@ -22,7 +23,7 @@ def setUp(self) -> None: """ Set up the LiveStreamDetector instance for tests. """ - self.api_url: str = 'http://localhost:5000' + self.api_url: str = 'http://127.0.0.1:8001' self.model_key: str = 'yolo11n' self.output_folder: str = 'test_output' self.detect_with_server: bool = False @@ -33,24 +34,17 @@ def setUp(self) -> None: detect_with_server=self.detect_with_server, ) - @patch( - 'src.live_stream_detection.LiveStreamDetector.requests_retry_session', - ) @patch('PIL.ImageFont.truetype') def test_initialisation( self, mock_truetype: MagicMock, - mock_requests_retry_session: MagicMock, ) -> None: """ Test the initialisation of the LiveStreamDetector instance. Args: mock_truetype (MagicMock): Mock for PIL.ImageFont.truetype. - mock_requests_retry_session (MagicMock): Mock for - requests_retry_session. """ - mock_requests_retry_session.return_value = MagicMock() mock_truetype.return_value = MagicMock() detector = LiveStreamDetector( @@ -65,13 +59,13 @@ def test_initialisation( self.assertEqual(detector.model_key, self.model_key) self.assertEqual(detector.output_folder, self.output_folder) self.assertEqual(detector.detect_with_server, self.detect_with_server) - self.assertIsNotNone(detector.session) self.assertEqual(detector.access_token, None) self.assertEqual(detector.token_expiry, 0.0) @patch('src.live_stream_detection.cv2.VideoCapture') @patch('src.live_stream_detection.AutoDetectionModel.from_pretrained') - def test_generate_detections_local( + @pytest.mark.asyncio + async def test_generate_detections_local( self, mock_from_pretrained: MagicMock, mock_video_capture: MagicMock, @@ -103,11 +97,11 @@ def test_generate_detections_local( ] mock_model.predict.return_value = mock_result - datas: list[list[Any]] = self.detector.generate_detections_local( + datas: list[list[Any]] = await self.detector.generate_detections_local( frame, ) - # Assert the structure and types of the detection data + # Assert the structure and types of the detection dataㄋ self.assertIsInstance(datas, list) for data in datas: self.assertIsInstance(data, list) @@ -119,23 +113,11 @@ def test_generate_detections_local( self.assertIsInstance(data[4], float) self.assertIsInstance(data[5], int) - @patch('src.live_stream_detection.cv2.destroyAllWindows') - @patch('src.live_stream_detection.LiveStreamDetector.generate_detections') - def test_run_detection( - self, - mock_generate_detections: MagicMock, - mock_destroyAllWindows: MagicMock, - ) -> None: + @pytest.mark.asyncio + async def test_run_detection(self) -> None: """ Test the run_detection method. - - Args: - mock_generate_detections (MagicMock): Mock for generate_detections. - mock_destroyAllWindows (MagicMock): Mock for cv2.destroyAllWindows. """ - mock_generate_detections.return_value = ( - [], np.zeros((480, 640, 3), dtype=np.uint8), - ) stream_url: str = 'http://example.com/virtual_stream' cap_mock: MagicMock = MagicMock() cap_mock.read.side_effect = [ @@ -154,35 +136,33 @@ def test_run_detection( 'src.live_stream_detection.cv2.waitKey', side_effect=[-1, ord('q')], ): - self.detector.run_detection(stream_url) + await self.detector.run_detection(stream_url) cap_mock.read.assert_called() cap_mock.release.assert_called_once() - mock_destroyAllWindows.assert_called() - @patch('src.live_stream_detection.requests.Session.post') - @patch('src.live_stream_detection.LiveStreamDetector.authenticate') - def test_generate_detections_cloud( + @pytest.mark.asyncio + @patch('aiohttp.ClientSession.post') + async def test_generate_detections_cloud( self, - mock_authenticate: MagicMock, mock_post: MagicMock, ) -> None: """ Test cloud detection generation. Args: - mock_authenticate (MagicMock): Mock for authenticate. - mock_post (MagicMock): Mock for requests.Session.post. + mock_post (MagicMock): Mock for aiohttp.ClientSession.post. """ frame: np.ndarray = np.zeros((480, 640, 3), dtype=np.uint8) - # mat_frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) mock_response: MagicMock = MagicMock() mock_response.json.return_value = [ [10, 10, 50, 50, 0.9, 0], [20, 20, 60, 60, 0.8, 1], ] - mock_post.return_value = mock_response + mock_post.return_value.__aenter__.return_value = mock_response - datas: list[list[Any]] = self.detector.generate_detections_cloud(frame) + datas: list[list[Any]] = await self.detector.generate_detections_cloud( + frame, + ) # Assert the structure and types of the detection data self.assertIsInstance(datas, list) @@ -196,26 +176,23 @@ def test_generate_detections_cloud( self.assertIsInstance(data[4], float) self.assertIsInstance(data[5], int) - @patch('src.live_stream_detection.LiveStreamDetector.ensure_authenticated') - @patch('src.live_stream_detection.requests.Session.post') - def test_authenticate( + @pytest.mark.asyncio + @patch('aiohttp.ClientSession.post') + async def test_authenticate( self, mock_post: MagicMock, - mock_ensure_authenticated: MagicMock, ) -> None: """ Test the authentication process. Args: mock_post (MagicMock): Mock for requests.Session.post. - mock_ensure_authenticated (MagicMock): Mock - for ensure_authenticated. """ mock_response: MagicMock = MagicMock() mock_response.json.return_value = {'access_token': 'fake_token'} - mock_post.return_value = mock_response + mock_post.return_value.__aenter__.return_value = mock_response - self.detector.authenticate() + await self.detector.authenticate() # Assert the access token and token expiry self.assertEqual(self.detector.access_token, 'fake_token') @@ -231,7 +208,8 @@ def test_token_expired(self) -> None: self.detector.token_expiry = time.time() + 1000 self.assertFalse(self.detector.token_expired()) - def test_ensure_authenticated(self) -> None: + @pytest.mark.asyncio + async def test_ensure_authenticated(self) -> None: """ Test the ensure_authenticated method. """ @@ -239,7 +217,7 @@ def test_ensure_authenticated(self) -> None: with patch.object( self.detector, 'authenticate', ) as mock_authenticate: - self.detector.ensure_authenticated() + await self.detector.ensure_authenticated() mock_authenticate.assert_called_once() def test_remove_overlapping_labels(self) -> None: @@ -293,24 +271,26 @@ def test_overlap_percentage(self) -> None: overlap = self.detector.overlap_percentage(bbox1, bbox2) self.assertAlmostEqual(overlap, 0.262344, places=6) - @patch('src.live_stream_detection.requests.Session.post') - def test_authenticate_error(self, mock_post: MagicMock) -> None: + @pytest.mark.asyncio + @patch('aiohttp.ClientSession.post') + async def test_authenticate_error(self, mock_post: MagicMock) -> None: """ Test the authenticate method with error response. Args: - mock_post (MagicMock): Mock for requests.Session.post. + mock_post (MagicMock): Mock for aiohttp.ClientSession.post. """ mock_response: MagicMock = MagicMock() mock_response.json.return_value = {'msg': 'Authentication failed'} - mock_post.return_value = mock_response + mock_post.return_value.__aenter__.return_value = mock_response with self.assertRaises(Exception) as context: - self.detector.authenticate() + await self.detector.authenticate() self.assertIn('Authentication failed', str(context.exception)) - def test_generate_detections(self) -> None: + @pytest.mark.asyncio + async def test_generate_detections(self) -> None: """ Test the generate_detections method. """ @@ -321,7 +301,7 @@ def test_generate_detections(self) -> None: self.detector, 'generate_detections_local', return_value=[[10, 10, 50, 50, 0.9, 0]], ) as mock_local: - datas, _ = self.detector.generate_detections(mat_frame) + datas, _ = await self.detector.generate_detections(mat_frame) self.assertEqual(len(datas), 1) self.assertEqual(datas[0][5], 0) mock_local.assert_called_once_with(mat_frame) @@ -331,12 +311,13 @@ def test_generate_detections(self) -> None: self.detector, 'generate_detections_cloud', return_value=[[20, 20, 60, 60, 0.8, 1]], ) as mock_cloud: - datas, _ = self.detector.generate_detections(mat_frame) + datas, _ = await self.detector.generate_detections(mat_frame) self.assertEqual(len(datas), 1) self.assertEqual(datas[0][5], 1) mock_cloud.assert_called_once_with(mat_frame) - def test_run_detection_fail_read_frame(self) -> None: + @pytest.mark.asyncio + async def test_run_detection_fail_read_frame(self) -> None: """ Test the run_detection method when failing to read a frame. """ @@ -355,7 +336,7 @@ def test_run_detection_fail_read_frame(self) -> None: side_effect=[-1, -1, ord('q')], ): try: - self.detector.run_detection(stream_url) + await self.detector.run_detection(stream_url) except StopIteration: pass @@ -431,7 +412,8 @@ def test_remove_completely_contained_labels(self) -> None: ], ) @patch('src.live_stream_detection.LiveStreamDetector.run_detection') - def test_main(self, mock_run_detection: MagicMock) -> None: + @pytest.mark.asyncio + async def test_main(self, mock_run_detection: MagicMock) -> None: """ Test the main function. @@ -444,9 +426,9 @@ def test_main(self, mock_run_detection: MagicMock) -> None: return_value=None, ) as mock_init: mock_init.return_value = None - main() + await main() mock_init.assert_called_once_with( - api_url='http://localhost:5000', + api_url='http://127.0.0.1:8001', model_key='yolo11n', output_folder=None, detect_with_server=True, diff --git a/tests/src/stream_capture_test.py b/tests/src/stream_capture_test.py index fd95683..faed21e 100644 --- a/tests/src/stream_capture_test.py +++ b/tests/src/stream_capture_test.py @@ -7,6 +7,8 @@ from unittest.mock import MagicMock from unittest.mock import patch +import pytest + from src.stream_capture import main as stream_capture_main from src.stream_capture import StreamCapture @@ -19,22 +21,29 @@ class TestStreamCapture(TestCase): def setUp(self) -> None: """Set up a StreamCapture instance for use in tests.""" # Initialise StreamCapture instance with a presumed stream URL - self.stream_capture = StreamCapture('http://example.com/stream') + self.stream_capture: StreamCapture = StreamCapture( + 'http://example.com/stream', + ) @patch('cv2.VideoCapture') - def test_initialise_stream_success( + async def test_initialise_stream_success( self, mock_video_capture: MagicMock, ) -> None: """ Test that the stream is successfully initialised. + + Args: + mock_video_capture (MagicMock): Mock for cv2.VideoCapture. """ # Mock VideoCapture object's isOpened method to # return True, indicating the stream opened successfully mock_video_capture.return_value.isOpened.return_value = True # Call initialise_stream method to initialise the stream - self.stream_capture.initialise_stream(self.stream_capture.stream_url) + await self.stream_capture.initialise_stream( + self.stream_capture.stream_url, + ) # Assert that the cap object is successfully initialised self.assertIsNotNone(self.stream_capture.cap) @@ -45,17 +54,21 @@ def test_initialise_stream_success( ) # Release resources - self.stream_capture.release_resources() + await self.stream_capture.release_resources() @patch('cv2.VideoCapture') - @patch('time.sleep') - def test_initialise_stream_retry( + @patch('time.sleep', return_value=None) + async def test_initialise_stream_retry( self, mock_sleep: MagicMock, mock_video_capture: MagicMock, ) -> None: """ Test that the stream initialisation retries if it fails initially. + + Args: + mock_sleep (MagicMock): Mock for time.sleep. + mock_video_capture (MagicMock): Mock for cv2.VideoCapture. """ # Mock VideoCapture object's isOpened method to # return False on the first call and True on the second @@ -63,7 +76,9 @@ def test_initialise_stream_retry( instance.isOpened.side_effect = [False, True] # Call initialise_stream method to simulate retry mechanism - self.stream_capture.initialise_stream(self.stream_capture.stream_url) + await self.stream_capture.initialise_stream( + self.stream_capture.stream_url, + ) # Assert that the cap object is eventually successfully initialised self.assertIsNotNone(self.stream_capture.cap) @@ -71,22 +86,16 @@ def test_initialise_stream_retry( # Verify that sleep method was called once to wait before retrying mock_sleep.assert_called_once_with(5) - # Verify that VideoCapture object was initialised once - # and opened successfully on the second attempt - self.assertEqual(mock_video_capture.call_count, 1) - self.assertEqual(instance.open.call_count, 1) - - def test_release_resources(self) -> None: + async def test_release_resources(self) -> None: """ Test that resources are released correctly. """ # Initialise StreamCapture instance and mock cap object - stream_capture = StreamCapture('test_stream_url') + stream_capture: StreamCapture = StreamCapture('test_stream_url') stream_capture.cap = MagicMock() - stream_capture.cap.release = MagicMock() # Call release_resources method to release resources - stream_capture.release_resources() + await stream_capture.release_resources() # Assert that cap object is set to None self.assertIsNone(stream_capture.cap) @@ -94,7 +103,7 @@ def test_release_resources(self) -> None: @patch('cv2.VideoCapture') @patch('cv2.Mat') @patch('time.sleep', return_value=None) - def test_execute_capture( + async def test_execute_capture( self, mock_sleep: MagicMock, mock_mat: MagicMock, @@ -102,6 +111,10 @@ def test_execute_capture( ) -> None: """ Test that frames are captured and returned with a timestamp. + + Args: + mock_sleep (MagicMock): Mock for time.sleep. + mock_video_capture (MagicMock): Mock for cv2.VideoCapture. """ # Mock VideoCapture object's read method to # return a frame and True indicating successful read @@ -110,7 +123,7 @@ def test_execute_capture( # Execute capture frame generator and get the first frame and timestamp generator = self.stream_capture.execute_capture() - frame, timestamp = next(generator) + frame, timestamp = await generator.__anext__() # Assert that the captured frame is not None # and the timestamp is a float @@ -118,12 +131,15 @@ def test_execute_capture( self.assertIsInstance(timestamp, float) # Release resources - self.stream_capture.release_resources() + await self.stream_capture.release_resources() @patch('speedtest.Speedtest') def test_check_internet_speed(self, mock_speedtest: MagicMock) -> None: """ Test that internet speed is correctly checked and returned. + + Args: + mock_speedtest (MagicMock): Mock for speedtest.Speedtest. """ # Mock Speedtest object's download and upload methods # to return download and upload speeds @@ -146,6 +162,9 @@ def test_select_quality_based_on_speed_high_speed( """ Test that the highest quality stream is selected for high internet speed. + + Args: + mock_streams (MagicMock): Mock for streamlink.streams. """ # Mock streamlink to return different quality streams mock_streams.return_value = { @@ -175,6 +194,9 @@ def test_select_quality_based_on_speed_medium_speed( """ Test that an appropriate quality stream is selected for medium internet speed. + + Args: + mock_streams (MagicMock): Mock for streamlink.streams. """ # Mock streamlink to return medium quality streams mock_streams.return_value = { @@ -202,6 +224,9 @@ def test_select_quality_based_on_speed_low_speed( ) -> None: """ Test that a lower quality stream is selected for low internet speed. + + Args: + mock_streams (MagicMock): Mock for streamlink.streams. """ # Mock streamlink to return low quality streams mock_streams.return_value = { @@ -231,6 +256,10 @@ def test_select_quality_based_on_speed_no_quality( ) -> None: """ Test that None is returned if no suitable stream quality is available. + + Args: + mock_check_speed (MagicMock): Mock for check_internet_speed method. + mock_streams (MagicMock): Mock for streamlink.streams. """ # Mock internet speed and stream quality check result to be empty selected_quality = self.stream_capture.select_quality_based_on_speed() @@ -246,7 +275,7 @@ def test_select_quality_based_on_speed_no_quality( @patch.object(StreamCapture, 'check_internet_speed', return_value=(20, 5)) @patch('cv2.VideoCapture') @patch('time.sleep', return_value=None) - def test_capture_generic_frames( + async def test_capture_generic_frames( self, mock_sleep: MagicMock, mock_video_capture: MagicMock, @@ -255,6 +284,12 @@ def test_capture_generic_frames( ) -> None: """ Test that generic frames are captured and returned with a timestamp. + + Args: + mock_sleep (MagicMock): Mock for time.sleep. + mock_video_capture (MagicMock): Mock for cv2.VideoCapture. + mock_check_speed (MagicMock): Mock for check_internet_speed method. + mock_streams (MagicMock): Mock for streamlink.streams. """ # Mock VideoCapture object's behaviour mock_video_capture.return_value.read.return_value = (True, MagicMock()) @@ -262,14 +297,14 @@ def test_capture_generic_frames( # Execute capture frame generator generator = self.stream_capture.capture_generic_frames() - frame, timestamp = next(generator) + frame, timestamp = await generator.__anext__() # Verify the returned frame and timestamp self.assertIsNotNone(frame) self.assertIsInstance(timestamp, float) # Release resources - self.stream_capture.release_resources() + await self.stream_capture.release_resources() def test_update_capture_interval(self) -> None: """ @@ -279,53 +314,53 @@ def test_update_capture_interval(self) -> None: self.stream_capture.update_capture_interval(20) self.assertEqual(self.stream_capture.capture_interval, 20) - @patch('src.stream_capture.StreamCapture') @patch('argparse.ArgumentParser.parse_args') - def test_main_function( + async def test_main_function( self, mock_parse_args: MagicMock, - mock_stream_capture: MagicMock, ) -> None: """ Test that the main function correctly initialises and executes StreamCapture. + + Args: + mock_parse_args (MagicMock): + Mock for argparse.ArgumentParser.parse_args. """ # Mock command line argument parsing mock_parse_args.return_value = argparse.Namespace( url='test_stream_url', ) - # Create mock instance of StreamCapture + # Mock command line argument parsing mock_capture_instance = MagicMock() - mock_stream_capture.return_value = mock_capture_instance - - # Mock execute_capture() method to return data - mock_capture_instance.execute_capture.return_value = iter( - [('frame_data', 1234567890.0)], - ) - - # Simulate executing the main function - with patch.object( - sys, 'argv', ['stream_capture.py', '--url', 'test_stream_url'], + with patch( + 'src.stream_capture.StreamCapture', + return_value=mock_capture_instance, ): - stream_capture_main() - - # Verify that StreamCapture was correctly initialised and called - mock_stream_capture.assert_called_once_with('test_stream_url') - mock_capture_instance.execute_capture.assert_called_once() + with patch.object( + sys, 'argv', ['stream_capture.py', '--url', 'test_stream_url'], + ): + await stream_capture_main() + mock_capture_instance.execute_capture.assert_called_once() @patch('cv2.VideoCapture') @patch('time.sleep', return_value=None) - def test_execute_capture_failures( + @pytest.mark.asyncio + async def test_execute_capture_failures( self, mock_sleep: MagicMock, mock_video_capture: MagicMock, ) -> None: """ Test that execute_capture handles multiple failures before success. + + Args: + mock_sleep (MagicMock): Mock for time.sleep. + mock_video_capture (MagicMock): Mock for cv2.VideoCapture. """ # Mock VideoCapture object's multiple failures and one success read - instance = mock_video_capture.return_value + instance: MagicMock = mock_video_capture.return_value instance.read.side_effect = [(False, None)] * 5 + [(True, MagicMock())] instance.isOpened.return_value = True @@ -336,7 +371,7 @@ def test_execute_capture_failures( return_value=iter([(MagicMock(), 1234567890.0)]), ) as mock_capture_generic_frames: generator = self.stream_capture.execute_capture() - frame, timestamp = next(generator) + frame, timestamp = await generator.__anext__() self.assertIsNotNone(frame) self.assertIsInstance(timestamp, float) @@ -344,7 +379,7 @@ def test_execute_capture_failures( mock_capture_generic_frames.assert_called_once() # Release resources - self.stream_capture.release_resources() + await self.stream_capture.release_resources() if __name__ == '__main__': diff --git a/tests/src/utils_test.py b/tests/src/utils_test.py index a112a11..3953cfd 100644 --- a/tests/src/utils_test.py +++ b/tests/src/utils_test.py @@ -6,6 +6,7 @@ from unittest.mock import MagicMock from unittest.mock import patch +import pytest from watchdog.events import FileModifiedEvent from src.utils import FileEventHandler @@ -28,8 +29,9 @@ def test_is_expired_with_none(self): self.assertFalse(Utils.is_expired(None)) +@pytest.mark.asyncio class TestFileEventHandler(unittest.TestCase): - def test_on_modified_triggers_callback(self): + async def test_on_modified_triggers_callback(self): # Create a mock callback function mock_callback = MagicMock() @@ -45,12 +47,14 @@ def test_on_modified_triggers_callback(self): event = FileModifiedEvent(file_path) # Trigger the on_modified event - event_handler.on_modified(event) + await event_handler.on_modified(event) # Assert that the callback was called mock_callback.assert_called_once() - def test_on_modified_does_not_trigger_callback_for_different_file(self): + async def test_on_modified_does_not_trigger_callback_for_different_file( + self, + ): # Create a mock callback function mock_callback = MagicMock() @@ -67,17 +71,18 @@ def test_on_modified_does_not_trigger_callback_for_different_file(self): event = FileModifiedEvent(different_file_path) # Trigger the on_modified event - event_handler.on_modified(event) + await event_handler.on_modified(event) # Assert that the callback was not called mock_callback.assert_not_called() +@pytest.mark.asyncio class TestRedisManager(unittest.TestCase): """ Test cases for the RedisManager class """ - @patch('src.utils.Redis') + @patch('src.utils.redis.Redis') def setUp(self, mock_redis): """ Set up a RedisManager instance with a mocked Redis connection @@ -89,7 +94,7 @@ def setUp(self, mock_redis): # Initialize RedisManager self.redis_manager = RedisManager() - def test_set_success(self): + async def test_set_success(self): """ Test successful set operation """ @@ -97,12 +102,12 @@ def test_set_success(self): value = b'test_value' # Call the set method - self.redis_manager.set(key, value) + await self.redis_manager.set(key, value) # Assert that the Redis set method was called with correct parameters self.mock_redis_instance.set.assert_called_once_with(key, value) - def test_set_error(self): + async def test_set_error(self): """ Simulate an exception during the Redis set operation """ @@ -112,9 +117,9 @@ def test_set_error(self): # Call the set method and verify it handles the exception with self.assertLogs(level='ERROR'): - self.redis_manager.set(key, value) + await self.redis_manager.set(key, value) - def test_get_success(self): + async def test_get_success(self): """ Mock the Redis get method to return a value """ @@ -123,7 +128,7 @@ def test_get_success(self): self.mock_redis_instance.get.return_value = expected_value # Call the get method - value = self.redis_manager.get(key) + value = await self.redis_manager.get(key) # Assert that the Redis get method was called with correct parameters self.mock_redis_instance.get.assert_called_once_with(key) @@ -131,7 +136,7 @@ def test_get_success(self): # Assert the value returned is correct self.assertEqual(value, expected_value) - def test_get_error(self): + async def test_get_error(self): """ Simulate an exception during the Redis get operation """ @@ -140,23 +145,23 @@ def test_get_error(self): # Call the get method and verify it handles the exception with self.assertLogs(level='ERROR'): - value = self.redis_manager.get(key) + value = await self.redis_manager.get(key) self.assertIsNone(value) - def test_delete_success(self): + async def test_delete_success(self): """ Test successful delete operation """ key = 'test_key' # Call the delete method - self.redis_manager.delete(key) + await self.redis_manager.delete(key) # Assert that the Redis delete method # was called with correct parameters self.mock_redis_instance.delete.assert_called_once_with(key) - def test_delete_error(self): + async def test_delete_error(self): """ Simulate an exception during the Redis delete operation """ @@ -165,7 +170,7 @@ def test_delete_error(self): # Call the delete method and verify it handles the exception with self.assertLogs(level='ERROR'): - self.redis_manager.delete(key) + await self.redis_manager.delete(key) if __name__ == '__main__':