From bfc15e888c3992d7e1bb305fb991e8eed385a95b Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 16 Feb 2022 14:44:42 +0000 Subject: [PATCH 1/7] Allow modules to set a display name on registration --- .../password_auth_provider_callbacks.md | 29 +++++++++- synapse/handlers/auth.py | 58 +++++++++++++++++++ synapse/module_api/__init__.py | 5 ++ 3 files changed, 91 insertions(+), 1 deletion(-) diff --git a/docs/modules/password_auth_provider_callbacks.md b/docs/modules/password_auth_provider_callbacks.md index 88b59bb09e61..b1964734c168 100644 --- a/docs/modules/password_auth_provider_callbacks.md +++ b/docs/modules/password_auth_provider_callbacks.md @@ -166,6 +166,34 @@ any of the subsequent implementations of this callback. If every callback return the username provided by the user is used, if any (otherwise one is automatically generated). +### `get_display_name_for_registration` + +_First introduced in Synapse v1.54.0_ + +```python +async def get_display_name_for_registration( + uia_results: Dict[str, Any], + params: Dict[str, Any], +) -> Optional[str] +``` + +Called when registering a new user. The module can return a display name to set for the +user being registered by returning it as a string, or `None` if it doesn't wish to force a +display name for this user. + +This callback is called once [User-Interactive Authentication](https://spec.matrix.org/latest/client-server-api/#user-interactive-authentication-api) +has been completed by the user. It is not called when registering a user via SSO. It is +passed two dictionaries, which include the information that the user has provided during +the registration process. These dictionaries are identical to the ones passed to +[`get_username_for_registration`](#get_username_for_registration), so refer to the +documentation of this callback for more information about them. + +If multiple modules implement this callback, they will be considered in order. If a +callback returns `None`, Synapse falls through to the next one. The value of the first +callback that does not return `None` will be used. If this happens, Synapse will not call +any of the subsequent implementations of this callback. If every callback return `None`, +the username will be used (e.g. `alice` if the user being registered is `@alice:example.com`). + ## `is_3pid_allowed` _First introduced in Synapse v1.53.0_ @@ -196,7 +224,6 @@ The example module below implements authentication checkers for two different lo - Expects a `password` field to be sent to `/login` - Is checked by the method: `self.check_pass` - ```python from typing import Awaitable, Callable, Optional, Tuple diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 6959d1aa7e47..b5865290b34f 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -2064,6 +2064,10 @@ def run(*args: Tuple, **kwargs: Dict) -> Awaitable: [JsonDict, JsonDict], Awaitable[Optional[str]], ] +GET_DISPLAY_NAME_FOR_REGISTRATION_CALLBACK = Callable[ + [JsonDict, JsonDict], + Awaitable[Optional[str]], +] IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]] @@ -2080,6 +2084,9 @@ def __init__(self) -> None: self.get_username_for_registration_callbacks: List[ GET_USERNAME_FOR_REGISTRATION_CALLBACK ] = [] + self.get_display_name_for_registration_callbacks: List[ + GET_DISPLAY_NAME_FOR_REGISTRATION_CALLBACK + ] = [] self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = [] # Mapping from login type to login parameters @@ -2099,6 +2106,9 @@ def register_password_auth_provider_callbacks( get_username_for_registration: Optional[ GET_USERNAME_FOR_REGISTRATION_CALLBACK ] = None, + get_display_name_for_registration: Optional[ + GET_DISPLAY_NAME_FOR_REGISTRATION_CALLBACK + ] = None, ) -> None: # Register check_3pid_auth callback if check_3pid_auth is not None: @@ -2148,6 +2158,11 @@ def register_password_auth_provider_callbacks( get_username_for_registration, ) + if get_display_name_for_registration is not None: + self.get_display_name_for_registration_callbacks.append( + get_display_name_for_registration, + ) + if is_3pid_allowed is not None: self.is_3pid_allowed_callbacks.append(is_3pid_allowed) @@ -2350,6 +2365,49 @@ async def get_username_for_registration( return None + async def get_display_name_for_registration( + self, + uia_results: JsonDict, + params: JsonDict, + ) -> Optional[str]: + """Defines the profile information to use when registering the user, using the + credentials and parameters provided during the UIA flow. + + Stops at the first callback that returns a tuple containing at least one string. + + Args: + uia_results: The credentials provided during the UIA flow. + params: The parameters provided by the registration request. + + Returns: + A tuple which first element is the display name, and the second is an MXC URL + to the user's avatar. + """ + for callback in self.get_display_name_for_registration_callbacks: + try: + res = await callback(uia_results, params) + + if isinstance(res, str): + return res + elif res is not None: + # mypy complains that this line is unreachable because it assumes the + # data returned by the module fits the expected type. We just want + # to make sure this is the case. + logger.warning( # type: ignore[unreachable] + "Ignoring non-string value returned by" + " get_display_name_for_registration callback %s: %s", + callback, + res, + ) + except Exception as e: + logger.error( + "Module raised an exception in get_profile_for_registration: %s", + e, + ) + raise SynapseError(code=500, msg="Internal Server Error") + + return None + async def is_3pid_allowed( self, medium: str, diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index d4fca369231a..99d274844dcf 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -70,6 +70,7 @@ from synapse.handlers.auth import ( CHECK_3PID_AUTH_CALLBACK, CHECK_AUTH_CALLBACK, + GET_DISPLAY_NAME_FOR_REGISTRATION_CALLBACK, GET_USERNAME_FOR_REGISTRATION_CALLBACK, IS_3PID_ALLOWED_CALLBACK, ON_LOGGED_OUT_CALLBACK, @@ -317,6 +318,9 @@ def register_password_auth_provider_callbacks( get_username_for_registration: Optional[ GET_USERNAME_FOR_REGISTRATION_CALLBACK ] = None, + get_display_name_for_registration: Optional[ + GET_DISPLAY_NAME_FOR_REGISTRATION_CALLBACK + ] = None, ) -> None: """Registers callbacks for password auth provider capabilities. @@ -328,6 +332,7 @@ def register_password_auth_provider_callbacks( is_3pid_allowed=is_3pid_allowed, auth_checkers=auth_checkers, get_username_for_registration=get_username_for_registration, + get_display_name_for_registration=get_display_name_for_registration, ) def register_background_update_controller_callbacks( From e050eb8e35f27163e2387d50ac02942925fe375b Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 16 Feb 2022 14:45:01 +0000 Subject: [PATCH 2/7] Use the callback --- synapse/rest/client/register.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index c965e2bda2f9..f5819f181f6e 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -694,11 +694,19 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: session_id ) + desired_display_name = await ( + self.password_auth_provider.get_display_name_for_registration( + auth_result, + params, + ) + ) + registered_user_id = await self.registration_handler.register_user( localpart=desired_username, password_hash=password_hash, guest_access_token=guest_access_token, threepid=threepid, + default_display_name=desired_display_name, address=client_addr, user_agent_ips=entries, ) From fd36fe606285feb921ec2b171aea03216e668772 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 16 Feb 2022 14:45:11 +0000 Subject: [PATCH 3/7] Test --- tests/handlers/test_password_providers.py | 48 ++++++++++++++++++++--- 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 4740dd0a65e3..e5356ffb3542 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -163,6 +163,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): account.register_servlets, ] + CALLBACK_USERNAME = "get_username_for_registration" + CALLBACK_DISPLAYNAME = "get_display_name_for_registration" + def setUp(self): # we use a global mock device, so make sure we are starting with a clean slate mock_password_provider.reset_mock() @@ -754,7 +757,9 @@ def test_username(self): """Tests that the get_username_for_registration callback can define the username of a user when registering. """ - self._setup_get_username_for_registration() + self._setup_get_name_for_registration( + callback_name=self.CALLBACK_USERNAME, + ) username = "rin" channel = self.make_request( @@ -777,7 +782,9 @@ def test_username_uia(self): """Tests that the get_username_for_registration callback is only called at the end of the UIA flow. """ - m = self._setup_get_username_for_registration() + m = self._setup_get_name_for_registration( + callback_name=self.CALLBACK_USERNAME, + ) # Initiate the UIA flow. username = "rin" @@ -817,6 +824,35 @@ def test_3pid_allowed(self): self._test_3pid_allowed("rin", False) self._test_3pid_allowed("kitay", True) + def test_display_name(self): + """Tests that the get_display_name_for_registration callback can define the + display name of a user when registering. + """ + self._setup_get_name_for_registration( + callback_name=self.CALLBACK_DISPLAYNAME, + ) + + username = "rin" + channel = self.make_request( + "POST", + "/register", + { + "username": username, + "password": "bar", + "auth": {"type": LoginType.DUMMY}, + }, + ) + self.assertEqual(channel.code, 200) + + # Our callback takes the username and appends "-foo" to it, check that's what we + # have. + user_id = UserID.from_string(channel.json_body["user_id"]) + display_name = self.get_success( + self.hs.get_profile_handler().get_displayname(user_id) + ) + + self.assertEqual(display_name, username + "-foo") + def _test_3pid_allowed(self, username: str, registration: bool): """Tests that the "is_3pid_allowed" module callback is called correctly, using either /register or /account URLs depending on the arguments. @@ -877,20 +913,20 @@ def _test_3pid_allowed(self, username: str, registration: bool): m.assert_called_once_with("email", "bar@test.com", registration) - def _setup_get_username_for_registration(self) -> Mock: + def _setup_get_name_for_registration(self, callback_name: str) -> Mock: """Registers a get_username_for_registration callback that appends "-foo" to the username the client is trying to register. """ - async def get_username_for_registration(uia_results, params): + async def callback(uia_results, params): self.assertIn(LoginType.DUMMY, uia_results) username = params["username"] return username + "-foo" - m = Mock(side_effect=get_username_for_registration) + m = Mock(side_effect=callback) password_auth_provider = self.hs.get_password_auth_provider() - password_auth_provider.get_username_for_registration_callbacks.append(m) + getattr(password_auth_provider, callback_name + "_callbacks").append(m) return m From bfe66b32a0bd7032be69f52d55fe6ba79568f22e Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 16 Feb 2022 15:07:54 +0000 Subject: [PATCH 4/7] Changelog --- changelog.d/12009.feature | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/12009.feature diff --git a/changelog.d/12009.feature b/changelog.d/12009.feature new file mode 100644 index 000000000000..c8a531481ec3 --- /dev/null +++ b/changelog.d/12009.feature @@ -0,0 +1 @@ +Enable modules to set a custom display name when registering a user. From 9b6a11c52c691e262c544c22bd18512e0b2598f3 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 17 Feb 2022 10:51:57 +0000 Subject: [PATCH 5/7] Incorporate review and rename callback for consistency --- .../password_auth_provider_callbacks.md | 8 +- synapse/handlers/auth.py | 34 ++++--- synapse/module_api/__init__.py | 8 +- synapse/rest/client/register.py | 9 +- tests/handlers/test_password_providers.py | 88 ++++++++++++------- 5 files changed, 84 insertions(+), 63 deletions(-) diff --git a/docs/modules/password_auth_provider_callbacks.md b/docs/modules/password_auth_provider_callbacks.md index b1964734c168..d99a1599a2ad 100644 --- a/docs/modules/password_auth_provider_callbacks.md +++ b/docs/modules/password_auth_provider_callbacks.md @@ -85,7 +85,7 @@ If the authentication is unsuccessful, the module must return `None`. If multiple modules implement this callback, they will be considered in order. If a callback returns `None`, Synapse falls through to the next one. The value of the first callback that does not return `None` will be used. If this happens, Synapse will not call -any of the subsequent implementations of this callback. If every callback return `None`, +any of the subsequent implementations of this callback. If every callback returns `None`, the authentication is denied. ### `on_logged_out` @@ -162,7 +162,7 @@ return `None`. If multiple modules implement this callback, they will be considered in order. If a callback returns `None`, Synapse falls through to the next one. The value of the first callback that does not return `None` will be used. If this happens, Synapse will not call -any of the subsequent implementations of this callback. If every callback return `None`, +any of the subsequent implementations of this callback. If every callback returns `None`, the username provided by the user is used, if any (otherwise one is automatically generated). @@ -191,7 +191,7 @@ documentation of this callback for more information about them. If multiple modules implement this callback, they will be considered in order. If a callback returns `None`, Synapse falls through to the next one. The value of the first callback that does not return `None` will be used. If this happens, Synapse will not call -any of the subsequent implementations of this callback. If every callback return `None`, +any of the subsequent implementations of this callback. If every callback returns `None`, the username will be used (e.g. `alice` if the user being registered is `@alice:example.com`). ## `is_3pid_allowed` @@ -222,7 +222,7 @@ The example module below implements authentication checkers for two different lo - Is checked by the method: `self.check_my_login` - `m.login.password` (defined in [the spec](https://matrix.org/docs/spec/client_server/latest#password-based)) - Expects a `password` field to be sent to `/login` - - Is checked by the method: `self.check_pass` + - Is checked by the method: `self.check_pass` ```python from typing import Awaitable, Callable, Optional, Tuple diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index b5865290b34f..444ab9f71d38 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -2040,11 +2040,9 @@ def run(*args: Tuple, **kwargs: Dict) -> Awaitable: # need to use a tuple here for ("password",) not a list since lists aren't hashable auth_checkers[(LoginType.PASSWORD, ("password",))] = check_password - api.register_password_auth_provider_callbacks( - check_3pid_auth=check_3pid_auth_hook, - on_logged_out=on_logged_out_hook, - auth_checkers=auth_checkers, - ) + api.register_password_auth_provider_callbacks(check_3pid_auth=check_3pid_auth_hook, + on_logged_out=on_logged_out_hook, + auth_checkers=auth_checkers) CHECK_3PID_AUTH_CALLBACK = Callable[ @@ -2064,7 +2062,7 @@ def run(*args: Tuple, **kwargs: Dict) -> Awaitable: [JsonDict, JsonDict], Awaitable[Optional[str]], ] -GET_DISPLAY_NAME_FOR_REGISTRATION_CALLBACK = Callable[ +GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK = Callable[ [JsonDict, JsonDict], Awaitable[Optional[str]], ] @@ -2084,8 +2082,8 @@ def __init__(self) -> None: self.get_username_for_registration_callbacks: List[ GET_USERNAME_FOR_REGISTRATION_CALLBACK ] = [] - self.get_display_name_for_registration_callbacks: List[ - GET_DISPLAY_NAME_FOR_REGISTRATION_CALLBACK + self.get_displayname_for_registration_callbacks: List[ + GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK ] = [] self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = [] @@ -2106,8 +2104,8 @@ def register_password_auth_provider_callbacks( get_username_for_registration: Optional[ GET_USERNAME_FOR_REGISTRATION_CALLBACK ] = None, - get_display_name_for_registration: Optional[ - GET_DISPLAY_NAME_FOR_REGISTRATION_CALLBACK + get_displayname_for_registration: Optional[ + GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK ] = None, ) -> None: # Register check_3pid_auth callback @@ -2158,9 +2156,9 @@ def register_password_auth_provider_callbacks( get_username_for_registration, ) - if get_display_name_for_registration is not None: - self.get_display_name_for_registration_callbacks.append( - get_display_name_for_registration, + if get_displayname_for_registration is not None: + self.get_displayname_for_registration_callbacks.append( + get_displayname_for_registration, ) if is_3pid_allowed is not None: @@ -2365,12 +2363,12 @@ async def get_username_for_registration( return None - async def get_display_name_for_registration( + async def get_displayname_for_registration( self, uia_results: JsonDict, params: JsonDict, ) -> Optional[str]: - """Defines the profile information to use when registering the user, using the + """Defines the display name to use when registering the user, using the credentials and parameters provided during the UIA flow. Stops at the first callback that returns a tuple containing at least one string. @@ -2383,7 +2381,7 @@ async def get_display_name_for_registration( A tuple which first element is the display name, and the second is an MXC URL to the user's avatar. """ - for callback in self.get_display_name_for_registration_callbacks: + for callback in self.get_displayname_for_registration_callbacks: try: res = await callback(uia_results, params) @@ -2395,13 +2393,13 @@ async def get_display_name_for_registration( # to make sure this is the case. logger.warning( # type: ignore[unreachable] "Ignoring non-string value returned by" - " get_display_name_for_registration callback %s: %s", + " get_displayname_for_registration callback %s: %s", callback, res, ) except Exception as e: logger.error( - "Module raised an exception in get_profile_for_registration: %s", + "Module raised an exception in get_displayname_for_registration: %s", e, ) raise SynapseError(code=500, msg="Internal Server Error") diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 99d274844dcf..8a17b912d36d 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -70,7 +70,7 @@ from synapse.handlers.auth import ( CHECK_3PID_AUTH_CALLBACK, CHECK_AUTH_CALLBACK, - GET_DISPLAY_NAME_FOR_REGISTRATION_CALLBACK, + GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK, GET_USERNAME_FOR_REGISTRATION_CALLBACK, IS_3PID_ALLOWED_CALLBACK, ON_LOGGED_OUT_CALLBACK, @@ -318,8 +318,8 @@ def register_password_auth_provider_callbacks( get_username_for_registration: Optional[ GET_USERNAME_FOR_REGISTRATION_CALLBACK ] = None, - get_display_name_for_registration: Optional[ - GET_DISPLAY_NAME_FOR_REGISTRATION_CALLBACK + get_displayname_for_registration: Optional[ + GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK ] = None, ) -> None: """Registers callbacks for password auth provider capabilities. @@ -332,7 +332,7 @@ def register_password_auth_provider_callbacks( is_3pid_allowed=is_3pid_allowed, auth_checkers=auth_checkers, get_username_for_registration=get_username_for_registration, - get_display_name_for_registration=get_display_name_for_registration, + get_displayname_for_registration=get_displayname_for_registration, ) def register_background_update_controller_callbacks( diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index f5819f181f6e..b8a5135e02e9 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -694,10 +694,9 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: session_id ) - desired_display_name = await ( - self.password_auth_provider.get_display_name_for_registration( - auth_result, - params, + display_name = await ( + self.password_auth_provider.get_displayname_for_registration( + auth_result, params ) ) @@ -706,7 +705,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: password_hash=password_hash, guest_access_token=guest_access_token, threepid=threepid, - default_display_name=desired_display_name, + default_display_name=display_name, address=client_addr, user_agent_ips=entries, ) diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index e5356ffb3542..898f983b73c7 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -84,8 +84,7 @@ def parse_config(self): def __init__(self, config, api: ModuleApi): api.register_password_auth_provider_callbacks( - auth_checkers={("test.login_type", ("test_field",)): self.check_auth}, - ) + auth_checkers={("test.login_type", ("test_field",)): self.check_auth}) def check_auth(self, *args): return mock_password_provider.check_auth(*args) @@ -118,12 +117,10 @@ def parse_config(self): pass def __init__(self, config, api: ModuleApi): - api.register_password_auth_provider_callbacks( - auth_checkers={ - ("test.login_type", ("test_field",)): self.check_auth, - ("m.login.password", ("password",)): self.check_auth, - }, - ) + api.register_password_auth_provider_callbacks(auth_checkers={ + ("test.login_type", ("test_field",)): self.check_auth, + ("m.login.password", ("password",)): self.check_auth, + }) pass def check_auth(self, *args): @@ -164,7 +161,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): ] CALLBACK_USERNAME = "get_username_for_registration" - CALLBACK_DISPLAYNAME = "get_display_name_for_registration" + CALLBACK_DISPLAYNAME = "get_displayname_for_registration" def setUp(self): # we use a global mock device, so make sure we are starting with a clean slate @@ -786,28 +783,10 @@ def test_username_uia(self): callback_name=self.CALLBACK_USERNAME, ) - # Initiate the UIA flow. username = "rin" - channel = self.make_request( - "POST", - "register", - {"username": username, "type": "m.login.password", "password": "bar"}, - ) - self.assertEqual(channel.code, 401) - self.assertIn("session", channel.json_body) + res = self._do_uia_assert_mock_not_called(username, m) - # Check that the callback hasn't been called yet. - m.assert_not_called() - - # Finish the UIA flow. - session = channel.json_body["session"] - channel = self.make_request( - "POST", - "register", - {"auth": {"session": session, "type": LoginType.DUMMY}}, - ) - self.assertEqual(channel.code, 200, channel.json_body) - mxid = channel.json_body["user_id"] + mxid = res["user_id"] self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo") # Check that the callback has been called. @@ -824,7 +803,7 @@ def test_3pid_allowed(self): self._test_3pid_allowed("rin", False) self._test_3pid_allowed("kitay", True) - def test_display_name(self): + def test_displayname(self): """Tests that the get_display_name_for_registration callback can define the display name of a user when registering. """ @@ -853,6 +832,27 @@ def test_display_name(self): self.assertEqual(display_name, username + "-foo") + def test_displayname_uia(self): + """Tests that the get_displayname_for_registration callback is only called at the + end of the UIA flow. + """ + m = self._setup_get_name_for_registration( + callback_name=self.CALLBACK_DISPLAYNAME, + ) + + username = "rin" + res = self._do_uia_assert_mock_not_called(username, m) + + user_id = UserID.from_string(res["user_id"]) + display_name = self.get_success( + self.hs.get_profile_handler().get_displayname(user_id) + ) + + self.assertEqual(display_name, username + "-foo") + + # Check that the callback has been called. + m.assert_called_once() + def _test_3pid_allowed(self, username: str, registration: bool): """Tests that the "is_3pid_allowed" module callback is called correctly, using either /register or /account URLs depending on the arguments. @@ -914,8 +914,9 @@ def _test_3pid_allowed(self, username: str, registration: bool): m.assert_called_once_with("email", "bar@test.com", registration) def _setup_get_name_for_registration(self, callback_name: str) -> Mock: - """Registers a get_username_for_registration callback that appends "-foo" to the - username the client is trying to register. + """Registers either a get_username_for_registration callback or a + get_displayname_for_registration callback that appends "-foo" to the username the + client is trying to register. """ async def callback(uia_results, params): @@ -930,6 +931,29 @@ async def callback(uia_results, params): return m + def _do_uia_assert_mock_not_called(self, username: str, m: Mock) -> JsonDict: + # Initiate the UIA flow. + channel = self.make_request( + "POST", + "register", + {"username": username, "type": "m.login.password", "password": "bar"}, + ) + self.assertEqual(channel.code, 401) + self.assertIn("session", channel.json_body) + + # Check that the callback hasn't been called yet. + m.assert_not_called() + + # Finish the UIA flow. + session = channel.json_body["session"] + channel = self.make_request( + "POST", + "register", + {"auth": {"session": session, "type": LoginType.DUMMY}}, + ) + self.assertEqual(channel.code, 200, channel.json_body) + return channel.json_body + def _get_login_flows(self) -> JsonDict: channel = self.make_request("GET", "/_matrix/client/r0/login") self.assertEqual(channel.code, 200, channel.result) From bd26ed34bc0e889317fa683f16ebbe6f5662fae0 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 17 Feb 2022 10:55:45 +0000 Subject: [PATCH 6/7] Lint --- synapse/handlers/auth.py | 8 +++++--- tests/handlers/test_password_providers.py | 13 ++++++++----- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 444ab9f71d38..572f54b1e353 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -2040,9 +2040,11 @@ def run(*args: Tuple, **kwargs: Dict) -> Awaitable: # need to use a tuple here for ("password",) not a list since lists aren't hashable auth_checkers[(LoginType.PASSWORD, ("password",))] = check_password - api.register_password_auth_provider_callbacks(check_3pid_auth=check_3pid_auth_hook, - on_logged_out=on_logged_out_hook, - auth_checkers=auth_checkers) + api.register_password_auth_provider_callbacks( + check_3pid_auth=check_3pid_auth_hook, + on_logged_out=on_logged_out_hook, + auth_checkers=auth_checkers, + ) CHECK_3PID_AUTH_CALLBACK = Callable[ diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 898f983b73c7..488636ebcd7c 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -84,7 +84,8 @@ def parse_config(self): def __init__(self, config, api: ModuleApi): api.register_password_auth_provider_callbacks( - auth_checkers={("test.login_type", ("test_field",)): self.check_auth}) + auth_checkers={("test.login_type", ("test_field",)): self.check_auth} + ) def check_auth(self, *args): return mock_password_provider.check_auth(*args) @@ -117,10 +118,12 @@ def parse_config(self): pass def __init__(self, config, api: ModuleApi): - api.register_password_auth_provider_callbacks(auth_checkers={ - ("test.login_type", ("test_field",)): self.check_auth, - ("m.login.password", ("password",)): self.check_auth, - }) + api.register_password_auth_provider_callbacks( + auth_checkers={ + ("test.login_type", ("test_field",)): self.check_auth, + ("m.login.password", ("password",)): self.check_auth, + } + ) pass def check_auth(self, *args): From 899295f942d6221863c3222bebb714fcef48c40f Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 17 Feb 2022 17:24:21 +0100 Subject: [PATCH 7/7] Apply suggestions from code review Co-authored-by: Patrick Cloke --- docs/modules/password_auth_provider_callbacks.md | 4 ++-- tests/handlers/test_password_providers.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/modules/password_auth_provider_callbacks.md b/docs/modules/password_auth_provider_callbacks.md index d99a1599a2ad..ec810fd292e5 100644 --- a/docs/modules/password_auth_provider_callbacks.md +++ b/docs/modules/password_auth_provider_callbacks.md @@ -166,12 +166,12 @@ any of the subsequent implementations of this callback. If every callback return the username provided by the user is used, if any (otherwise one is automatically generated). -### `get_display_name_for_registration` +### `get_displayname_for_registration` _First introduced in Synapse v1.54.0_ ```python -async def get_display_name_for_registration( +async def get_displayname_for_registration( uia_results: Dict[str, Any], params: Dict[str, Any], ) -> Optional[str] diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 488636ebcd7c..49d832de814d 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -807,7 +807,7 @@ def test_3pid_allowed(self): self._test_3pid_allowed("kitay", True) def test_displayname(self): - """Tests that the get_display_name_for_registration callback can define the + """Tests that the get_displayname_for_registration callback can define the display name of a user when registering. """ self._setup_get_name_for_registration(