Skip to content

Commit

Permalink
feat: add redis support for task state management
Browse files Browse the repository at this point in the history
  • Loading branch information
kevin.zhang committed Apr 10, 2024
1 parent a0944fa commit 3d45348
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 46 deletions.
4 changes: 2 additions & 2 deletions app/controllers/v1/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def create_video(background_tasks: BackgroundTasks, request: Request, body: Task
"request_id": request_id,
"params": body.dict(),
}
sm.update_task(task_id)
sm.state.update_task(task_id)
background_tasks.add_task(tm.start, task_id=task_id, params=body)
logger.success(f"video created: {utils.to_json(task)}")
return utils.get_response(200, task)
Expand All @@ -46,7 +46,7 @@ def get_task(request: Request, task_id: str = Path(..., description="Task ID"),
endpoint = endpoint.rstrip("/")

request_id = base.get_task_id(request)
task = sm.get_task(task_id)
task = sm.state.get_task(task_id)
if task:
task_dir = utils.task_dir()

Expand Down
125 changes: 93 additions & 32 deletions app/services/state.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,96 @@
# State Management
# This module is responsible for managing the state of the application.
import math
import ast
import json
from abc import ABC, abstractmethod
import redis
from app.config import config
from app.models import const

# 如果你部署在分布式环境中,你可能需要一个中心化的状态管理服务,比如 Redis 或者数据库。
# 如果你的应用程序是单机的,你可以使用内存来存储状态。

# If you are deploying in a distributed environment, you might need a centralized state management service like Redis or a database.
# If your application is single-node, you can use memory to store the state.
# Base class for state management
class BaseState(ABC):

from app.models import const
from app.utils import utils

_tasks = {}


def update_task(task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
"""
Set the state of the task.
"""
progress = int(progress)
if progress > 100:
progress = 100

_tasks[task_id] = {
"state": state,
"progress": progress,
**kwargs,
}

def get_task(task_id: str):
"""
Get the state of the task.
"""
return _tasks.get(task_id, None)
@abstractmethod
def update_task(self, task_id: str, state: int, progress: int = 0, **kwargs):
pass

@abstractmethod
def get_task(self, task_id: str):
pass


# Memory state management
class MemoryState(BaseState):

def __init__(self):
self._tasks = {}

def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
progress = int(progress)
if progress > 100:
progress = 100

self._tasks[task_id] = {
"state": state,
"progress": progress,
**kwargs,
}

def get_task(self, task_id: str):
return self._tasks.get(task_id, None)


# Redis state management
class RedisState(BaseState):

def __init__(self, host='localhost', port=6379, db=0):
self._redis = redis.StrictRedis(host=host, port=port, db=db)

def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
progress = int(progress)
if progress > 100:
progress = 100

fields = {
"state": state,
"progress": progress,
**kwargs,
}

for field, value in fields.items():
self._redis.hset(task_id, field, str(value))

def get_task(self, task_id: str):
task_data = self._redis.hgetall(task_id)
if not task_data:
return None

task = {key.decode('utf-8'): self._convert_to_original_type(value) for key, value in task_data.items()}
return task

@staticmethod
def _convert_to_original_type(value):
"""
Convert the value from byte string to its original data type.
You can extend this method to handle other data types as needed.
"""
value_str = value.decode('utf-8')

try:
# try to convert byte string array to list
return ast.literal_eval(value_str)
except (ValueError, SyntaxError):
pass

if value_str.isdigit():
return int(value_str)
# Add more conversions here if needed
return value_str


# Global state
_enable_redis = config.app.get("enable_redis", False)
_redis_host = config.app.get("redis_host", "localhost")
_redis_port = config.app.get("redis_port", 6379)
_redis_db = config.app.get("redis_db", 0)

state = RedisState(host=_redis_host, port=_redis_port, db=_redis_db) if _enable_redis else MemoryState()
22 changes: 11 additions & 11 deletions app/services/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def start(task_id, params: VideoParams):
}
"""
logger.info(f"start task: {task_id}")
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=5)
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=5)

video_subject = params.video_subject
voice_name = voice.parse_voice_name(params.voice_name)
Expand All @@ -44,7 +44,7 @@ def start(task_id, params: VideoParams):
else:
logger.debug(f"video script: \n{video_script}")

sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=10)
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=10)

logger.info("\n\n## generating video terms")
video_terms = params.video_terms
Expand All @@ -70,21 +70,21 @@ def start(task_id, params: VideoParams):
with open(script_file, "w", encoding="utf-8") as f:
f.write(utils.to_json(script_data))

sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=20)
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=20)

logger.info("\n\n## generating audio")
audio_file = path.join(utils.task_dir(task_id), f"audio.mp3")
sub_maker = voice.tts(text=video_script, voice_name=voice_name, voice_file=audio_file)
if sub_maker is None:
sm.update_task(task_id, state=const.TASK_STATE_FAILED)
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
logger.error(
"failed to generate audio, maybe the network is not available. if you are in China, please use a VPN.")
return

audio_duration = voice.get_audio_duration(sub_maker)
audio_duration = math.ceil(audio_duration)

sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=30)
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=30)

subtitle_path = ""
if params.subtitle_enabled:
Expand All @@ -108,7 +108,7 @@ def start(task_id, params: VideoParams):
logger.warning(f"subtitle file is invalid: {subtitle_path}")
subtitle_path = ""

sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=40)
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=40)

logger.info("\n\n## downloading videos")
downloaded_videos = material.download_videos(task_id=task_id,
Expand All @@ -119,12 +119,12 @@ def start(task_id, params: VideoParams):
max_clip_duration=max_clip_duration,
)
if not downloaded_videos:
sm.update_task(task_id, state=const.TASK_STATE_FAILED)
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
logger.error(
"failed to download videos, maybe the network is not available. if you are in China, please use a VPN.")
return

sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=50)
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=50)

final_video_paths = []
combined_video_paths = []
Expand All @@ -146,7 +146,7 @@ def start(task_id, params: VideoParams):
threads=n_threads)

_progress += 50 / params.video_count / 2
sm.update_task(task_id, progress=_progress)
sm.state.update_task(task_id, progress=_progress)

final_video_path = path.join(utils.task_dir(task_id), f"final-{index}.mp4")

Expand All @@ -160,7 +160,7 @@ def start(task_id, params: VideoParams):
)

_progress += 50 / params.video_count / 2
sm.update_task(task_id, progress=_progress)
sm.state.update_task(task_id, progress=_progress)

final_video_paths.append(final_video_path)
combined_video_paths.append(combined_video_path)
Expand All @@ -171,5 +171,5 @@ def start(task_id, params: VideoParams):
"videos": final_video_paths,
"combined_videos": combined_video_paths
}
sm.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs)
sm.state.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs)
return kwargs
5 changes: 5 additions & 0 deletions config.example.toml
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@

material_directory = ""

# Used for state management of the task
enable_redis = true
redis_host = "localhost"
redis_port = 6379
redis_db = 0

[whisper]
# Only effective when subtitle_provider is "whisper"
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ pydantic~=2.6.3
g4f~=0.2.5.4
dashscope~=1.15.0
google.generativeai~=0.4.1
python-multipart~=0.0.9
python-multipart~=0.0.9
redis==5.0.3

0 comments on commit 3d45348

Please sign in to comment.