From 27a5feb83ce3adbb50f1546ff3ba1822ab9b1ac6 Mon Sep 17 00:00:00 2001 From: Theron Luhn Date: Tue, 19 Jan 2016 08:14:48 -0800 Subject: [PATCH] Implement BasicAuth decode classmethod. --- aiohttp/helpers.py | 36 ++++++++++++++++++++++++++++++++++++ tests/test_helpers.py | 16 ++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/aiohttp/helpers.py b/aiohttp/helpers.py index 5ddc141d8fb..81dac3bfd8f 100644 --- a/aiohttp/helpers.py +++ b/aiohttp/helpers.py @@ -1,5 +1,6 @@ """Various helper functions""" import base64 +import binascii import datetime import functools import io @@ -31,6 +32,41 @@ def __new__(cls, login, password='', encoding='latin1'): return super().__new__(cls, login, password, encoding) + @classmethod + def decode(cls, auth_header, encoding='latin1'): + """Create a :class:`BasicAuth` object from an ``Authorization`` HTTP + header. + + :param auth_header: The value of the ``Authorization`` header. + :type auth_header: str + :param encoding: The character encoding used on the password. + :type encoding: str + + :returns: The decoded authentication. + :rtype: :class:`BasicAuth` + + :raises ValueError: if the headers are unable to be decoded. + + """ + split = auth_header.strip().split(' ') + if len(split) == 2: + if split[0].strip().lower() != 'basic': + raise ValueError('Unknown authorization method %s' % split[0]) + to_decode = split[1] + elif len(split) == 1: + to_decode = split[0] + else: + raise ValueError('Could not parse authorization header.') + + try: + username, _, password = base64.b64decode( + to_decode.encode('ascii') + ).decode(encoding).partition(':') + except binascii.Error: + raise ValueError('Invalid base64 encoding.') + + return cls(username, password) + def encode(self): """Encode credentials.""" creds = ('%s:%s' % (self.login, self.password)).encode(self.encoding) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 12e2ae65edb..99d15f167c0 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -71,6 +71,22 @@ def test_basic_auth4(): assert auth.encode() == 'Basic bmtpbTpwd2Q=' +def test_basic_auth_decode(): + auth = helpers.BasicAuth.decode('Basic bmtpbTpwd2Q=') + assert auth.login == 'nkim' + assert auth.password == 'pwd' + + +def test_basic_auth_decode_not_basic(): + with pytest.raises(ValueError): + helpers.BasicAuth.decode('Complex bmtpbTpwd2Q=') + + +def test_basic_auth_decode_bad_base64(): + with pytest.raises(ValueError): + helpers.BasicAuth.decode('Basic bmtpbTpwd2Q') + + def test_invalid_formdata_params(): with pytest.raises(TypeError): helpers.FormData('asdasf')