diff --git a/awx/main/consumers.py b/awx/main/consumers.py index b6d8872ebdc1..85acb2cc5fa2 100644 --- a/awx/main/consumers.py +++ b/awx/main/consumers.py @@ -1,3 +1,4 @@ +import collections import json import logging import time @@ -12,12 +13,53 @@ from channels.generic.websocket import AsyncJsonWebsocketConsumer from channels.layers import get_channel_layer from channels.db import database_sync_to_async +from channels_redis.core import RedisChannelLayer logger = logging.getLogger('awx.main.consumers') XRF_KEY = '_auth_user_xrf' +class ExpiringCache(collections.defaultdict): + def __init__(self, default, ttl=60, *args, **kw): + collections.defaultdict.__init__(self, default) + self._expires = collections.OrderedDict() + self.ttl = ttl + + def __setitem__(self, k, v): + collections.defaultdict.__setitem__(self, k, v) + self._expires[k] = time.time() + self.ttl + + def __delitem__(self, k): + try: + collections.defaultdict.__delitem__(self, k) + except KeyError: + # RedisChannelLayer itself _does_ periodically clean up this + # dictionary (e.g., when exceptions like asyncio.CancelledError + # occur) + pass + + def expire(self): + expired = [] + for k in self._expires.keys(): + if self._expires[k] < time.time(): + expired.append(k) + else: + # as this is an OrderedDict, every key after this + # was inserted *later*, so if _this_ key is *not* expired, + # the ones after it aren't either (so we can stop iterating) + break + for k in expired: + del self._expires[k] + del self[k] + + +class ExpiringRedisChannelLayer(RedisChannelLayer): + def __init__(self, *args, **kw): + super(ExpiringRedisChannelLayer, self).__init__(*args, **kw) + self.receive_buffer = ExpiringCache(asyncio.Queue, ttl=self.expiry) + + class WebsocketSecretAuthHelper: """ Middlewareish for websockets to verify node websocket broadcast interconnect. @@ -106,6 +148,7 @@ async def connect(self): async def disconnect(self, code): logger.info(f"client '{self.channel_name}' disconnected from the broadcast group.") await self.channel_layer.group_discard(settings.BROADCAST_WEBSOCKET_GROUP_NAME, self.channel_name) + self.channel_layer.receive_buffer.expire() async def internal_message(self, event): await self.send(event['text']) @@ -137,6 +180,7 @@ async def disconnect(self, code): group_name, self.channel_name, ) + self.channel_layer.receive_buffer.expire() @database_sync_to_async def user_can_see_object_id(self, user_access, oid): diff --git a/awx/settings/defaults.py b/awx/settings/defaults.py index fe7c8c0ba345..76cf46060a44 100644 --- a/awx/settings/defaults.py +++ b/awx/settings/defaults.py @@ -916,7 +916,7 @@ def IS_TESTING(argv=None): CHANNEL_LAYERS = { "default": { - "BACKEND": "channels_redis.core.RedisChannelLayer", + "BACKEND": "awx.main.consumers.ExpiringRedisChannelLayer", "CONFIG": { "hosts": [BROKER_URL], "capacity": 10000,