Skip to content
This repository has been archived by the owner on Nov 5, 2019. It is now read-only.

Commit

Permalink
Correct query loss when using parse_qsl to dict
Browse files Browse the repository at this point in the history
  • Loading branch information
dhermes committed Aug 17, 2016
1 parent 4c7b3be commit ced809d
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 49 deletions.
58 changes: 53 additions & 5 deletions oauth2client/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,58 @@ def string_to_scopes(scopes):
return scopes


def parse_unique_urlencoded(content, as_list=False):
"""Parses unique key-value parameters from urlencoded content.
Args:
content: string, URL-encoded key-value pairs.
as_list: bool, flag indicating if the values in the result should
be list or values. Defaults to False.
Returns:
dict, The key-value pairs from ``content``.
Raises:
ValueError: if one of the keys is repeated.
"""
urlencoded_params = urllib.parse.parse_qs(content)
params = {}
for key, value in six.iteritems(urlencoded_params):
if len(value) != 1:
msg = ('URL-encoded content contains a repeated value:'
'%s -> %s' % (key, ', '.join(value)))
raise ValueError(msg)
if as_list:
params[key] = value
else:
params[key] = value[0]
return params


def update_query_params(uri, params):
"""Updates a URI with new query parameters.
If a given key from ``params`` is repeated in the ``uri``, then
the new value from ``params`` will be added to that list. If the
key occurs zero or once in the ``uri``, the value from ``params``
will be used as the only value.
Args:
uri: string, A valid URI, with potential existing query parameters.
params: dict, A dictionary of query parameters.
Returns:
The same URI but with the new query parameters added.
"""
parts = urllib.parse.urlparse(uri)
query_params = parse_unique_urlencoded(parts.query, as_list=True)
for key, value in six.iteritems(params):
query_params[key] = [value]
new_query = urllib.parse.urlencode(query_params, doseq=True)
new_parts = parts._replace(query=new_query)
return urllib.parse.urlunparse(new_parts)


def _add_query_parameter(url, name, value):
"""Adds a query parameter to a url.
Expand All @@ -195,11 +247,7 @@ def _add_query_parameter(url, name, value):
if value is None:
return url
else:
parsed = list(urllib.parse.urlparse(url))
query = dict(urllib.parse.parse_qsl(parsed[4]))
query[name] = value
parsed[4] = urllib.parse.urlencode(query)
return urllib.parse.urlunparse(parsed)
return update_query_params(url, {name: value})


def validate_file(filename):
Expand Down
28 changes: 6 additions & 22 deletions oauth2client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,23 +438,6 @@ def delete(self):
self.release_lock()


def _update_query_params(uri, params):
"""Updates a URI with new query parameters.
Args:
uri: string, A valid URI, with potential existing query parameters.
params: dict, A dictionary of query parameters.
Returns:
The same URI but with the new query parameters added.
"""
parts = urllib.parse.urlparse(uri)
query_params = dict(urllib.parse.parse_qsl(parts.query))
query_params.update(params)
new_parts = parts._replace(query=urllib.parse.urlencode(query_params))
return urllib.parse.urlunparse(new_parts)


class OAuth2Credentials(Credentials):
"""Credentials object for OAuth 2.0.
Expand Down Expand Up @@ -850,7 +833,8 @@ def _do_revoke(self, http, token):
"""
logger.info('Revoking token')
query_params = {'token': token}
token_revoke_uri = _update_query_params(self.revoke_uri, query_params)
token_revoke_uri = _helpers.update_query_params(
self.revoke_uri, query_params)
resp, content = transport.request(http, token_revoke_uri)
if resp.status == http_client.OK:
self.invalid = True
Expand Down Expand Up @@ -889,8 +873,8 @@ def _do_retrieve_scopes(self, http, token):
"""
logger.info('Refreshing scopes')
query_params = {'access_token': token, 'fields': 'scope'}
token_info_uri = _update_query_params(self.token_info_uri,
query_params)
token_info_uri = _helpers.update_query_params(
self.token_info_uri, query_params)
resp, content = transport.request(http, token_info_uri)
content = _helpers._from_bytes(content)
if resp.status == http_client.OK:
Expand Down Expand Up @@ -1610,7 +1594,7 @@ def _parse_exchange_token_response(content):
except Exception:
# different JSON libs raise different exceptions,
# so we just do a catch-all here
resp = dict(urllib.parse.parse_qsl(content))
resp = _helpers.parse_unique_urlencoded(content)

# some providers respond with 'expires', others with 'expires_in'
if resp and 'expires' in resp:
Expand Down Expand Up @@ -1943,7 +1927,7 @@ def step1_get_authorize_url(self, redirect_uri=None, state=None):
query_params['code_challenge_method'] = 'S256'

query_params.update(self.params)
return _update_query_params(self.auth_uri, query_params)
return _helpers.update_query_params(self.auth_uri, query_params)

@_helpers.positional(1)
def step1_get_device_and_user_codes(self, http=None):
Expand Down
12 changes: 6 additions & 6 deletions oauth2client/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,16 @@ def do_GET(self):
if an error occurred.
"""
self.send_response(http_client.OK)
self.send_header("Content-type", "text/html")
self.send_header('Content-type', 'text/html')
self.end_headers()
query = self.path.split('?', 1)[-1]
query = dict(urllib.parse.parse_qsl(query))
parts = urllib.parse.urlparse(self.path)
query = _helpers.parse_unique_urlencoded(parts.query)
self.server.query_params = query
self.wfile.write(
b"<html><head><title>Authentication Status</title></head>")
b'<html><head><title>Authentication Status</title></head>')
self.wfile.write(
b"<body><p>The authentication flow has completed.</p>")
self.wfile.write(b"</body></html>")
b'<body><p>The authentication flow has completed.</p>')
self.wfile.write(b'</body></html>')

def log_message(self, format, *args):
"""Do not log messages to stdout while running as cmd. line program."""
Expand Down
40 changes: 40 additions & 0 deletions tests/test__helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import mock

from oauth2client import _helpers
from tests import test_client


class PositionalTests(unittest.TestCase):
Expand Down Expand Up @@ -242,3 +243,42 @@ def test_bad_input(self):
bad_string = b'+'
with self.assertRaises((TypeError, binascii.Error)):
_helpers._urlsafe_b64decode(bad_string)


class Test_update_query_params(unittest.TestCase):

def test_update_query_params_no_params(self):
uri = 'http://www.google.com'
updated = _helpers.update_query_params(uri, {'a': 'b'})
self.assertEqual(updated, uri + '?a=b')

def test_update_query_params_existing_params(self):
uri = 'http://www.google.com?x=y'
updated = _helpers.update_query_params(uri, {'a': 'b', 'c': 'd&'})
hardcoded_update = uri + '&a=b&c=d%26'
test_client.assertUrisEqual(self, updated, hardcoded_update)

def test_update_query_params_replace_param(self):
base_uri = 'http://www.google.com'
uri = base_uri + '?x=a'
updated = _helpers.update_query_params(uri, {'x': 'b', 'y': 'c'})
hardcoded_update = base_uri + '?x=b&y=c'
test_client.assertUrisEqual(self, updated, hardcoded_update)

def test_update_query_params_repeated_params(self):
uri = 'http://www.google.com?x=a&x=b'
with self.assertRaises(ValueError):
_helpers.update_query_params(uri, {'a': 'c'})


class Test_parse_unique_urlencoded(unittest.TestCase):

def test_without_repeats(self):
content = 'a=b&c=d'
result = _helpers.parse_unique_urlencoded(content)
self.assertEqual(result, {'a': 'b', 'c': 'd'})

def test_with_repeats(self):
content = 'a=b&a=d'
with self.assertRaises(ValueError):
_helpers.parse_unique_urlencoded(content)
19 changes: 3 additions & 16 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1364,7 +1364,7 @@ def _do_retrieve_scopes_test_helper(self, response, content,
self.assertEqual(credentials.scopes, set())
self.assertEqual(exc_manager.exception.args, (error_msg,))

token_uri = client._update_query_params(
token_uri = _helpers.update_query_params(
oauth2client.GOOGLE_TOKEN_INFO_URI,
{'fields': 'scope', 'access_token': token})

Expand Down Expand Up @@ -1558,19 +1558,6 @@ def test_sign_blob_abstract(self):
credentials.sign_blob(b'blob')


class UpdateQueryParamsTest(unittest.TestCase):
def test_update_query_params_no_params(self):
uri = 'http://www.google.com'
updated = client._update_query_params(uri, {'a': 'b'})
self.assertEqual(updated, uri + '?a=b')

def test_update_query_params_existing_params(self):
uri = 'http://www.google.com?x=y'
updated = client._update_query_params(uri, {'a': 'b', 'c': 'd&'})
hardcoded_update = uri + '&a=b&c=d%26'
assertUrisEqual(self, updated, hardcoded_update)


class ExtractIdTokenTest(unittest.TestCase):
"""Tests client._extract_id_token()."""

Expand Down Expand Up @@ -1670,7 +1657,7 @@ def test_step1_get_authorize_url_redirect_override(self, logger):
'access_type': 'offline',
'response_type': 'code',
}
expected = client._update_query_params(flow.auth_uri, query_params)
expected = _helpers.update_query_params(flow.auth_uri, query_params)
assertUrisEqual(self, expected, result)
# Check stubs.
self.assertEqual(logger.warning.call_count, 1)
Expand Down Expand Up @@ -1735,7 +1722,7 @@ def test_step1_get_authorize_url_without_login_hint(self):
'access_type': 'offline',
'response_type': 'code',
}
expected = client._update_query_params(flow.auth_uri, query_params)
expected = _helpers.update_query_params(flow.auth_uri, query_params)
assertUrisEqual(self, expected, result)

def test_step1_get_device_and_user_codes_wo_device_uri(self):
Expand Down

0 comments on commit ced809d

Please sign in to comment.