Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Update the auth providers to be async. #7935

Merged
merged 5 commits into from
Jul 23, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions docs/password_auth_providers.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ methods.
> will be called for each login attempt where the login type matches one
> of the keys returned by `get_supported_login_types`.
>
> It is passed the (possibly UNqualified) `user` provided by the client,
> It is passed the (possibly unqualified) `user` provided by the client,
> the login type, and a dictionary of login secrets passed by the
> client.
>
> The method should return a Twisted `Deferred` object, which resolves
> The method should return an `Awaitable` object, which resolves
> to the canonical `@localpart:domain` user id if authentication is
> successful, and `None` if not.
>
> Alternatively, the `Deferred` can resolve to a `(str, func)` tuple, in
> Alternatively, the `Awaitable` can resolve to a `(str, func)` tuple, in
> which case the second field is a callback which will be called with
> the result from the `/login` call (including `access_token`,
> `device_id`, etc.)
Expand All @@ -88,11 +88,11 @@ methods.
> passed the medium (ex. "email"), an address (ex.
> "<[email protected]>") and the user's password.
>
> The method should return a Twisted `Deferred` object, which resolves
> The method should return an `Awaitable` object, which resolves
> to a `str` containing the user's (canonical) User ID if
> authentication was successful, and `None` if not.
>
> As with `check_auth`, the `Deferred` may alternatively resolve to a
> As with `check_auth`, the `Awaitable` may alternatively resolve to a
> `(user_id, callback)` tuple.

`someprovider.check_password`(*user_id*, *password*)
Expand All @@ -102,11 +102,11 @@ methods.
> providers that just want to provide a mechanism for validating
> `m.login.password` logins.
>
> Iif implemented, it will be called to check logins with an
> If implemented, it will be called to check logins with an
> `m.login.password` login type. It is passed a qualified
> `@localpart:domain` user id, and the password provided by the user.
>
> The method should return a Twisted `Deferred` object, which resolves
> The method should return an `Awaitable` object, which resolves
> to `True` if authentication is successful, and `False` if not.

`someprovider.on_logged_out`(*user_id*, *device_id*, *access_token*)
Expand All @@ -116,5 +116,5 @@ methods.
> any: access tokens are occasionally created without an associated
> device ID), and the (now deactivated) access token.
>
> It may return a Twisted `Deferred` object; the logout request will
> It may return an `Awaitable` object; the logout request will
> wait for the deferred to complete but the result is ignored.
35 changes: 17 additions & 18 deletions synapse/handlers/ui_auth/checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
# limitations under the License.

import logging
from typing import Any

from canonicaljson import json

from twisted.internet import defer
from twisted.web.client import PartialDownloadError

from synapse.api.constants import LoginType
Expand All @@ -33,25 +33,25 @@ class UserInteractiveAuthChecker:
def __init__(self, hs):
pass

def is_enabled(self):
def is_enabled(self) -> bool:
"""Check if the configuration of the homeserver allows this checker to work

Returns:
bool: True if this login type is enabled.
True if this login type is enabled.
"""

def check_auth(self, authdict, clientip):
async def check_auth(self, authdict: dict, clientip: str) -> Any:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Were you holding off specifying the possibilities here so that you wouldn't need to specify Deferred? If so, it's worth noting that isinstance(a_deferred, Awaitable) resolves to True, so I think something not too cumbersome like Optional[Union[str, Tuple]] should work here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't specify the possibilities because I did the typing before reading the separate documentation and the docstring didn't specify what was returned. 😄 I can double check the return types to the documentation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wait, I know why I did this. (I should have ☕ before replying to things...)

Although this method is also named check_auth, it is NOT the same as above. It gets called twice:

result = await self.checkers[stagetype].check_auth(authdict, clientip)
if result:
await self.store.mark_ui_auth_stage_complete(
authdict["session"], stagetype, result
)

checker = self.checkers.get(login_type)
if checker is not None:
res = await checker.check_auth(authdict, clientip=clientip)
return res

While the one from a password provider gets called:

result = await provider.check_auth(username, login_type, login_dict)
if result:
if isinstance(result, str):
result = (result, None)
return result

Anyway, the result of UserInteractiveAuthChecker.check_auth gets saved into the database and sometimes inspected in other places.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, fair enough then! Thanks for the clear explanation :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're welcome! I'm now realizing that putting these changes in the same PR was confusing...since they're not really related. 😢 Sorry about that!

"""Given the authentication dict from the client, attempt to check this step

Args:
authdict (dict): authentication dictionary from the client
clientip (str): The IP address of the client.
authdict: authentication dictionary from the client
clientip: The IP address of the client.

Raises:
SynapseError if authentication failed

Returns:
Deferred: the result of authentication (to pass back to the client?)
The result of authentication (to pass back to the client?)
"""
raise NotImplementedError()

Expand All @@ -62,8 +62,8 @@ class DummyAuthChecker(UserInteractiveAuthChecker):
def is_enabled(self):
return True

def check_auth(self, authdict, clientip):
return defer.succeed(True)
async def check_auth(self, authdict, clientip):
return True


class TermsAuthChecker(UserInteractiveAuthChecker):
Expand All @@ -72,8 +72,8 @@ class TermsAuthChecker(UserInteractiveAuthChecker):
def is_enabled(self):
return True

def check_auth(self, authdict, clientip):
return defer.succeed(True)
async def check_auth(self, authdict, clientip):
return True


class RecaptchaAuthChecker(UserInteractiveAuthChecker):
Expand All @@ -89,8 +89,7 @@ def __init__(self, hs):
def is_enabled(self):
return self._enabled

@defer.inlineCallbacks
def check_auth(self, authdict, clientip):
async def check_auth(self, authdict, clientip):
try:
user_response = authdict["response"]
except KeyError:
Expand All @@ -107,7 +106,7 @@ def check_auth(self, authdict, clientip):
# TODO: get this from the homeserver rather than creating a new one for
# each request
try:
resp_body = yield self._http_client.post_urlencoded_get_json(
resp_body = await self._http_client.post_urlencoded_get_json(
self._url,
args={
"secret": self._secret,
Expand Down Expand Up @@ -219,8 +218,8 @@ def is_enabled(self):
ThreepidBehaviour.LOCAL,
)

def check_auth(self, authdict, clientip):
return defer.ensureDeferred(self._check_threepid("email", authdict))
async def check_auth(self, authdict, clientip):
return await self._check_threepid("email", authdict)


class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
Expand All @@ -233,8 +232,8 @@ def __init__(self, hs):
def is_enabled(self):
return bool(self.hs.config.account_threepid_delegate_msisdn)

def check_auth(self, authdict, clientip):
return defer.ensureDeferred(self._check_threepid("msisdn", authdict))
async def check_auth(self, authdict, clientip):
return await self._check_threepid("msisdn", authdict)


INTERACTIVE_AUTH_CHECKERS = [
Expand Down