Skip to content

Commit

Permalink
Refactor to a cache class
Browse files Browse the repository at this point in the history
  • Loading branch information
moisses89 committed Nov 11, 2024
1 parent 6199cbc commit 71945a8
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 170 deletions.
4 changes: 3 additions & 1 deletion config/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,9 @@

# Compression level – an integer from 0 to 9. 0 means not compression
CACHE_ALL_TXS_COMPRESSION_LEVEL = env.int("CACHE_ALL_TXS_COMPRESSION_LEVEL", default=0)
DEFAULT_CACHE_PAGE_TIMEOUT = env.int("DEFAULT_CACHE_PAGE_TIMEOUT", default=60)
CACHE_VIEW_DEFAULT_TIMEOUT = env.int(
"DEFAULT_CACHE_PAGE_TIMEOUT", default=60
) # 0 will disable the cache

# Contracts reindex batch configuration
# ------------------------------------------------------------------------------
Expand Down
183 changes: 183 additions & 0 deletions safe_transaction_service/history/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import json
from functools import wraps
from typing import List, Optional, Union
from urllib.parse import urlencode

from django.conf import settings

from eth_typing import ChecksumAddress
from rest_framework import status
from rest_framework.response import Response

from safe_transaction_service.history.models import (
InternalTx,
ModuleTransaction,
MultisigConfirmation,
MultisigTransaction,
TokenTransfer,
)
from safe_transaction_service.utils.redis import get_redis, logger


class CacheSafeTxsView:
# Cache tags
LIST_MULTISIGTRANSACTIONS_VIEW_CACHE_KEY = "multisigtransactionsview"
LIST_MODULETRANSACTIONS_VIEW_CACHE_KEY = "moduletransactionsview"
LIST_TRANSFERS_VIEW_CACHE_KEY = "transfersview"

def __init__(self, cache_tag: str, address: ChecksumAddress):
self.redis = get_redis()
self.address = address
self.cache_tag = cache_tag
self.cache_name = self.get_cache_name()

def get_cache_name(self) -> str:
"""
Calculate the cache_name from the cache_tag and address
:param cache_tag:
:param address:
:return:
"""
return f"{self.cache_tag}:{self.address}"

def get_cache_data(self, cache_path: str) -> Optional[str]:
"""
Return the cache for the provided cache_path
:param cache_path:
:return:
"""
logger.debug(f"Getting from cache {self.cache_name}{cache_path}")
return self.redis.hget(self.cache_name, cache_path)

def set_cache_data(self, cache_path: str, data: str, timeout: int):
"""
Set a cache for provided data with the provided timeout
:param cache_path:
:param data:
:param timeout:
:return:
"""
logger.debug(
f"Setting cache {self.cache_name}{cache_path} with TTL {timeout} seconds"
)
self.redis.hset(self.cache_name, cache_path, data)
self.redis.expire(self.cache_name, timeout)

def remove_cache(self):
"""
Remove cache key stored in redis for the provided parameters
:param cache_name:
:return:
"""
cache_name = self.cache_name
logger.debug(f"Removing all the cache for {self.cache_name}")
self.redis.unlink(self.cache_name)


def cache_txs_view_for_address(
cache_tag: str, timeout: int = settings.CACHE_VIEW_DEFAULT_TIMEOUT
):
"""
Custom cache decorator that caches the view response.
This decorator caches the response of a view function for a specified timeout.
It allows you to cache the response based on a unique cache name, which can
be used for invalidating.
:param timeout: Cache timeout in seconds.
:param cache_name: A unique identifier for the cache entry.
"""

def decorator(view_func):
@wraps(view_func)
def _wrapped_view(request, *args, **kwargs):
# Get query parameters
query_params = request.request.GET.dict()
cache_path = f"{urlencode(query_params)}"
# Calculate cache_name
address = request.kwargs["address"]
cache_txs_view: Optional[CacheSafeTxsView] = None
if address:
cache_txs_view = CacheSafeTxsView(cache_tag, address)
else:
logger.warning(
"Address does not exist in the request, this will not be cached"
)
cache_txs_view = None

if cache_txs_view:
# Check if response is cached
response_data = cache_txs_view.get_cache_data(cache_path)
if response_data:
return Response(
status=status.HTTP_200_OK, data=json.loads(response_data)
)

# Get response from the view
response = view_func(request, *args, **kwargs)
if response.status_code == 200:
# Just store success responses and if cache is enabled with DEFAULT_CACHE_PAGE_TIMEOUT !=0
if cache_txs_view:
cache_txs_view.set_cache_data(
cache_path, json.dumps(response.data), timeout
)

return response

return _wrapped_view

return decorator


def remove_cache_view_by_instance(
instance: Union[
TokenTransfer,
InternalTx,
MultisigConfirmation,
MultisigTransaction,
ModuleTransaction,
]
):
"""
Remove the cache stored for instance view.
:param instance:
"""
addresses = []
cache_tag: Optional[str] = None
if isinstance(instance, TokenTransfer):
cache_tag = CacheSafeTxsView.LIST_TRANSFERS_VIEW_CACHE_KEY
addresses.append(instance.to)
addresses.append(instance._from)
elif isinstance(instance, MultisigTransaction):
cache_tag = CacheSafeTxsView.LIST_MULTISIGTRANSACTIONS_VIEW_CACHE_KEY
addresses.append(instance.safe)
elif isinstance(instance, MultisigConfirmation) and instance.multisig_transaction:
cache_tag = CacheSafeTxsView.LIST_MULTISIGTRANSACTIONS_VIEW_CACHE_KEY
addresses.append(instance.multisig_transaction.safe)
elif isinstance(instance, InternalTx):
cache_tag = CacheSafeTxsView.LIST_TRANSFERS_VIEW_CACHE_KEY
addresses.append(instance.to)
if instance._from:
addresses.append(instance._from)
elif isinstance(instance, ModuleTransaction):
cache_tag = CacheSafeTxsView.LIST_MODULETRANSACTIONS_VIEW_CACHE_KEY
addresses.append(instance.safe)

if cache_tag:
remove_cache_view_for_addresses(cache_tag, addresses)


def remove_cache_view_for_addresses(cache_tag: str, addresses: List[ChecksumAddress]):
"""
Remove several cache for the provided cache_tag and addresses
:param cache_tag:
:param addresses:
:return:
"""
for address in addresses:
cache_safe_txs = CacheSafeTxsView(cache_tag, address)
cache_safe_txs.remove_cache()
2 changes: 1 addition & 1 deletion safe_transaction_service/history/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from safe_transaction_service.notifications.tasks import send_notification_task

from ..events.services.queue_service import get_queue_service
from ..utils.redis import remove_cache_view_by_instance
from .cache import remove_cache_view_by_instance
from .models import (
ERC20Transfer,
ERC721Transfer,
Expand Down
45 changes: 39 additions & 6 deletions safe_transaction_service/history/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,24 @@
from safe_transaction_service.tokens.tests.factories import TokenFactory
from safe_transaction_service.utils.utils import datetime_to_str

from ...utils.redis import get_redis, remove_cache_view_by_instance
from ...utils.redis import get_redis
from ..cache import remove_cache_view_by_instance
from ..helpers import DelegateSignatureHelper, DeleteMultisigTxSignatureHelper
from ..models import (
IndexingStatus,
InternalTx,
ModuleTransaction,
MultisigConfirmation,
MultisigTransaction,
SafeContractDelegate,
SafeMasterCopy,
)
from ..serializers import TransferType
from ..views import SafeMultisigTransactionListView
from ..views import (
SafeModuleTransactionListView,
SafeMultisigTransactionListView,
SafeTransferListView,
)
from .factories import (
ERC20TransferFactory,
ERC721TransferFactory,
Expand Down Expand Up @@ -642,6 +649,19 @@ def test_get_module_transactions(self):
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["count"], 1)

# Test that the result should be cached
# Mock get_queryset with empty queryset return value to get proper error in case of fail
with mock.patch.object(
SafeModuleTransactionListView,
"get_queryset",
return_value=ModuleTransaction.objects.none(),
) as patched_queryset:
response = self.client.get(url, format="json")
# queryset shouldn't be called
patched_queryset.assert_not_called()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["count"], 1)

def test_get_module_transaction(self):
wrong_module_transaction_id = "wrong_module_transaction_id"
url = reverse(
Expand Down Expand Up @@ -3084,14 +3104,27 @@ def test_transfers_view(self):
for result in response.data["results"]:
self.assertEqual(result["type"], TransferType.ETHER_TRANSFER.name)

response = self.client.get(
reverse("v1:history:transfers", args=(safe_address,)) + "?ether=false",
format="json",
)
url = reverse("v1:history:transfers", args=(safe_address,)) + "?ether=false"
response = self.client.get(url, format="json")
self.assertGreater(len(response.data["results"]), 0)
for result in response.data["results"]:
self.assertNotEqual(result["type"], TransferType.ETHER_TRANSFER.name)

# Test that the result should be cached
# Mock get_queryset with empty queryset return value to get proper error in case of fail
with mock.patch.object(
SafeTransferListView,
"get_queryset",
return_value=InternalTx.objects.none(),
) as patched_queryset:
response = self.client.get(url, format="json")
# queryset shouldn't be called
patched_queryset.assert_not_called()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertGreater(len(response.data["results"]), 0)
for result in response.data["results"]:
self.assertNotEqual(result["type"], TransferType.ETHER_TRANSFER.name)

def test_get_transfer_view(self):
# test wrong random transfer_id
transfer_id = FuzzyText(length=6).fuzz()
Expand Down
15 changes: 6 additions & 9 deletions safe_transaction_service/history/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,8 @@
from safe_transaction_service.utils.ethereum import get_chain_id
from safe_transaction_service.utils.utils import parse_boolean_query_param

from ..utils.redis import (
LIST_MODULETRANSACTIONS_VIEW_CACHE_KEY,
LIST_MULTISIGTRANSACTIONS_VIEW_CACHE_KEY,
cache_page_for_address,
)
from . import filters, pagination, serializers
from .cache import CacheSafeTxsView, cache_txs_view_for_address
from .helpers import add_tokens_to_transfers, is_valid_unique_transfer_id
from .models import (
ERC20Transfer,
Expand Down Expand Up @@ -491,8 +487,8 @@ def get_queryset(self):
.order_by("-created")
)

@cache_page_for_address(
cache_tag=LIST_MODULETRANSACTIONS_VIEW_CACHE_KEY, timeout=60
@cache_txs_view_for_address(
cache_tag=CacheSafeTxsView.LIST_MODULETRANSACTIONS_VIEW_CACHE_KEY
)
def get(self, request, address, format=None):
"""
Expand Down Expand Up @@ -692,8 +688,8 @@ def get_serializer_class(self):
),
},
)
@cache_page_for_address(
cache_tag=LIST_MULTISIGTRANSACTIONS_VIEW_CACHE_KEY, timeout=2
@cache_txs_view_for_address(
cache_tag=CacheSafeTxsView.LIST_MULTISIGTRANSACTIONS_VIEW_CACHE_KEY
)
def get(self, request, *args, **kwargs):
"""
Expand Down Expand Up @@ -981,6 +977,7 @@ def list(self, request, *args, **kwargs):
),
},
)
@cache_txs_view_for_address(CacheSafeTxsView.LIST_TRANSFERS_VIEW_CACHE_KEY)
def get(self, request, address, format=None):
"""
Returns the list of token transfers for a given Safe address.
Expand Down
Loading

0 comments on commit 71945a8

Please sign in to comment.