forked from harry0703/MoneyPrinterTurbo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add redis support for task state management
- Loading branch information
kevin.zhang
committed
Apr 10, 2024
1 parent
a0944fa
commit 3d45348
Showing
5 changed files
with
113 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,35 +1,96 @@ | ||
# State Management | ||
# This module is responsible for managing the state of the application. | ||
import math | ||
import ast | ||
import json | ||
from abc import ABC, abstractmethod | ||
import redis | ||
from app.config import config | ||
from app.models import const | ||
|
||
# 如果你部署在分布式环境中,你可能需要一个中心化的状态管理服务,比如 Redis 或者数据库。 | ||
# 如果你的应用程序是单机的,你可以使用内存来存储状态。 | ||
|
||
# If you are deploying in a distributed environment, you might need a centralized state management service like Redis or a database. | ||
# If your application is single-node, you can use memory to store the state. | ||
# Base class for state management | ||
class BaseState(ABC): | ||
|
||
from app.models import const | ||
from app.utils import utils | ||
|
||
_tasks = {} | ||
|
||
|
||
def update_task(task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs): | ||
""" | ||
Set the state of the task. | ||
""" | ||
progress = int(progress) | ||
if progress > 100: | ||
progress = 100 | ||
|
||
_tasks[task_id] = { | ||
"state": state, | ||
"progress": progress, | ||
**kwargs, | ||
} | ||
|
||
def get_task(task_id: str): | ||
""" | ||
Get the state of the task. | ||
""" | ||
return _tasks.get(task_id, None) | ||
@abstractmethod | ||
def update_task(self, task_id: str, state: int, progress: int = 0, **kwargs): | ||
pass | ||
|
||
@abstractmethod | ||
def get_task(self, task_id: str): | ||
pass | ||
|
||
|
||
# Memory state management | ||
class MemoryState(BaseState): | ||
|
||
def __init__(self): | ||
self._tasks = {} | ||
|
||
def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs): | ||
progress = int(progress) | ||
if progress > 100: | ||
progress = 100 | ||
|
||
self._tasks[task_id] = { | ||
"state": state, | ||
"progress": progress, | ||
**kwargs, | ||
} | ||
|
||
def get_task(self, task_id: str): | ||
return self._tasks.get(task_id, None) | ||
|
||
|
||
# Redis state management | ||
class RedisState(BaseState): | ||
|
||
def __init__(self, host='localhost', port=6379, db=0): | ||
self._redis = redis.StrictRedis(host=host, port=port, db=db) | ||
|
||
def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs): | ||
progress = int(progress) | ||
if progress > 100: | ||
progress = 100 | ||
|
||
fields = { | ||
"state": state, | ||
"progress": progress, | ||
**kwargs, | ||
} | ||
|
||
for field, value in fields.items(): | ||
self._redis.hset(task_id, field, str(value)) | ||
|
||
def get_task(self, task_id: str): | ||
task_data = self._redis.hgetall(task_id) | ||
if not task_data: | ||
return None | ||
|
||
task = {key.decode('utf-8'): self._convert_to_original_type(value) for key, value in task_data.items()} | ||
return task | ||
|
||
@staticmethod | ||
def _convert_to_original_type(value): | ||
""" | ||
Convert the value from byte string to its original data type. | ||
You can extend this method to handle other data types as needed. | ||
""" | ||
value_str = value.decode('utf-8') | ||
|
||
try: | ||
# try to convert byte string array to list | ||
return ast.literal_eval(value_str) | ||
except (ValueError, SyntaxError): | ||
pass | ||
|
||
if value_str.isdigit(): | ||
return int(value_str) | ||
# Add more conversions here if needed | ||
return value_str | ||
|
||
|
||
# Global state | ||
_enable_redis = config.app.get("enable_redis", False) | ||
_redis_host = config.app.get("redis_host", "localhost") | ||
_redis_port = config.app.get("redis_port", 6379) | ||
_redis_db = config.app.get("redis_db", 0) | ||
|
||
state = RedisState(host=_redis_host, port=_redis_port, db=_redis_db) if _enable_redis else MemoryState() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters