Skip to content

Commit

Permalink
Make it easier to customize authentication.
Browse files Browse the repository at this point in the history
  • Loading branch information
aaugustin committed May 23, 2021
1 parent c0750da commit dd6d6bc
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 9 deletions.
3 changes: 3 additions & 0 deletions docs/project/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ They may change at any time.

* Optimized default compression settings to reduce memory usage.

* Made it easier to customize authentication with
:meth:`~auth.BasicAuthWebSocketServerProtocol.check_credentials`.

9.0.2
.....

Expand Down
9 changes: 7 additions & 2 deletions docs/reference/server.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,15 @@ Basic authentication

.. autoclass:: BasicAuthWebSocketServerProtocol

.. automethod:: process_request
.. attribute:: realm

Scope of protection.

If provided, it should contain only ASCII characters because the
encoding of non-ASCII characters is undefined.

.. attribute:: username

Username of the authenticated user.


.. automethod:: check_credentials
36 changes: 29 additions & 7 deletions src/websockets/legacy/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,44 @@ class BasicAuthWebSocketServerProtocol(WebSocketServerProtocol):
"""

realm = ""

def __init__(
self,
*args: Any,
realm: str,
check_credentials: Callable[[str, str], Awaitable[bool]],
realm: Optional[str] = None,
check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None,
**kwargs: Any,
) -> None:
self.realm = realm
self.check_credentials = check_credentials
if realm is not None:
self.realm = realm # shadow class attribute
self._check_credentials = check_credentials
super().__init__(*args, **kwargs)

async def check_credentials(self, username: str, password: str) -> bool:
"""
Check whether credentials are authorized.
If ``check_credentials`` returns ``True``, the WebSocket handshake
continues. If it returns ``False``, the handshake fails with a HTTP
401 error.
This coroutine may be overridden in a subclass, for example to
authenticate against a database or an external service.
"""
if self._check_credentials is not None:
return await self._check_credentials(username, password)

return False

async def process_request(
self, path: str, request_headers: Headers
self,
path: str,
request_headers: Headers,
) -> Optional[HTTPResponse]:
"""
Check HTTP Basic Auth and return a HTTP 401 or 403 response if needed.
Check HTTP Basic Auth and return a HTTP 401 response if needed.
"""
try:
Expand Down Expand Up @@ -84,7 +106,7 @@ async def process_request(


def basic_auth_protocol_factory(
realm: str,
realm: Optional[str] = None,
credentials: Optional[Union[Credentials, Iterable[Credentials]]] = None,
check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None,
create_protocol: Optional[Callable[[Any], BasicAuthWebSocketServerProtocol]] = None,
Expand Down
18 changes: 18 additions & 0 deletions tests/legacy/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ async def process_request(self, path, request_headers):
return await super().process_request(path, request_headers)


class CheckWebSocketServerProtocol(BasicAuthWebSocketServerProtocol):
async def check_credentials(self, username, password):
return password == "letmein"


class AuthClientServerTests(ClientServerTestsMixin, AsyncioTestCase):

create_protocol = basic_auth_protocol_factory(
Expand Down Expand Up @@ -103,6 +108,19 @@ def test_basic_auth_custom_protocol(self):
self.loop.run_until_complete(self.client.send("Hello!"))
self.loop.run_until_complete(self.client.recv())

@with_server(create_protocol=CheckWebSocketServerProtocol)
@with_client(user_info=("hello", "letmein"))
def test_basic_auth_custom_protocol_subclass(self):
self.loop.run_until_complete(self.client.send("Hello!"))
self.loop.run_until_complete(self.client.recv())

# CustomWebSocketServerProtocol doesn't override check_credentials
@with_server(create_protocol=CustomWebSocketServerProtocol)
def test_basic_auth_defaults_to_deny_all(self):
with self.assertRaises(InvalidStatusCode) as raised:
self.start_client(user_info=("hello", "iloveyou"))
self.assertEqual(raised.exception.status_code, 401)

@with_server(create_protocol=create_protocol)
def test_basic_auth_missing_credentials(self):
with self.assertRaises(InvalidStatusCode) as raised:
Expand Down

0 comments on commit dd6d6bc

Please sign in to comment.