Skip to content

Commit

Permalink
Add ConnectionRefusedError and handling for it
Browse files Browse the repository at this point in the history
  • Loading branch information
andreyrusanov committed Jan 28, 2019
1 parent a8b13fc commit 38edd2c
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 8 deletions.
16 changes: 12 additions & 4 deletions socketio/asyncio_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import engineio

from . import asyncio_manager
from . import exceptions
from . import packet
from . import server

Expand Down Expand Up @@ -320,11 +321,18 @@ async def _handle_connect(self, sid, namespace):
"""Handle a client connection request."""
namespace = namespace or '/'
self.manager.connect(sid, namespace)
if await self._trigger_event('connect', namespace, sid,
self.environ[sid]) is False:

try:
success = await self._trigger_event('connect', namespace, sid, self.environ[sid])
except exceptions.ConnectionRefusedError as exc:
fail_reason = exc.get_info()
success = False
else:
fail_reason = None

if success is False:
self.manager.disconnect(sid, namespace)
await self._send_packet(sid, packet.Packet(packet.ERROR,
namespace=namespace))
await self._send_packet(sid, packet.Packet(packet.ERROR, data=fail_reason, namespace=namespace))
if sid in self.environ: # pragma: no cover
del self.environ[sid]
return False
Expand Down
14 changes: 14 additions & 0 deletions socketio/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,17 @@ class SocketIOError(Exception):

class ConnectionError(SocketIOError):
pass


class ConnectionRefusedError(ConnectionError):
"""
Raised when connection is refused on the application level
"""
def __init__(self, info):
self._info = info

def get_info(self):
"""
This method could be overridden in subclass to add extra logic for data output
"""
return self._info
15 changes: 11 additions & 4 deletions socketio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import six

from . import base_manager
from . import exceptions
from . import packet
from . import namespace

Expand Down Expand Up @@ -485,11 +486,17 @@ def _handle_connect(self, sid, namespace):
"""Handle a client connection request."""
namespace = namespace or '/'
self.manager.connect(sid, namespace)
if self._trigger_event('connect', namespace, sid,
self.environ[sid]) is False:
try:
success = self._trigger_event('connect', namespace, sid, self.environ[sid])
except exceptions.ConnectionRefusedError as exc:
fail_reason = exc.get_info()
success = False
else:
fail_reason = None

if success is False:
self.manager.disconnect(sid, namespace)
self._send_packet(sid, packet.Packet(packet.ERROR,
namespace=namespace))
self._send_packet(sid, packet.Packet(packet.ERROR, data=fail_reason, namespace=namespace))
if sid in self.environ: # pragma: no cover
del self.environ[sid]
return False
Expand Down

0 comments on commit 38edd2c

Please sign in to comment.