diff --git a/sql_caching/middleware.py b/sql_caching/middleware.py index 24da6b0..2ef9372 100644 --- a/sql_caching/middleware.py +++ b/sql_caching/middleware.py @@ -8,11 +8,11 @@ ''' from threading import local import itertools +from django.utils.deprecation import MiddlewareMixin from django.db.models.sql.compiler import SQLCompiler from django.db.models.sql.datastructures import EmptyResultSet from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE - _thread_locals = local() @@ -30,7 +30,7 @@ def execute_sql_cache(self, *args, **kwargs): if hasattr(_thread_locals, 'query_cache'): sql = get_sql(self) # ('SELECT * FROM ...', (50)) <= sql string, args tuple - if sql[0][:6].upper() == 'SELECT': + if sql[0].upper().lstrip().startswith('SELECT'): # uses the tuple of sql + args as the cache key if sql in _thread_locals.query_cache: @@ -57,7 +57,7 @@ def execute_sql_cache(self, *args, **kwargs): return self._execute_sql(*args, **kwargs) -class QueryCacheMiddleware(object): +class QueryCacheMiddleware(MiddlewareMixin): def process_request(self, request): _thread_locals.query_cache = {}