From 71945a8f5d080569b20bbd2418d6ffcbaea13e44 Mon Sep 17 00:00:00 2001 From: moisses89 <7888669+moisses89@users.noreply.github.com> Date: Mon, 11 Nov 2024 16:08:57 +0100 Subject: [PATCH] Refactor to a cache class --- config/settings/base.py | 4 +- safe_transaction_service/history/cache.py | 183 ++++++++++++++++++ safe_transaction_service/history/signals.py | 2 +- .../history/tests/test_views.py | 45 ++++- safe_transaction_service/history/views.py | 15 +- safe_transaction_service/utils/redis.py | 154 +-------------- 6 files changed, 233 insertions(+), 170 deletions(-) create mode 100644 safe_transaction_service/history/cache.py diff --git a/config/settings/base.py b/config/settings/base.py index 08608b772..c443e4217 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -683,7 +683,9 @@ # Compression level – an integer from 0 to 9. 0 means not compression CACHE_ALL_TXS_COMPRESSION_LEVEL = env.int("CACHE_ALL_TXS_COMPRESSION_LEVEL", default=0) -DEFAULT_CACHE_PAGE_TIMEOUT = env.int("DEFAULT_CACHE_PAGE_TIMEOUT", default=60) +CACHE_VIEW_DEFAULT_TIMEOUT = env.int( + "DEFAULT_CACHE_PAGE_TIMEOUT", default=60 +) # 0 will disable the cache # Contracts reindex batch configuration # ------------------------------------------------------------------------------ diff --git a/safe_transaction_service/history/cache.py b/safe_transaction_service/history/cache.py new file mode 100644 index 000000000..04b6639cf --- /dev/null +++ b/safe_transaction_service/history/cache.py @@ -0,0 +1,183 @@ +import json +from functools import wraps +from typing import List, Optional, Union +from urllib.parse import urlencode + +from django.conf import settings + +from eth_typing import ChecksumAddress +from rest_framework import status +from rest_framework.response import Response + +from safe_transaction_service.history.models import ( + InternalTx, + ModuleTransaction, + MultisigConfirmation, + MultisigTransaction, + TokenTransfer, +) +from safe_transaction_service.utils.redis import get_redis, logger + + +class CacheSafeTxsView: + # Cache tags + LIST_MULTISIGTRANSACTIONS_VIEW_CACHE_KEY = "multisigtransactionsview" + LIST_MODULETRANSACTIONS_VIEW_CACHE_KEY = "moduletransactionsview" + LIST_TRANSFERS_VIEW_CACHE_KEY = "transfersview" + + def __init__(self, cache_tag: str, address: ChecksumAddress): + self.redis = get_redis() + self.address = address + self.cache_tag = cache_tag + self.cache_name = self.get_cache_name() + + def get_cache_name(self) -> str: + """ + Calculate the cache_name from the cache_tag and address + + :param cache_tag: + :param address: + :return: + """ + return f"{self.cache_tag}:{self.address}" + + def get_cache_data(self, cache_path: str) -> Optional[str]: + """ + Return the cache for the provided cache_path + + :param cache_path: + :return: + """ + logger.debug(f"Getting from cache {self.cache_name}{cache_path}") + return self.redis.hget(self.cache_name, cache_path) + + def set_cache_data(self, cache_path: str, data: str, timeout: int): + """ + Set a cache for provided data with the provided timeout + :param cache_path: + :param data: + :param timeout: + :return: + """ + logger.debug( + f"Setting cache {self.cache_name}{cache_path} with TTL {timeout} seconds" + ) + self.redis.hset(self.cache_name, cache_path, data) + self.redis.expire(self.cache_name, timeout) + + def remove_cache(self): + """ + Remove cache key stored in redis for the provided parameters + + :param cache_name: + :return: + """ + cache_name = self.cache_name + logger.debug(f"Removing all the cache for {self.cache_name}") + self.redis.unlink(self.cache_name) + + +def cache_txs_view_for_address( + cache_tag: str, timeout: int = settings.CACHE_VIEW_DEFAULT_TIMEOUT +): + """ + Custom cache decorator that caches the view response. + This decorator caches the response of a view function for a specified timeout. + It allows you to cache the response based on a unique cache name, which can + be used for invalidating. + + :param timeout: Cache timeout in seconds. + :param cache_name: A unique identifier for the cache entry. + """ + + def decorator(view_func): + @wraps(view_func) + def _wrapped_view(request, *args, **kwargs): + # Get query parameters + query_params = request.request.GET.dict() + cache_path = f"{urlencode(query_params)}" + # Calculate cache_name + address = request.kwargs["address"] + cache_txs_view: Optional[CacheSafeTxsView] = None + if address: + cache_txs_view = CacheSafeTxsView(cache_tag, address) + else: + logger.warning( + "Address does not exist in the request, this will not be cached" + ) + cache_txs_view = None + + if cache_txs_view: + # Check if response is cached + response_data = cache_txs_view.get_cache_data(cache_path) + if response_data: + return Response( + status=status.HTTP_200_OK, data=json.loads(response_data) + ) + + # Get response from the view + response = view_func(request, *args, **kwargs) + if response.status_code == 200: + # Just store success responses and if cache is enabled with DEFAULT_CACHE_PAGE_TIMEOUT !=0 + if cache_txs_view: + cache_txs_view.set_cache_data( + cache_path, json.dumps(response.data), timeout + ) + + return response + + return _wrapped_view + + return decorator + + +def remove_cache_view_by_instance( + instance: Union[ + TokenTransfer, + InternalTx, + MultisigConfirmation, + MultisigTransaction, + ModuleTransaction, + ] +): + """ + Remove the cache stored for instance view. + + :param instance: + """ + addresses = [] + cache_tag: Optional[str] = None + if isinstance(instance, TokenTransfer): + cache_tag = CacheSafeTxsView.LIST_TRANSFERS_VIEW_CACHE_KEY + addresses.append(instance.to) + addresses.append(instance._from) + elif isinstance(instance, MultisigTransaction): + cache_tag = CacheSafeTxsView.LIST_MULTISIGTRANSACTIONS_VIEW_CACHE_KEY + addresses.append(instance.safe) + elif isinstance(instance, MultisigConfirmation) and instance.multisig_transaction: + cache_tag = CacheSafeTxsView.LIST_MULTISIGTRANSACTIONS_VIEW_CACHE_KEY + addresses.append(instance.multisig_transaction.safe) + elif isinstance(instance, InternalTx): + cache_tag = CacheSafeTxsView.LIST_TRANSFERS_VIEW_CACHE_KEY + addresses.append(instance.to) + if instance._from: + addresses.append(instance._from) + elif isinstance(instance, ModuleTransaction): + cache_tag = CacheSafeTxsView.LIST_MODULETRANSACTIONS_VIEW_CACHE_KEY + addresses.append(instance.safe) + + if cache_tag: + remove_cache_view_for_addresses(cache_tag, addresses) + + +def remove_cache_view_for_addresses(cache_tag: str, addresses: List[ChecksumAddress]): + """ + Remove several cache for the provided cache_tag and addresses + + :param cache_tag: + :param addresses: + :return: + """ + for address in addresses: + cache_safe_txs = CacheSafeTxsView(cache_tag, address) + cache_safe_txs.remove_cache() diff --git a/safe_transaction_service/history/signals.py b/safe_transaction_service/history/signals.py index 21ed2b9eb..c32a23ef4 100644 --- a/safe_transaction_service/history/signals.py +++ b/safe_transaction_service/history/signals.py @@ -10,7 +10,7 @@ from safe_transaction_service.notifications.tasks import send_notification_task from ..events.services.queue_service import get_queue_service -from ..utils.redis import remove_cache_view_by_instance +from .cache import remove_cache_view_by_instance from .models import ( ERC20Transfer, ERC721Transfer, diff --git a/safe_transaction_service/history/tests/test_views.py b/safe_transaction_service/history/tests/test_views.py index 7fa852912..6d7b9ddfd 100644 --- a/safe_transaction_service/history/tests/test_views.py +++ b/safe_transaction_service/history/tests/test_views.py @@ -33,17 +33,24 @@ from safe_transaction_service.tokens.tests.factories import TokenFactory from safe_transaction_service.utils.utils import datetime_to_str -from ...utils.redis import get_redis, remove_cache_view_by_instance +from ...utils.redis import get_redis +from ..cache import remove_cache_view_by_instance from ..helpers import DelegateSignatureHelper, DeleteMultisigTxSignatureHelper from ..models import ( IndexingStatus, + InternalTx, + ModuleTransaction, MultisigConfirmation, MultisigTransaction, SafeContractDelegate, SafeMasterCopy, ) from ..serializers import TransferType -from ..views import SafeMultisigTransactionListView +from ..views import ( + SafeModuleTransactionListView, + SafeMultisigTransactionListView, + SafeTransferListView, +) from .factories import ( ERC20TransferFactory, ERC721TransferFactory, @@ -642,6 +649,19 @@ def test_get_module_transactions(self): self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data["count"], 1) + # Test that the result should be cached + # Mock get_queryset with empty queryset return value to get proper error in case of fail + with mock.patch.object( + SafeModuleTransactionListView, + "get_queryset", + return_value=ModuleTransaction.objects.none(), + ) as patched_queryset: + response = self.client.get(url, format="json") + # queryset shouldn't be called + patched_queryset.assert_not_called() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data["count"], 1) + def test_get_module_transaction(self): wrong_module_transaction_id = "wrong_module_transaction_id" url = reverse( @@ -3084,14 +3104,27 @@ def test_transfers_view(self): for result in response.data["results"]: self.assertEqual(result["type"], TransferType.ETHER_TRANSFER.name) - response = self.client.get( - reverse("v1:history:transfers", args=(safe_address,)) + "?ether=false", - format="json", - ) + url = reverse("v1:history:transfers", args=(safe_address,)) + "?ether=false" + response = self.client.get(url, format="json") self.assertGreater(len(response.data["results"]), 0) for result in response.data["results"]: self.assertNotEqual(result["type"], TransferType.ETHER_TRANSFER.name) + # Test that the result should be cached + # Mock get_queryset with empty queryset return value to get proper error in case of fail + with mock.patch.object( + SafeTransferListView, + "get_queryset", + return_value=InternalTx.objects.none(), + ) as patched_queryset: + response = self.client.get(url, format="json") + # queryset shouldn't be called + patched_queryset.assert_not_called() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertGreater(len(response.data["results"]), 0) + for result in response.data["results"]: + self.assertNotEqual(result["type"], TransferType.ETHER_TRANSFER.name) + def test_get_transfer_view(self): # test wrong random transfer_id transfer_id = FuzzyText(length=6).fuzz() diff --git a/safe_transaction_service/history/views.py b/safe_transaction_service/history/views.py index ab5111f8b..5d582c380 100644 --- a/safe_transaction_service/history/views.py +++ b/safe_transaction_service/history/views.py @@ -38,12 +38,8 @@ from safe_transaction_service.utils.ethereum import get_chain_id from safe_transaction_service.utils.utils import parse_boolean_query_param -from ..utils.redis import ( - LIST_MODULETRANSACTIONS_VIEW_CACHE_KEY, - LIST_MULTISIGTRANSACTIONS_VIEW_CACHE_KEY, - cache_page_for_address, -) from . import filters, pagination, serializers +from .cache import CacheSafeTxsView, cache_txs_view_for_address from .helpers import add_tokens_to_transfers, is_valid_unique_transfer_id from .models import ( ERC20Transfer, @@ -491,8 +487,8 @@ def get_queryset(self): .order_by("-created") ) - @cache_page_for_address( - cache_tag=LIST_MODULETRANSACTIONS_VIEW_CACHE_KEY, timeout=60 + @cache_txs_view_for_address( + cache_tag=CacheSafeTxsView.LIST_MODULETRANSACTIONS_VIEW_CACHE_KEY ) def get(self, request, address, format=None): """ @@ -692,8 +688,8 @@ def get_serializer_class(self): ), }, ) - @cache_page_for_address( - cache_tag=LIST_MULTISIGTRANSACTIONS_VIEW_CACHE_KEY, timeout=2 + @cache_txs_view_for_address( + cache_tag=CacheSafeTxsView.LIST_MULTISIGTRANSACTIONS_VIEW_CACHE_KEY ) def get(self, request, *args, **kwargs): """ @@ -981,6 +977,7 @@ def list(self, request, *args, **kwargs): ), }, ) + @cache_txs_view_for_address(CacheSafeTxsView.LIST_TRANSFERS_VIEW_CACHE_KEY) def get(self, request, address, format=None): """ Returns the list of token transfers for a given Safe address. diff --git a/safe_transaction_service/utils/redis.py b/safe_transaction_service/utils/redis.py index 9877f66fb..332697097 100644 --- a/safe_transaction_service/utils/redis.py +++ b/safe_transaction_service/utils/redis.py @@ -1,25 +1,10 @@ import copyreg -import json import logging -from functools import cache, wraps -from typing import List, Optional, Union -from urllib.parse import urlencode +from functools import cache from django.conf import settings -from eth_typing import ChecksumAddress from redis import Redis -from rest_framework import status -from rest_framework.response import Response - -from safe_transaction_service.contracts.models import Contract -from safe_transaction_service.history.models import ( - InternalTx, - ModuleTransaction, - MultisigConfirmation, - MultisigTransaction, - TokenTransfer, -) logger = logging.getLogger(__name__) @@ -32,140 +17,3 @@ def get_redis() -> Redis: copyreg.pickle(memoryview, lambda val: (memoryview, (bytes(val),))) return Redis.from_url(settings.REDIS_URL) - - -LIST_MULTISIGTRANSACTIONS_VIEW_CACHE_KEY = "multisigtransactionsview" -LIST_MODULETRANSACTIONS_VIEW_CACHE_KEY = "moduletransactionsview" -LIST_TRANSFERS_VIEW_CACHE_KEY = "transfersview" - - -def get_cache_page_name(cache_tag: str, address: ChecksumAddress) -> str: - """ - Calculate the cache_name from the cache_tag and provided address - - :param cache_tag: - :param address: - :return: - """ - return f"{cache_tag}:{address}" - - -def cache_page_for_address( - cache_tag: str, timeout: int = settings.DEFAULT_CACHE_PAGE_TIMEOUT -): - """ - Custom cache decorator that caches the view response. - This decorator caches the response of a view function for a specified timeout. - It allows you to cache the response based on a unique cache name, which can - be used for invalidating. - - :param timeout: Cache timeout in seconds. - :param cache_name: A unique identifier for the cache entry. - """ - - def decorator(view_func): - @wraps(view_func) - def _wrapped_view(request, *args, **kwargs): - redis = get_redis() - # Get query parameters - query_params = request.request.GET.dict() - cache_path = f"{urlencode(query_params)}" - # Calculate cache_name - address = request.kwargs["address"] - if address: - cache_name = get_cache_page_name(cache_tag, address) - else: - logger.warning( - "Address does not exist in the request, this will not be cached" - ) - cache_name = None - - if cache_name: - # Check if response is cached - response_data = redis.hget(cache_name, cache_path) - if response_data: - logger.debug(f"Getting from cache {cache_name}{cache_path}") - return Response( - status=status.HTTP_200_OK, data=json.loads(response_data) - ) - - # Get response from the view - response = view_func(request, *args, **kwargs) - if response.status_code == 200: - # Just store if there were not issues calculating cache_name - if cache_name: - # We just store the success result - logger.debug( - f"Setting cache {cache_name}{cache_path} with TTL {timeout} seconds" - ) - redis.hset(cache_name, cache_path, json.dumps(response.data)) - redis.expire(cache_name, timeout) - - return response - - return _wrapped_view - - return decorator - - -def remove_cache_page_by_address(cache_tag: str, address: ChecksumAddress): - """ - Remove cache key stored in redis for the provided parameters - - :param cache_name: - :return: - """ - cache_name = get_cache_page_name(cache_tag, address) - - logger.debug(f"Removing all the cache for {cache_name}") - get_redis().unlink(cache_name) - - -def remove_cache_page_for_addresses(cache_tag: str, addresses: List[ChecksumAddress]): - """ - Remove cache for provided addresses - - :param cache_tag: - :param addresses: - :return: - """ - for address in addresses: - remove_cache_page_by_address(cache_tag, address) - - -def remove_cache_view_by_instance( - instance: Union[ - TokenTransfer, - InternalTx, - MultisigConfirmation, - MultisigTransaction, - ModuleTransaction, - Contract, - ] -): - """ - Remove the cache stored for instance view. - - :param instance: - """ - addresses = [] - cache_tag: Optional[str] = None - if isinstance(instance, TokenTransfer): - cache_tag = LIST_TRANSFERS_VIEW_CACHE_KEY - addresses.append(instance.to) - addresses.append(instance._from) - elif isinstance(instance, MultisigTransaction): - cache_tag = LIST_MULTISIGTRANSACTIONS_VIEW_CACHE_KEY - addresses.append(instance.safe) - elif isinstance(instance, MultisigConfirmation) and instance.multisig_transaction: - cache_tag = LIST_MULTISIGTRANSACTIONS_VIEW_CACHE_KEY - addresses.append(instance.multisig_transaction.safe) - elif isinstance(instance, InternalTx): - cache_tag = LIST_TRANSFERS_VIEW_CACHE_KEY - addresses.append(instance.to) - elif isinstance(instance, ModuleTransaction): - cache_tag = LIST_MODULETRANSACTIONS_VIEW_CACHE_KEY - addresses.append(instance.safe) - - if cache_tag: - remove_cache_page_for_addresses(cache_tag, addresses)