diff --git a/socketio/asyncio_server.py b/socketio/asyncio_server.py index cbd812bf..f0e4679a 100644 --- a/socketio/asyncio_server.py +++ b/socketio/asyncio_server.py @@ -3,6 +3,7 @@ import engineio from . import asyncio_manager +from . import exceptions from . import packet from . import server @@ -320,11 +321,18 @@ async def _handle_connect(self, sid, namespace): """Handle a client connection request.""" namespace = namespace or '/' self.manager.connect(sid, namespace) - if await self._trigger_event('connect', namespace, sid, - self.environ[sid]) is False: + + try: + success = await self._trigger_event('connect', namespace, sid, self.environ[sid]) + except exceptions.ConnectionRefusedError as exc: + fail_reason = exc.get_info() + success = False + else: + fail_reason = None + + if success is False: self.manager.disconnect(sid, namespace) - await self._send_packet(sid, packet.Packet(packet.ERROR, - namespace=namespace)) + await self._send_packet(sid, packet.Packet(packet.ERROR, data=fail_reason, namespace=namespace)) if sid in self.environ: # pragma: no cover del self.environ[sid] return False diff --git a/socketio/exceptions.py b/socketio/exceptions.py index 5bd86979..2c325f2e 100644 --- a/socketio/exceptions.py +++ b/socketio/exceptions.py @@ -4,3 +4,17 @@ class SocketIOError(Exception): class ConnectionError(SocketIOError): pass + + +class ConnectionRefusedError(ConnectionError): + """ + Raised when connection is refused on the application level + """ + def __init__(self, info): + self._info = info + + def get_info(self): + """ + This method could be overridden in subclass to add extra logic for data output + """ + return self._info diff --git a/socketio/server.py b/socketio/server.py index 449c94a1..8151e535 100644 --- a/socketio/server.py +++ b/socketio/server.py @@ -4,6 +4,7 @@ import six from . import base_manager +from . import exceptions from . import packet from . import namespace @@ -485,11 +486,17 @@ def _handle_connect(self, sid, namespace): """Handle a client connection request.""" namespace = namespace or '/' self.manager.connect(sid, namespace) - if self._trigger_event('connect', namespace, sid, - self.environ[sid]) is False: + try: + success = self._trigger_event('connect', namespace, sid, self.environ[sid]) + except exceptions.ConnectionRefusedError as exc: + fail_reason = exc.get_info() + success = False + else: + fail_reason = None + + if success is False: self.manager.disconnect(sid, namespace) - self._send_packet(sid, packet.Packet(packet.ERROR, - namespace=namespace)) + self._send_packet(sid, packet.Packet(packet.ERROR, data=fail_reason, namespace=namespace)) if sid in self.environ: # pragma: no cover del self.environ[sid] return False