Skip to content

Commit

Permalink
Fixed #34901 -- Added async-compatible interface to session engines.
Browse files Browse the repository at this point in the history
Thanks Andrew-Chen-Wang for the initial implementation which was posted
to the Django forum thread about asyncifying contrib modules.
  • Loading branch information
bigfootjon authored and felixxm committed Mar 13, 2024
1 parent 33c06ca commit f5c3406
Show file tree
Hide file tree
Showing 12 changed files with 975 additions and 9 deletions.
4 changes: 3 additions & 1 deletion django/contrib/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,4 +269,6 @@ def update_session_auth_hash(request, user):

async def aupdate_session_auth_hash(request, user):
"""See update_session_auth_hash()."""
return await sync_to_async(update_session_auth_hash)(request, user)
await request.session.acycle_key()
if hasattr(user, "get_session_auth_hash") and request.user == user:
await request.session.aset(HASH_SESSION_KEY, user.get_session_auth_hash())
159 changes: 159 additions & 0 deletions django/contrib/sessions/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import string
from datetime import datetime, timedelta

from asgiref.sync import sync_to_async

from django.conf import settings
from django.core import signing
from django.utils import timezone
Expand Down Expand Up @@ -56,6 +58,10 @@ def __setitem__(self, key, value):
self._session[key] = value
self.modified = True

async def aset(self, key, value):
(await self._aget_session())[key] = value
self.modified = True

def __delitem__(self, key):
del self._session[key]
self.modified = True
Expand All @@ -67,27 +73,52 @@ def key_salt(self):
def get(self, key, default=None):
return self._session.get(key, default)

async def aget(self, key, default=None):
return (await self._aget_session()).get(key, default)

def pop(self, key, default=__not_given):
self.modified = self.modified or key in self._session
args = () if default is self.__not_given else (default,)
return self._session.pop(key, *args)

async def apop(self, key, default=__not_given):
self.modified = self.modified or key in (await self._aget_session())
args = () if default is self.__not_given else (default,)
return (await self._aget_session()).pop(key, *args)

def setdefault(self, key, value):
if key in self._session:
return self._session[key]
else:
self[key] = value
return value

async def asetdefault(self, key, value):
session = await self._aget_session()
if key in session:
return session[key]
else:
await self.aset(key, value)
return value

def set_test_cookie(self):
self[self.TEST_COOKIE_NAME] = self.TEST_COOKIE_VALUE

async def aset_test_cookie(self):
await self.aset(self.TEST_COOKIE_NAME, self.TEST_COOKIE_VALUE)

def test_cookie_worked(self):
return self.get(self.TEST_COOKIE_NAME) == self.TEST_COOKIE_VALUE

async def atest_cookie_worked(self):
return (await self.aget(self.TEST_COOKIE_NAME)) == self.TEST_COOKIE_VALUE

def delete_test_cookie(self):
del self[self.TEST_COOKIE_NAME]

async def adelete_test_cookie(self):
del (await self._aget_session())[self.TEST_COOKIE_NAME]

def encode(self, session_dict):
"Return the given session dictionary serialized and encoded as a string."
return signing.dumps(
Expand Down Expand Up @@ -115,18 +146,34 @@ def update(self, dict_):
self._session.update(dict_)
self.modified = True

async def aupdate(self, dict_):
(await self._aget_session()).update(dict_)
self.modified = True

def has_key(self, key):
return key in self._session

async def ahas_key(self, key):
return key in (await self._aget_session())

def keys(self):
return self._session.keys()

async def akeys(self):
return (await self._aget_session()).keys()

def values(self):
return self._session.values()

async def avalues(self):
return (await self._aget_session()).values()

def items(self):
return self._session.items()

async def aitems(self):
return (await self._aget_session()).items()

def clear(self):
# To avoid unnecessary persistent storage accesses, we set up the
# internals directly (loading data wastes time, since we are going to
Expand All @@ -149,11 +196,22 @@ def _get_new_session_key(self):
if not self.exists(session_key):
return session_key

async def _aget_new_session_key(self):
while True:
session_key = get_random_string(32, VALID_KEY_CHARS)
if not await self.aexists(session_key):
return session_key

def _get_or_create_session_key(self):
if self._session_key is None:
self._session_key = self._get_new_session_key()
return self._session_key

async def _aget_or_create_session_key(self):
if self._session_key is None:
self._session_key = await self._aget_new_session_key()
return self._session_key

def _validate_session_key(self, key):
"""
Key must be truthy and at least 8 characters long. 8 characters is an
Expand Down Expand Up @@ -191,6 +249,17 @@ def _get_session(self, no_load=False):
self._session_cache = self.load()
return self._session_cache

async def _aget_session(self, no_load=False):
self.accessed = True
try:
return self._session_cache
except AttributeError:
if self.session_key is None or no_load:
self._session_cache = {}
else:
self._session_cache = await self.aload()
return self._session_cache

_session = property(_get_session)

def get_session_cookie_age(self):
Expand Down Expand Up @@ -223,6 +292,25 @@ def get_expiry_age(self, **kwargs):
delta = expiry - modification
return delta.days * 86400 + delta.seconds

async def aget_expiry_age(self, **kwargs):
try:
modification = kwargs["modification"]
except KeyError:
modification = timezone.now()
try:
expiry = kwargs["expiry"]
except KeyError:
expiry = await self.aget("_session_expiry")

if not expiry: # Checks both None and 0 cases
return self.get_session_cookie_age()
if not isinstance(expiry, (datetime, str)):
return expiry
if isinstance(expiry, str):
expiry = datetime.fromisoformat(expiry)
delta = expiry - modification
return delta.days * 86400 + delta.seconds

def get_expiry_date(self, **kwargs):
"""Get session the expiry date (as a datetime object).
Expand All @@ -246,6 +334,23 @@ def get_expiry_date(self, **kwargs):
expiry = expiry or self.get_session_cookie_age()
return modification + timedelta(seconds=expiry)

async def aget_expiry_date(self, **kwargs):
try:
modification = kwargs["modification"]
except KeyError:
modification = timezone.now()
try:
expiry = kwargs["expiry"]
except KeyError:
expiry = await self.aget("_session_expiry")

if isinstance(expiry, datetime):
return expiry
elif isinstance(expiry, str):
return datetime.fromisoformat(expiry)
expiry = expiry or self.get_session_cookie_age()
return modification + timedelta(seconds=expiry)

def set_expiry(self, value):
"""
Set a custom expiration for the session. ``value`` can be an integer,
Expand Down Expand Up @@ -274,6 +379,20 @@ def set_expiry(self, value):
value = value.isoformat()
self["_session_expiry"] = value

async def aset_expiry(self, value):
if value is None:
# Remove any custom expiration for this session.
try:
await self.apop("_session_expiry")
except KeyError:
pass
return
if isinstance(value, timedelta):
value = timezone.now() + value
if isinstance(value, datetime):
value = value.isoformat()
await self.aset("_session_expiry", value)

def get_expire_at_browser_close(self):
"""
Return ``True`` if the session is set to expire when the browser
Expand All @@ -285,6 +404,11 @@ def get_expire_at_browser_close(self):
return settings.SESSION_EXPIRE_AT_BROWSER_CLOSE
return expiry == 0

async def aget_expire_at_browser_close(self):
if (expiry := await self.aget("_session_expiry")) is None:
return settings.SESSION_EXPIRE_AT_BROWSER_CLOSE
return expiry == 0

def flush(self):
"""
Remove the current session data from the database and regenerate the
Expand All @@ -294,6 +418,11 @@ def flush(self):
self.delete()
self._session_key = None

async def aflush(self):
self.clear()
await self.adelete()
self._session_key = None

def cycle_key(self):
"""
Create a new session key, while retaining the current session data.
Expand All @@ -305,6 +434,17 @@ def cycle_key(self):
if key:
self.delete(key)

async def acycle_key(self):
"""
Create a new session key, while retaining the current session data.
"""
data = await self._aget_session()
key = self.session_key
await self.acreate()
self._session_cache = data
if key:
await self.adelete(key)

# Methods that child classes must implement.

def exists(self, session_key):
Expand All @@ -315,6 +455,9 @@ def exists(self, session_key):
"subclasses of SessionBase must provide an exists() method"
)

async def aexists(self, session_key):
return await sync_to_async(self.exists)(session_key)

def create(self):
"""
Create a new session instance. Guaranteed to create a new object with
Expand All @@ -325,6 +468,9 @@ def create(self):
"subclasses of SessionBase must provide a create() method"
)

async def acreate(self):
return await sync_to_async(self.create)()

def save(self, must_create=False):
"""
Save the session data. If 'must_create' is True, create a new session
Expand All @@ -335,6 +481,9 @@ def save(self, must_create=False):
"subclasses of SessionBase must provide a save() method"
)

async def asave(self, must_create=False):
return await sync_to_async(self.save)(must_create)

def delete(self, session_key=None):
"""
Delete the session data under this key. If the key is None, use the
Expand All @@ -344,6 +493,9 @@ def delete(self, session_key=None):
"subclasses of SessionBase must provide a delete() method"
)

async def adelete(self, session_key=None):
return await sync_to_async(self.delete)(session_key)

def load(self):
"""
Load the session data and return a dictionary.
Expand All @@ -352,6 +504,9 @@ def load(self):
"subclasses of SessionBase must provide a load() method"
)

async def aload(self):
return await sync_to_async(self.load)()

@classmethod
def clear_expired(cls):
"""
Expand All @@ -362,3 +517,7 @@ def clear_expired(cls):
a built-in expiration mechanism, it should be a no-op.
"""
raise NotImplementedError("This backend does not support clear_expired().")

@classmethod
async def aclear_expired(cls):
return await sync_to_async(cls.clear_expired)()
Loading

0 comments on commit f5c3406

Please sign in to comment.