From 3d453486627234937c7bfe6f176890360074696b Mon Sep 17 00:00:00 2001 From: "kevin.zhang" Date: Wed, 10 Apr 2024 10:42:56 +0800 Subject: [PATCH] feat: add redis support for task state management --- app/controllers/v1/video.py | 4 +- app/services/state.py | 125 +++++++++++++++++++++++++++--------- app/services/task.py | 22 +++---- config.example.toml | 5 ++ requirements.txt | 3 +- 5 files changed, 113 insertions(+), 46 deletions(-) diff --git a/app/controllers/v1/video.py b/app/controllers/v1/video.py index 9c5ef176..8b080266 100644 --- a/app/controllers/v1/video.py +++ b/app/controllers/v1/video.py @@ -29,7 +29,7 @@ def create_video(background_tasks: BackgroundTasks, request: Request, body: Task "request_id": request_id, "params": body.dict(), } - sm.update_task(task_id) + sm.state.update_task(task_id) background_tasks.add_task(tm.start, task_id=task_id, params=body) logger.success(f"video created: {utils.to_json(task)}") return utils.get_response(200, task) @@ -46,7 +46,7 @@ def get_task(request: Request, task_id: str = Path(..., description="Task ID"), endpoint = endpoint.rstrip("/") request_id = base.get_task_id(request) - task = sm.get_task(task_id) + task = sm.state.get_task(task_id) if task: task_dir = utils.task_dir() diff --git a/app/services/state.py b/app/services/state.py index 606a2c15..0aa95efe 100644 --- a/app/services/state.py +++ b/app/services/state.py @@ -1,35 +1,96 @@ -# State Management -# This module is responsible for managing the state of the application. -import math +import ast +import json +from abc import ABC, abstractmethod +import redis +from app.config import config +from app.models import const -# 如果你部署在分布式环境中,你可能需要一个中心化的状态管理服务,比如 Redis 或者数据库。 -# 如果你的应用程序是单机的,你可以使用内存来存储状态。 -# If you are deploying in a distributed environment, you might need a centralized state management service like Redis or a database. -# If your application is single-node, you can use memory to store the state. +# Base class for state management +class BaseState(ABC): -from app.models import const -from app.utils import utils - -_tasks = {} - - -def update_task(task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs): - """ - Set the state of the task. - """ - progress = int(progress) - if progress > 100: - progress = 100 - - _tasks[task_id] = { - "state": state, - "progress": progress, - **kwargs, - } - -def get_task(task_id: str): - """ - Get the state of the task. - """ - return _tasks.get(task_id, None) + @abstractmethod + def update_task(self, task_id: str, state: int, progress: int = 0, **kwargs): + pass + + @abstractmethod + def get_task(self, task_id: str): + pass + + +# Memory state management +class MemoryState(BaseState): + + def __init__(self): + self._tasks = {} + + def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs): + progress = int(progress) + if progress > 100: + progress = 100 + + self._tasks[task_id] = { + "state": state, + "progress": progress, + **kwargs, + } + + def get_task(self, task_id: str): + return self._tasks.get(task_id, None) + + +# Redis state management +class RedisState(BaseState): + + def __init__(self, host='localhost', port=6379, db=0): + self._redis = redis.StrictRedis(host=host, port=port, db=db) + + def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs): + progress = int(progress) + if progress > 100: + progress = 100 + + fields = { + "state": state, + "progress": progress, + **kwargs, + } + + for field, value in fields.items(): + self._redis.hset(task_id, field, str(value)) + + def get_task(self, task_id: str): + task_data = self._redis.hgetall(task_id) + if not task_data: + return None + + task = {key.decode('utf-8'): self._convert_to_original_type(value) for key, value in task_data.items()} + return task + + @staticmethod + def _convert_to_original_type(value): + """ + Convert the value from byte string to its original data type. + You can extend this method to handle other data types as needed. + """ + value_str = value.decode('utf-8') + + try: + # try to convert byte string array to list + return ast.literal_eval(value_str) + except (ValueError, SyntaxError): + pass + + if value_str.isdigit(): + return int(value_str) + # Add more conversions here if needed + return value_str + + +# Global state +_enable_redis = config.app.get("enable_redis", False) +_redis_host = config.app.get("redis_host", "localhost") +_redis_port = config.app.get("redis_port", 6379) +_redis_db = config.app.get("redis_db", 0) + +state = RedisState(host=_redis_host, port=_redis_port, db=_redis_db) if _enable_redis else MemoryState() diff --git a/app/services/task.py b/app/services/task.py index d5a62182..595ddc03 100644 --- a/app/services/task.py +++ b/app/services/task.py @@ -28,7 +28,7 @@ def start(task_id, params: VideoParams): } """ logger.info(f"start task: {task_id}") - sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=5) + sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=5) video_subject = params.video_subject voice_name = voice.parse_voice_name(params.voice_name) @@ -44,7 +44,7 @@ def start(task_id, params: VideoParams): else: logger.debug(f"video script: \n{video_script}") - sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=10) + sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=10) logger.info("\n\n## generating video terms") video_terms = params.video_terms @@ -70,13 +70,13 @@ def start(task_id, params: VideoParams): with open(script_file, "w", encoding="utf-8") as f: f.write(utils.to_json(script_data)) - sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=20) + sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=20) logger.info("\n\n## generating audio") audio_file = path.join(utils.task_dir(task_id), f"audio.mp3") sub_maker = voice.tts(text=video_script, voice_name=voice_name, voice_file=audio_file) if sub_maker is None: - sm.update_task(task_id, state=const.TASK_STATE_FAILED) + sm.state.update_task(task_id, state=const.TASK_STATE_FAILED) logger.error( "failed to generate audio, maybe the network is not available. if you are in China, please use a VPN.") return @@ -84,7 +84,7 @@ def start(task_id, params: VideoParams): audio_duration = voice.get_audio_duration(sub_maker) audio_duration = math.ceil(audio_duration) - sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=30) + sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=30) subtitle_path = "" if params.subtitle_enabled: @@ -108,7 +108,7 @@ def start(task_id, params: VideoParams): logger.warning(f"subtitle file is invalid: {subtitle_path}") subtitle_path = "" - sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=40) + sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=40) logger.info("\n\n## downloading videos") downloaded_videos = material.download_videos(task_id=task_id, @@ -119,12 +119,12 @@ def start(task_id, params: VideoParams): max_clip_duration=max_clip_duration, ) if not downloaded_videos: - sm.update_task(task_id, state=const.TASK_STATE_FAILED) + sm.state.update_task(task_id, state=const.TASK_STATE_FAILED) logger.error( "failed to download videos, maybe the network is not available. if you are in China, please use a VPN.") return - sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=50) + sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=50) final_video_paths = [] combined_video_paths = [] @@ -146,7 +146,7 @@ def start(task_id, params: VideoParams): threads=n_threads) _progress += 50 / params.video_count / 2 - sm.update_task(task_id, progress=_progress) + sm.state.update_task(task_id, progress=_progress) final_video_path = path.join(utils.task_dir(task_id), f"final-{index}.mp4") @@ -160,7 +160,7 @@ def start(task_id, params: VideoParams): ) _progress += 50 / params.video_count / 2 - sm.update_task(task_id, progress=_progress) + sm.state.update_task(task_id, progress=_progress) final_video_paths.append(final_video_path) combined_video_paths.append(combined_video_path) @@ -171,5 +171,5 @@ def start(task_id, params: VideoParams): "videos": final_video_paths, "combined_videos": combined_video_paths } - sm.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs) + sm.state.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs) return kwargs diff --git a/config.example.toml b/config.example.toml index 634706fd..14efeb82 100644 --- a/config.example.toml +++ b/config.example.toml @@ -129,6 +129,11 @@ material_directory = "" + # Used for state management of the task + enable_redis = true + redis_host = "localhost" + redis_port = 6379 + redis_db = 0 [whisper] # Only effective when subtitle_provider is "whisper" diff --git a/requirements.txt b/requirements.txt index 3098c5cc..28358af9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,4 +15,5 @@ pydantic~=2.6.3 g4f~=0.2.5.4 dashscope~=1.15.0 google.generativeai~=0.4.1 -python-multipart~=0.0.9 \ No newline at end of file +python-multipart~=0.0.9 +redis==5.0.3 \ No newline at end of file