From ced809d1de396f3fc2c81bdc8463e527ad63c397 Mon Sep 17 00:00:00 2001 From: Danny Hermes Date: Wed, 17 Aug 2016 15:18:29 -0700 Subject: [PATCH] Correct query loss when using parse_qsl to dict --- oauth2client/_helpers.py | 58 ++++++++++++++++++++++++++++++++++++---- oauth2client/client.py | 28 +++++-------------- oauth2client/tools.py | 12 ++++----- tests/test__helpers.py | 40 +++++++++++++++++++++++++++ tests/test_client.py | 19 +++---------- 5 files changed, 108 insertions(+), 49 deletions(-) diff --git a/oauth2client/_helpers.py b/oauth2client/_helpers.py index 79586a518..0e5a9de3c 100644 --- a/oauth2client/_helpers.py +++ b/oauth2client/_helpers.py @@ -179,6 +179,58 @@ def string_to_scopes(scopes): return scopes +def parse_unique_urlencoded(content, as_list=False): + """Parses unique key-value parameters from urlencoded content. + + Args: + content: string, URL-encoded key-value pairs. + as_list: bool, flag indicating if the values in the result should + be list or values. Defaults to False. + + Returns: + dict, The key-value pairs from ``content``. + + Raises: + ValueError: if one of the keys is repeated. + """ + urlencoded_params = urllib.parse.parse_qs(content) + params = {} + for key, value in six.iteritems(urlencoded_params): + if len(value) != 1: + msg = ('URL-encoded content contains a repeated value:' + '%s -> %s' % (key, ', '.join(value))) + raise ValueError(msg) + if as_list: + params[key] = value + else: + params[key] = value[0] + return params + + +def update_query_params(uri, params): + """Updates a URI with new query parameters. + + If a given key from ``params`` is repeated in the ``uri``, then + the new value from ``params`` will be added to that list. If the + key occurs zero or once in the ``uri``, the value from ``params`` + will be used as the only value. + + Args: + uri: string, A valid URI, with potential existing query parameters. + params: dict, A dictionary of query parameters. + + Returns: + The same URI but with the new query parameters added. + """ + parts = urllib.parse.urlparse(uri) + query_params = parse_unique_urlencoded(parts.query, as_list=True) + for key, value in six.iteritems(params): + query_params[key] = [value] + new_query = urllib.parse.urlencode(query_params, doseq=True) + new_parts = parts._replace(query=new_query) + return urllib.parse.urlunparse(new_parts) + + def _add_query_parameter(url, name, value): """Adds a query parameter to a url. @@ -195,11 +247,7 @@ def _add_query_parameter(url, name, value): if value is None: return url else: - parsed = list(urllib.parse.urlparse(url)) - query = dict(urllib.parse.parse_qsl(parsed[4])) - query[name] = value - parsed[4] = urllib.parse.urlencode(query) - return urllib.parse.urlunparse(parsed) + return update_query_params(url, {name: value}) def validate_file(filename): diff --git a/oauth2client/client.py b/oauth2client/client.py index 0497d074f..704c61034 100644 --- a/oauth2client/client.py +++ b/oauth2client/client.py @@ -438,23 +438,6 @@ def delete(self): self.release_lock() -def _update_query_params(uri, params): - """Updates a URI with new query parameters. - - Args: - uri: string, A valid URI, with potential existing query parameters. - params: dict, A dictionary of query parameters. - - Returns: - The same URI but with the new query parameters added. - """ - parts = urllib.parse.urlparse(uri) - query_params = dict(urllib.parse.parse_qsl(parts.query)) - query_params.update(params) - new_parts = parts._replace(query=urllib.parse.urlencode(query_params)) - return urllib.parse.urlunparse(new_parts) - - class OAuth2Credentials(Credentials): """Credentials object for OAuth 2.0. @@ -850,7 +833,8 @@ def _do_revoke(self, http, token): """ logger.info('Revoking token') query_params = {'token': token} - token_revoke_uri = _update_query_params(self.revoke_uri, query_params) + token_revoke_uri = _helpers.update_query_params( + self.revoke_uri, query_params) resp, content = transport.request(http, token_revoke_uri) if resp.status == http_client.OK: self.invalid = True @@ -889,8 +873,8 @@ def _do_retrieve_scopes(self, http, token): """ logger.info('Refreshing scopes') query_params = {'access_token': token, 'fields': 'scope'} - token_info_uri = _update_query_params(self.token_info_uri, - query_params) + token_info_uri = _helpers.update_query_params( + self.token_info_uri, query_params) resp, content = transport.request(http, token_info_uri) content = _helpers._from_bytes(content) if resp.status == http_client.OK: @@ -1610,7 +1594,7 @@ def _parse_exchange_token_response(content): except Exception: # different JSON libs raise different exceptions, # so we just do a catch-all here - resp = dict(urllib.parse.parse_qsl(content)) + resp = _helpers.parse_unique_urlencoded(content) # some providers respond with 'expires', others with 'expires_in' if resp and 'expires' in resp: @@ -1943,7 +1927,7 @@ def step1_get_authorize_url(self, redirect_uri=None, state=None): query_params['code_challenge_method'] = 'S256' query_params.update(self.params) - return _update_query_params(self.auth_uri, query_params) + return _helpers.update_query_params(self.auth_uri, query_params) @_helpers.positional(1) def step1_get_device_and_user_codes(self, http=None): diff --git a/oauth2client/tools.py b/oauth2client/tools.py index 0aa671ba6..b882429f8 100644 --- a/oauth2client/tools.py +++ b/oauth2client/tools.py @@ -122,16 +122,16 @@ def do_GET(self): if an error occurred. """ self.send_response(http_client.OK) - self.send_header("Content-type", "text/html") + self.send_header('Content-type', 'text/html') self.end_headers() - query = self.path.split('?', 1)[-1] - query = dict(urllib.parse.parse_qsl(query)) + parts = urllib.parse.urlparse(self.path) + query = _helpers.parse_unique_urlencoded(parts.query) self.server.query_params = query self.wfile.write( - b"Authentication Status") + b'Authentication Status') self.wfile.write( - b"

The authentication flow has completed.

") - self.wfile.write(b"") + b'

The authentication flow has completed.

') + self.wfile.write(b'') def log_message(self, format, *args): """Do not log messages to stdout while running as cmd. line program.""" diff --git a/tests/test__helpers.py b/tests/test__helpers.py index aac5f8d59..00cd38a23 100644 --- a/tests/test__helpers.py +++ b/tests/test__helpers.py @@ -19,6 +19,7 @@ import mock from oauth2client import _helpers +from tests import test_client class PositionalTests(unittest.TestCase): @@ -242,3 +243,42 @@ def test_bad_input(self): bad_string = b'+' with self.assertRaises((TypeError, binascii.Error)): _helpers._urlsafe_b64decode(bad_string) + + +class Test_update_query_params(unittest.TestCase): + + def test_update_query_params_no_params(self): + uri = 'http://www.google.com' + updated = _helpers.update_query_params(uri, {'a': 'b'}) + self.assertEqual(updated, uri + '?a=b') + + def test_update_query_params_existing_params(self): + uri = 'http://www.google.com?x=y' + updated = _helpers.update_query_params(uri, {'a': 'b', 'c': 'd&'}) + hardcoded_update = uri + '&a=b&c=d%26' + test_client.assertUrisEqual(self, updated, hardcoded_update) + + def test_update_query_params_replace_param(self): + base_uri = 'http://www.google.com' + uri = base_uri + '?x=a' + updated = _helpers.update_query_params(uri, {'x': 'b', 'y': 'c'}) + hardcoded_update = base_uri + '?x=b&y=c' + test_client.assertUrisEqual(self, updated, hardcoded_update) + + def test_update_query_params_repeated_params(self): + uri = 'http://www.google.com?x=a&x=b' + with self.assertRaises(ValueError): + _helpers.update_query_params(uri, {'a': 'c'}) + + +class Test_parse_unique_urlencoded(unittest.TestCase): + + def test_without_repeats(self): + content = 'a=b&c=d' + result = _helpers.parse_unique_urlencoded(content) + self.assertEqual(result, {'a': 'b', 'c': 'd'}) + + def test_with_repeats(self): + content = 'a=b&a=d' + with self.assertRaises(ValueError): + _helpers.parse_unique_urlencoded(content) diff --git a/tests/test_client.py b/tests/test_client.py index dbe11ebaf..a3268ba84 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1364,7 +1364,7 @@ def _do_retrieve_scopes_test_helper(self, response, content, self.assertEqual(credentials.scopes, set()) self.assertEqual(exc_manager.exception.args, (error_msg,)) - token_uri = client._update_query_params( + token_uri = _helpers.update_query_params( oauth2client.GOOGLE_TOKEN_INFO_URI, {'fields': 'scope', 'access_token': token}) @@ -1558,19 +1558,6 @@ def test_sign_blob_abstract(self): credentials.sign_blob(b'blob') -class UpdateQueryParamsTest(unittest.TestCase): - def test_update_query_params_no_params(self): - uri = 'http://www.google.com' - updated = client._update_query_params(uri, {'a': 'b'}) - self.assertEqual(updated, uri + '?a=b') - - def test_update_query_params_existing_params(self): - uri = 'http://www.google.com?x=y' - updated = client._update_query_params(uri, {'a': 'b', 'c': 'd&'}) - hardcoded_update = uri + '&a=b&c=d%26' - assertUrisEqual(self, updated, hardcoded_update) - - class ExtractIdTokenTest(unittest.TestCase): """Tests client._extract_id_token().""" @@ -1670,7 +1657,7 @@ def test_step1_get_authorize_url_redirect_override(self, logger): 'access_type': 'offline', 'response_type': 'code', } - expected = client._update_query_params(flow.auth_uri, query_params) + expected = _helpers.update_query_params(flow.auth_uri, query_params) assertUrisEqual(self, expected, result) # Check stubs. self.assertEqual(logger.warning.call_count, 1) @@ -1735,7 +1722,7 @@ def test_step1_get_authorize_url_without_login_hint(self): 'access_type': 'offline', 'response_type': 'code', } - expected = client._update_query_params(flow.auth_uri, query_params) + expected = _helpers.update_query_params(flow.auth_uri, query_params) assertUrisEqual(self, expected, result) def test_step1_get_device_and_user_codes_wo_device_uri(self):