Skip to content

Commit

Permalink
work around a memory leak in channels_redis
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanpetrello committed Aug 11, 2020
1 parent def79de commit aa3f553
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 1 deletion.
50 changes: 50 additions & 0 deletions awx/main/consumers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections
import json
import logging
import time
Expand All @@ -12,12 +13,59 @@
from channels.generic.websocket import AsyncJsonWebsocketConsumer
from channels.layers import get_channel_layer
from channels.db import database_sync_to_async
from channels_redis.core import RedisChannelLayer


logger = logging.getLogger('awx.main.consumers')
XRF_KEY = '_auth_user_xrf'


class ExpiringCache(collections.defaultdict):

next_clean = 0

def __init__(self, default, ttl=60, *args, **kw):
collections.defaultdict.__init__(self, default)
self._expires = collections.OrderedDict()
self.ttl = ttl

def __setitem__(self, k, v):
collections.defaultdict.__setitem__(self, k, v)
self._expires[k] = time.time() + self.ttl

def __delitem__(self, k):
try:
collections.defaultdict.__delitem__(self, k)
except KeyError:
# RedisChannelLayer itself _does_ periodically clean up this
# dictionary (e.g., when exceptions like asyncio.CancelledError
# occur)
pass

def expire(self):
if time.time() < self.next_clean:
return
self.next_clean = time.time() + self.ttl
expired = []
for k in self._expires.keys():
if self._expires[k] < time.time():
expired.append(k)
else:
# as this is an OrderedDict, every key after this
# was inserted *later*, so if _this_ key is *not* expired,
# the ones after it aren't either (so we can stop iterating)
break
for k in expired:
del self._expires[k]
del self[k]


class ExpiringRedisChannelLayer(RedisChannelLayer):
def __init__(self, *args, **kw):
super(ExpiringRedisChannelLayer, self).__init__(*args, **kw)
self.receive_buffer = ExpiringCache(asyncio.Queue, ttl=self.expiry)


class WebsocketSecretAuthHelper:
"""
Middlewareish for websockets to verify node websocket broadcast interconnect.
Expand Down Expand Up @@ -109,6 +157,7 @@ async def disconnect(self, code):

async def internal_message(self, event):
await self.send(event['text'])
self.channel_layer.receive_buffer.expire()


class EventConsumer(AsyncJsonWebsocketConsumer):
Expand Down Expand Up @@ -204,6 +253,7 @@ async def receive_json(self, data):

async def internal_message(self, event):
await self.send(event['text'])
self.channel_layer.receive_buffer.expire()


def run_sync(func):
Expand Down
2 changes: 1 addition & 1 deletion awx/settings/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,7 @@ def IS_TESTING(argv=None):

CHANNEL_LAYERS = {
"default": {
"BACKEND": "channels_redis.core.RedisChannelLayer",
"BACKEND": "awx.main.consumers.ExpiringRedisChannelLayer",
"CONFIG": {
"hosts": [BROKER_URL],
"capacity": 10000,
Expand Down

0 comments on commit aa3f553

Please sign in to comment.