Skip to content
This repository has been archived by the owner on Nov 5, 2019. It is now read-only.

Correct query loss when using parse_qsl to dict #622

Merged
merged 1 commit into from
Aug 17, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 49 additions & 5 deletions oauth2client/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,54 @@ def string_to_scopes(scopes):
return scopes


def parse_unique_urlencoded(content):
"""Parses unique key-value parameters from urlencoded content.

Args:
content: string, URL-encoded key-value pairs.

Returns:
dict, The key-value pairs from ``content``.

Raises:
ValueError: if one of the keys is repeated.
"""
urlencoded_params = urllib.parse.parse_qs(content)
params = {}
for key, value in six.iteritems(urlencoded_params):
if len(value) != 1:
msg = ('URL-encoded content contains a repeated value:'
'%s -> %s' % (key, ', '.join(value)))
raise ValueError(msg)
params[key] = value[0]
return params


def update_query_params(uri, params):
"""Updates a URI with new query parameters.

If a given key from ``params`` is repeated in the ``uri``, then
the URI will be considered invalid and an error will occur.

If the URI is valid, then each value from ``params`` will
replace the corresponding value in the query parameters (if
it exists).

Args:
uri: string, A valid URI, with potential existing query parameters.
params: dict, A dictionary of query parameters.

Returns:
The same URI but with the new query parameters added.
"""
parts = urllib.parse.urlparse(uri)
query_params = parse_unique_urlencoded(parts.query)
query_params.update(params)
new_query = urllib.parse.urlencode(query_params)
new_parts = parts._replace(query=new_query)
return urllib.parse.urlunparse(new_parts)


def _add_query_parameter(url, name, value):
"""Adds a query parameter to a url.

Expand All @@ -195,11 +243,7 @@ def _add_query_parameter(url, name, value):
if value is None:

This comment was marked as spam.

return url
else:
parsed = list(urllib.parse.urlparse(url))
query = dict(urllib.parse.parse_qsl(parsed[4]))
query[name] = value
parsed[4] = urllib.parse.urlencode(query)
return urllib.parse.urlunparse(parsed)
return update_query_params(url, {name: value})


def validate_file(filename):
Expand Down
28 changes: 6 additions & 22 deletions oauth2client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,23 +438,6 @@ def delete(self):
self.release_lock()


def _update_query_params(uri, params):
"""Updates a URI with new query parameters.

Args:
uri: string, A valid URI, with potential existing query parameters.
params: dict, A dictionary of query parameters.

Returns:
The same URI but with the new query parameters added.
"""
parts = urllib.parse.urlparse(uri)
query_params = dict(urllib.parse.parse_qsl(parts.query))
query_params.update(params)
new_parts = parts._replace(query=urllib.parse.urlencode(query_params))
return urllib.parse.urlunparse(new_parts)


class OAuth2Credentials(Credentials):
"""Credentials object for OAuth 2.0.

Expand Down Expand Up @@ -850,7 +833,8 @@ def _do_revoke(self, http, token):
"""
logger.info('Revoking token')
query_params = {'token': token}
token_revoke_uri = _update_query_params(self.revoke_uri, query_params)
token_revoke_uri = _helpers.update_query_params(
self.revoke_uri, query_params)
resp, content = transport.request(http, token_revoke_uri)
if resp.status == http_client.OK:
self.invalid = True
Expand Down Expand Up @@ -889,8 +873,8 @@ def _do_retrieve_scopes(self, http, token):
"""
logger.info('Refreshing scopes')
query_params = {'access_token': token, 'fields': 'scope'}
token_info_uri = _update_query_params(self.token_info_uri,
query_params)
token_info_uri = _helpers.update_query_params(
self.token_info_uri, query_params)
resp, content = transport.request(http, token_info_uri)
content = _helpers._from_bytes(content)
if resp.status == http_client.OK:
Expand Down Expand Up @@ -1610,7 +1594,7 @@ def _parse_exchange_token_response(content):
except Exception:
# different JSON libs raise different exceptions,
# so we just do a catch-all here
resp = dict(urllib.parse.parse_qsl(content))
resp = _helpers.parse_unique_urlencoded(content)

# some providers respond with 'expires', others with 'expires_in'
if resp and 'expires' in resp:
Expand Down Expand Up @@ -1943,7 +1927,7 @@ def step1_get_authorize_url(self, redirect_uri=None, state=None):
query_params['code_challenge_method'] = 'S256'

query_params.update(self.params)
return _update_query_params(self.auth_uri, query_params)
return _helpers.update_query_params(self.auth_uri, query_params)

@_helpers.positional(1)
def step1_get_device_and_user_codes(self, http=None):
Expand Down
12 changes: 6 additions & 6 deletions oauth2client/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,16 @@ def do_GET(self):
if an error occurred.
"""
self.send_response(http_client.OK)
self.send_header("Content-type", "text/html")
self.send_header('Content-type', 'text/html')
self.end_headers()
query = self.path.split('?', 1)[-1]
query = dict(urllib.parse.parse_qsl(query))
parts = urllib.parse.urlparse(self.path)
query = _helpers.parse_unique_urlencoded(parts.query)
self.server.query_params = query
self.wfile.write(
b"<html><head><title>Authentication Status</title></head>")
b'<html><head><title>Authentication Status</title></head>')
self.wfile.write(
b"<body><p>The authentication flow has completed.</p>")
self.wfile.write(b"</body></html>")
b'<body><p>The authentication flow has completed.</p>')
self.wfile.write(b'</body></html>')

def log_message(self, format, *args):
"""Do not log messages to stdout while running as cmd. line program."""
Expand Down
40 changes: 40 additions & 0 deletions tests/test__helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import mock

from oauth2client import _helpers
from tests import test_client


class PositionalTests(unittest.TestCase):
Expand Down Expand Up @@ -242,3 +243,42 @@ def test_bad_input(self):
bad_string = b'+'
with self.assertRaises((TypeError, binascii.Error)):
_helpers._urlsafe_b64decode(bad_string)


class Test_update_query_params(unittest.TestCase):

def test_update_query_params_no_params(self):
uri = 'http://www.google.com'
updated = _helpers.update_query_params(uri, {'a': 'b'})
self.assertEqual(updated, uri + '?a=b')

def test_update_query_params_existing_params(self):
uri = 'http://www.google.com?x=y'
updated = _helpers.update_query_params(uri, {'a': 'b', 'c': 'd&'})
hardcoded_update = uri + '&a=b&c=d%26'
test_client.assertUrisEqual(self, updated, hardcoded_update)

def test_update_query_params_replace_param(self):
base_uri = 'http://www.google.com'
uri = base_uri + '?x=a'
updated = _helpers.update_query_params(uri, {'x': 'b', 'y': 'c'})
hardcoded_update = base_uri + '?x=b&y=c'
test_client.assertUrisEqual(self, updated, hardcoded_update)

def test_update_query_params_repeated_params(self):
uri = 'http://www.google.com?x=a&x=b'
with self.assertRaises(ValueError):
_helpers.update_query_params(uri, {'a': 'c'})


class Test_parse_unique_urlencoded(unittest.TestCase):

def test_without_repeats(self):
content = 'a=b&c=d'
result = _helpers.parse_unique_urlencoded(content)
self.assertEqual(result, {'a': 'b', 'c': 'd'})

def test_with_repeats(self):
content = 'a=b&a=d'
with self.assertRaises(ValueError):
_helpers.parse_unique_urlencoded(content)
19 changes: 3 additions & 16 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1364,7 +1364,7 @@ def _do_retrieve_scopes_test_helper(self, response, content,
self.assertEqual(credentials.scopes, set())
self.assertEqual(exc_manager.exception.args, (error_msg,))

token_uri = client._update_query_params(
token_uri = _helpers.update_query_params(
oauth2client.GOOGLE_TOKEN_INFO_URI,
{'fields': 'scope', 'access_token': token})

Expand Down Expand Up @@ -1558,19 +1558,6 @@ def test_sign_blob_abstract(self):
credentials.sign_blob(b'blob')


class UpdateQueryParamsTest(unittest.TestCase):
def test_update_query_params_no_params(self):
uri = 'http://www.google.com'
updated = client._update_query_params(uri, {'a': 'b'})
self.assertEqual(updated, uri + '?a=b')

def test_update_query_params_existing_params(self):
uri = 'http://www.google.com?x=y'
updated = client._update_query_params(uri, {'a': 'b', 'c': 'd&'})
hardcoded_update = uri + '&a=b&c=d%26'
assertUrisEqual(self, updated, hardcoded_update)


class ExtractIdTokenTest(unittest.TestCase):
"""Tests client._extract_id_token()."""

Expand Down Expand Up @@ -1670,7 +1657,7 @@ def test_step1_get_authorize_url_redirect_override(self, logger):
'access_type': 'offline',
'response_type': 'code',
}
expected = client._update_query_params(flow.auth_uri, query_params)
expected = _helpers.update_query_params(flow.auth_uri, query_params)
assertUrisEqual(self, expected, result)
# Check stubs.
self.assertEqual(logger.warning.call_count, 1)
Expand Down Expand Up @@ -1735,7 +1722,7 @@ def test_step1_get_authorize_url_without_login_hint(self):
'access_type': 'offline',
'response_type': 'code',
}
expected = client._update_query_params(flow.auth_uri, query_params)
expected = _helpers.update_query_params(flow.auth_uri, query_params)
assertUrisEqual(self, expected, result)

def test_step1_get_device_and_user_codes_wo_device_uri(self):
Expand Down