diff --git a/sydent/http/servlets/registerservlet.py b/sydent/http/servlets/registerservlet.py index 21d79af4..7267efc2 100644 --- a/sydent/http/servlets/registerservlet.py +++ b/sydent/http/servlets/registerservlet.py @@ -47,9 +47,9 @@ def render_POST(self, request): args = get_args(request, ('matrix_server_name', 'access_token')) - hostname = args['matrix_server_name'].lower() + matrix_server = args['matrix_server_name'].lower() - if not is_valid_hostname(hostname): + if not is_valid_hostname(matrix_server): request.setResponseCode(400) return { 'errcode': 'M_INVALID_PARAM', @@ -59,15 +59,50 @@ def render_POST(self, request): result = yield self.client.get_json( "matrix://%s/_matrix/federation/v1/openid/userinfo?access_token=%s" % ( - hostname, + matrix_server, urllib.parse.quote(args['access_token']), ), 1024 * 5, ) + if 'sub' not in result: raise Exception("Invalid response from homeserver") user_id = result['sub'] + + if not isinstance(user_id, str): + request.setResponseCode(500) + return { + 'errcode': 'M_UNKNOWN', + 'error': 'The Matrix homeserver returned a malformed reply' + } + + user_id_components = user_id.split(':', 1) + + # Ensure there's a localpart and domain in the returned user ID. + if len(user_id_components) != 2: + request.setResponseCode(500) + return { + 'errcode': 'M_UNKNOWN', + 'error': 'The Matrix homeserver returned an invalid MXID' + } + + user_id_server = user_id_components[1] + + if not is_valid_hostname(user_id_server): + request.setResponseCode(500) + return { + 'errcode': 'M_UNKNOWN', + 'error': 'The Matrix homeserver returned an invalid MXID' + } + + if user_id_server != matrix_server: + request.setResponseCode(500) + return { + 'errcode': 'M_UNKNOWN', + 'error': 'The Matrix homeserver returned a MXID belonging to another homeserver' + } + tok = yield issueToken(self.sydent, user_id) # XXX: `token` is correct for the spec, but we released with `access_token` diff --git a/sydent/util/stringutils.py b/sydent/util/stringutils.py index d2f37321..3f1a91e8 100644 --- a/sydent/util/stringutils.py +++ b/sydent/util/stringutils.py @@ -63,7 +63,7 @@ def is_valid_hostname(string: str) -> bool: port_num = int(port) valid_port = ( port == str(port_num) # exclude things like '08090' or ' 8090' - and 1 <= port_num < 65536 + and 1 <= port_num < 65536) except ValueError: valid_port = False