From 910af27ea093bc779e93daab44b9af85839ea03f Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Sat, 7 Sep 2019 23:10:42 +0200 Subject: [PATCH] Refactor basic auth header building --- httpx/middleware/basic_auth.py | 20 ++++++++++---------- httpx/utils.py | 4 ++++ 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/httpx/middleware/basic_auth.py b/httpx/middleware/basic_auth.py index d85ebfd27b..faffba2fd1 100644 --- a/httpx/middleware/basic_auth.py +++ b/httpx/middleware/basic_auth.py @@ -2,6 +2,7 @@ from base64 import b64encode from ..models import AsyncRequest, AsyncResponse +from ..utils import to_bytes from .base import BaseMiddleware @@ -9,19 +10,18 @@ class BasicAuthMiddleware(BaseMiddleware): def __init__( self, username: typing.Union[str, bytes], password: typing.Union[str, bytes] ): - if isinstance(username, str): - username = username.encode("latin1") - - if isinstance(password, str): - password = password.encode("latin1") - - userpass = b":".join((username, password)) - token = b64encode(userpass).decode().strip() - - self.authorization_header = f"Basic {token}" + self.authorization_header = build_basic_auth_header(username, password) async def __call__( self, request: AsyncRequest, get_response: typing.Callable ) -> AsyncResponse: request.headers["Authorization"] = self.authorization_header return await get_response(request) + + +def build_basic_auth_header( + username: typing.Union[str, bytes], password: typing.Union[str, bytes] +) -> str: + userpass = b":".join((to_bytes(username), to_bytes(password))) + token = b64encode(userpass).decode().strip() + return f"Basic {token}" diff --git a/httpx/utils.py b/httpx/utils.py index 0357ea2bd2..75e8235f48 100644 --- a/httpx/utils.py +++ b/httpx/utils.py @@ -169,3 +169,7 @@ def get_logger(name: str) -> logging.Logger: logger.addHandler(handler) return logging.getLogger(name) + + +def to_bytes(value: typing.Union[str, bytes]) -> bytes: + return value.encode() if isinstance(value, str) else value