Skip to content

Commit

Permalink
Refactor middleware (#325)
Browse files Browse the repository at this point in the history
* Split middleware into a subpackage

* Refactor basic auth header building

* Add encoding parameter to to_bytes()
  • Loading branch information
florimondmanca authored Sep 7, 2019
1 parent 0f34f3b commit 71648ec
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 51 deletions.
10 changes: 4 additions & 6 deletions httpx/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,10 @@
from .dispatch.threaded import ThreadedDispatcher
from .dispatch.wsgi import WSGIDispatch
from .exceptions import HTTPError, InvalidURL
from .middleware import (
BaseMiddleware,
BasicAuthMiddleware,
CustomAuthMiddleware,
RedirectMiddleware,
)
from .middleware.base import BaseMiddleware
from .middleware.basic_auth import BasicAuthMiddleware
from .middleware.custom_auth import CustomAuthMiddleware
from .middleware.redirect import RedirectMiddleware
from .models import (
URL,
AsyncRequest,
Expand Down
Empty file added httpx/middleware/__init__.py
Empty file.
10 changes: 10 additions & 0 deletions httpx/middleware/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import typing

from ..models import AsyncRequest, AsyncResponse


class BaseMiddleware:
async def __call__(
self, request: AsyncRequest, get_response: typing.Callable
) -> AsyncResponse:
raise NotImplementedError # pragma: no cover
27 changes: 27 additions & 0 deletions httpx/middleware/basic_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import typing
from base64 import b64encode

from ..models import AsyncRequest, AsyncResponse
from ..utils import to_bytes
from .base import BaseMiddleware


class BasicAuthMiddleware(BaseMiddleware):
def __init__(
self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
):
self.authorization_header = build_basic_auth_header(username, password)

async def __call__(
self, request: AsyncRequest, get_response: typing.Callable
) -> AsyncResponse:
request.headers["Authorization"] = self.authorization_header
return await get_response(request)


def build_basic_auth_header(
username: typing.Union[str, bytes], password: typing.Union[str, bytes]
) -> str:
userpass = b":".join((to_bytes(username), to_bytes(password)))
token = b64encode(userpass).decode().strip()
return f"Basic {token}"
15 changes: 15 additions & 0 deletions httpx/middleware/custom_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import typing

from ..models import AsyncRequest, AsyncResponse
from .base import BaseMiddleware


class CustomAuthMiddleware(BaseMiddleware):
def __init__(self, auth: typing.Callable[[AsyncRequest], AsyncRequest]):
self.auth = auth

async def __call__(
self, request: AsyncRequest, get_response: typing.Callable
) -> AsyncResponse:
request = self.auth(request)
return await get_response(request)
50 changes: 5 additions & 45 deletions httpx/middleware.py → httpx/middleware/redirect.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,11 @@
import functools
import typing
from base64 import b64encode

from .config import DEFAULT_MAX_REDIRECTS
from .exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects
from .models import URL, AsyncRequest, AsyncResponse, Cookies, Headers
from .status_codes import codes


class BaseMiddleware:
async def __call__(
self, request: AsyncRequest, get_response: typing.Callable
) -> AsyncResponse:
raise NotImplementedError # pragma: no cover


class BasicAuthMiddleware(BaseMiddleware):
def __init__(
self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
):
if isinstance(username, str):
username = username.encode("latin1")

if isinstance(password, str):
password = password.encode("latin1")

userpass = b":".join((username, password))
token = b64encode(userpass).decode().strip()

self.authorization_header = f"Basic {token}"

async def __call__(
self, request: AsyncRequest, get_response: typing.Callable
) -> AsyncResponse:
request.headers["Authorization"] = self.authorization_header
return await get_response(request)


class CustomAuthMiddleware(BaseMiddleware):
def __init__(self, auth: typing.Callable[[AsyncRequest], AsyncRequest]):
self.auth = auth

async def __call__(
self, request: AsyncRequest, get_response: typing.Callable
) -> AsyncResponse:
request = self.auth(request)
return await get_response(request)
from ..config import DEFAULT_MAX_REDIRECTS
from ..exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects
from ..models import URL, AsyncRequest, AsyncResponse, Cookies, Headers
from ..status_codes import codes
from .base import BaseMiddleware


class RedirectMiddleware(BaseMiddleware):
Expand Down
4 changes: 4 additions & 0 deletions httpx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,7 @@ def get_logger(name: str) -> logging.Logger:
logger.addHandler(handler)

return logging.getLogger(name)


def to_bytes(value: typing.Union[str, bytes], encoding: str = "utf-8") -> bytes:
return value.encode(encoding) if isinstance(value, str) else value

0 comments on commit 71648ec

Please sign in to comment.