From 9e573348d81df8191bbe8c266c01999c9d57cd5f Mon Sep 17 00:00:00 2001 From: Denis Kasak Date: Wed, 24 Mar 2021 13:50:40 +0000 Subject: [PATCH] Rework hostname validation to make port checking stricter. Instead of using a regex to validate the entire hostname + port combination, we now split the hostname into components and check each component separately. This makes the regex a bit simpler and allows us to validate the port number better, including that it belongs to the valid range. --- sydent/http/servlets/registerservlet.py | 17 ++++++++-- sydent/util/stringutils.py | 42 ++++++++++++++++++++++- tests/test_auth.py | 4 +-- tests/test_register.py | 45 +++++++++++++++++++++++++ tests/test_util.py | 26 ++++++++++++++ 5 files changed, 128 insertions(+), 6 deletions(-) create mode 100644 tests/test_register.py create mode 100644 tests/test_util.py diff --git a/sydent/http/servlets/registerservlet.py b/sydent/http/servlets/registerservlet.py index ebd07cce..21d79af4 100644 --- a/sydent/http/servlets/registerservlet.py +++ b/sydent/http/servlets/registerservlet.py @@ -25,7 +25,7 @@ from sydent.http.servlets import get_args, jsonwrap, deferjsonwrap, send_cors from sydent.http.httpclient import FederationHttpClient from sydent.users.tokens import issueToken - +from sydent.util.stringutils import is_valid_hostname logger = logging.getLogger(__name__) @@ -47,9 +47,20 @@ def render_POST(self, request): args = get_args(request, ('matrix_server_name', 'access_token')) + hostname = args['matrix_server_name'].lower() + + if not is_valid_hostname(hostname): + request.setResponseCode(400) + return { + 'errcode': 'M_INVALID_PARAM', + 'error': 'matrix_server_name must be a valid hostname' + } + result = yield self.client.get_json( - "matrix://%s/_matrix/federation/v1/openid/userinfo?access_token=%s" % ( - args['matrix_server_name'], urllib.parse.quote(args['access_token']), + "matrix://%s/_matrix/federation/v1/openid/userinfo?access_token=%s" + % ( + hostname, + urllib.parse.quote(args['access_token']), ), 1024 * 5, ) diff --git a/sydent/util/stringutils.py b/sydent/util/stringutils.py index e41ff662..d2f37321 100644 --- a/sydent/util/stringutils.py +++ b/sydent/util/stringutils.py @@ -17,14 +17,54 @@ # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$") +# hostname/domain name + optional port +# https://regex101.com/r/OyN1lg/2 +hostname_regex = re.compile( + r"^(?:[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?)(?:\.[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?)*$", + flags=re.IGNORECASE) + def is_valid_client_secret(client_secret): """Validate that a given string matches the client_secret regex defined by the spec :param client_secret: The client_secret to validate - :type client_secret: unicode + :type client_secret: str :return: Whether the client_secret is valid :rtype: bool """ return client_secret_regex.match(client_secret) is not None + + +def is_valid_hostname(string: str) -> bool: + """Validate that a given string is a valid hostname or domain name, with an + optional port number. + + For domain names, this only validates that the form is right (for + instance, it doesn't check that the TLD is valid). If a port is + specified, it has to be a valid port number. + + :param string: The string to validate + :type string: str + + :return: Whether the input is a valid hostname + :rtype: bool + """ + + host_parts = string.split(":", 1) + + if len(host_parts) == 1: + return hostname_regex.match(string) is not None + else: + host, port = host_parts + valid_hostname = hostname_regex.match(host) is not None + + try: + port_num = int(port) + valid_port = ( + port == str(port_num) # exclude things like '08090' or ' 8090' + and 1 <= port_num < 65536 + except ValueError: + valid_port = False + + return valid_hostname and valid_port diff --git a/tests/test_auth.py b/tests/test_auth.py index 7d259453..71f5a513 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -44,7 +44,7 @@ def setUp(self): self.sydent.db.commit() def test_can_read_token_from_headers(self): - """Tests that Sydent correct extracts an auth token from request headers""" + """Tests that Sydent correctly extracts an auth token from request headers""" self.sydent.run() request, _ = make_request( @@ -59,7 +59,7 @@ def test_can_read_token_from_headers(self): self.assertEqual(token, self.test_token) def test_can_read_token_from_query_parameters(self): - """Tests that Sydent correct extracts an auth token from query parameters""" + """Tests that Sydent correctly extracts an auth token from query parameters""" self.sydent.run() request, _ = make_request( diff --git a/tests/test_register.py b/tests/test_register.py new file mode 100644 index 00000000..abc8c16c --- /dev/null +++ b/tests/test_register.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.trial import unittest + +from tests.utils import make_request, make_sydent + + +class RegisterTestCase(unittest.TestCase): + """Tests Sydent's register servlet""" + def setUp(self): + # Create a new sydent + self.sydent = make_sydent() + + def test_sydent_rejects_invalid_hostname(self): + """Tests that the /register endpoint rejects an invalid hostname passed as matrix_server_name""" + self.sydent.run() + + bad_hostname = "example.com#" + + request, channel = make_request( + self.sydent.reactor, + "POST", + "/_matrix/identity/v2/account/register", + content={ + "matrix_server_name": bad_hostname, + "access_token": "foo" + }) + + request.render(self.sydent.servlets.registerServlet) + + self.assertEqual(channel.code, 400) diff --git a/tests/test_util.py b/tests/test_util.py new file mode 100644 index 00000000..7c9a011e --- /dev/null +++ b/tests/test_util.py @@ -0,0 +1,26 @@ +from twisted.trial import unittest +from sydent.util.stringutils import is_valid_hostname + + +class UtilTests(unittest.TestCase): + """Tests Sydent utility functions.""" + def test_is_valid_hostname(self): + """Tests that the is_valid_hostname function accepts only valid + hostnames (or domain names), with optional port number. + """ + + self.assertTrue(is_valid_hostname("example.com")) + self.assertTrue(is_valid_hostname("EXAMPLE.COM")) + self.assertTrue(is_valid_hostname("ExAmPlE.CoM")) + self.assertTrue(is_valid_hostname("example.com:4242")) + self.assertTrue(is_valid_hostname("localhost")) + self.assertTrue(is_valid_hostname("localhost:9000")) + self.assertTrue(is_valid_hostname("a.b:1234")) + + self.assertFalse(is_valid_hostname("example.com:65536")) + self.assertFalse(is_valid_hostname("example.com:0")) + self.assertFalse(is_valid_hostname("example.com:a")) + self.assertFalse(is_valid_hostname("example.com:04242")) + self.assertFalse(is_valid_hostname("example.com: 4242")) + self.assertFalse(is_valid_hostname("example.com/example.com")) + self.assertFalse(is_valid_hostname("example.com#example.com"))