diff --git a/app/__init__.py b/app/__init__.py index 7cdb960..0c187a3 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -1,14 +1,16 @@ import os -import redis from flask import Flask from flask_sqlalchemy import SQLAlchemy from flask_migrate import Migrate +from app.cache import Cache + db = SQLAlchemy() migrate = Migrate() -redis_client = redis.Redis(host=os.environ.get("REDIS_HOST", "localhost"), port=os.environ.get("REDIS_PORT", 6379), db=0) +# redis_client = redis.Redis(host=os.environ.get("REDIS_HOST", "localhost"), port=os.environ.get("REDIS_PORT", 6379), db=0) +app_cache = cache.Cache() def create_app(): @@ -22,18 +24,17 @@ def create_app(): with app.app_context(): from . import routes db.create_all() - init_redis() - print("started") + init_cache() return app -def init_redis(): +def init_cache(): from app.models import Auth keys = Auth.query.all() for key in keys: if key.token and key.user: - redis_client.set(key.token, key.user) + app_cache.set(key.token, key.user) diff --git a/app/cache.py b/app/cache.py new file mode 100644 index 0000000..2b73ece --- /dev/null +++ b/app/cache.py @@ -0,0 +1,36 @@ +import redis + + +class Cache: + def __init__(self, type="in_memory", **kwargs): + self.type = type + if type == "redis": + self.cache = redis.Redis(host=kwargs["host"], port=kwargs["port"], db=0) + else: + self.cache = {} + + self.func_map = { + "redis": {"set": self._set_redis_key, "get": self._get_redis_key}, + "in_memory": { + "set": self._set_in_memory_key, + "get": self._get_in_memory_key, + }, + } + + def set(self, key, value): + self.func_map[self.type]["set"](key, value) + + def get(self, key): + return self.func_map[self.type]["get"](key) + + def _set_redis_key(self, key: str, value): + self.cache.set(key, value) + + def _set_in_memory_key(self, key, value): + self.cache[key] = value + + def _get_redis_key(self, key): + return self.cache.get(key) + + def _get_in_memory_key(self, key): + return self.cache.get(key) diff --git a/app/utils.py b/app/utils.py index ad4cc90..40e0a8d 100644 --- a/app/utils.py +++ b/app/utils.py @@ -8,7 +8,7 @@ from slack import WebClient from flask import request, jsonify -from app import redis_client +from app import app_cache from app.models import User, Submission from app.constants import ( STANDUP_CHANNEL_ID, @@ -29,7 +29,7 @@ def check_authorization(*args, **kwargs): return func(*args, **kwargs) else: auth_key = request.headers.get("Authorization", "") - if redis_client.get(auth_key): + if app_cache.get(auth_key): return func(*args, **kwargs) else: return jsonify(