Skip to content

Commit

Permalink
Types and linting.
Browse files Browse the repository at this point in the history
Closes pulp#926
  • Loading branch information
decko committed Jul 26, 2024
1 parent 2d173fe commit 88d7682
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 42 deletions.
63 changes: 33 additions & 30 deletions pulp-glue/pulp_glue/common/authentication.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,53 @@
import click
import json
import requests

from typing import NoReturn
import typing as t
from datetime import datetime, timedelta
from pathlib import Path

import click
import requests


class OAuth2Auth(requests.auth.AuthBase):

def __init__(self, *args, **kwargs):
self.client_id: str = kwargs.get("username")
self.client_secret: str = kwargs.get("password")
self.flow: list = kwargs.get("flow")
self.token_url: str = self.flow["flows"]["clientCredentials"]["tokenUrl"]
self.scope: str = [*self.flow["flows"]["clientCredentials"]["scopes"]][0]
self.token: dict = {}
def __init__(self, *args: t.List[t.Any], **kwargs: t.Dict[t.Any, t.Any]):
self.client_id = kwargs.get("username")
self.client_secret = kwargs.get("password")
self.flow: t.Dict[t.Any, t.Any] = kwargs["flow"]
self.token_url = self.flow["flows"]["clientCredentials"]["tokenUrl"]
self.scope = [*self.flow["flows"]["clientCredentials"]["scopes"]][0]
self.token: t.Dict[t.Any, t.Any] = {}

def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest:
self.retrieve_local_token()

access_token = self.token.get("access_token")
request.headers["Authorization"] = f"Bearer {access_token}"

request.register_hook("response", self.handle401)
request.register_hook("response", self.handle401) # type: ignore

return request

def handle401(self, request: requests.PreparedRequest, **kwargs) -> requests.PreparedRequest:
if request.status_code != 401:
return request
def handle401(
self, response: requests.Response, **kwargs: t.Dict[t.Any, t.Any]
) -> requests.Response:
if response.status_code != 401:
return response

self.retrieve_local_token()
if self.is_token_expired():
self.retrieve_token()

request.content
prep = request.request.copy()
response.content
prep = response.request.copy()

access_token = self.token.get('access_token')
access_token = self.token.get("access_token")
prep.headers["Authorization"] = f"Bearer {access_token}"

_request = request.connection.send(prep, **kwargs)
_request.history.append(request)
_request.request = prep
_response: requests.Response = response.connection.send(prep, **kwargs) # type: ignore
_response.history.append(response)
_response.request = prep

return _request
return _response

def is_token_expired(self) -> bool:
if self.token:
Expand All @@ -58,19 +60,20 @@ def is_token_expired(self) -> bool:

return True

def store_local_token(self) -> NoReturn:
TOKEN_LOCATION = (Path(click.utils.get_app_dir("pulp"), "token.json"))
def store_local_token(self) -> None:
TOKEN_LOCATION = Path(click.utils.get_app_dir("pulp"), "token.json")
with Path(TOKEN_LOCATION).open("w") as token_file:
token = json.dumps(self.token)
token_file.write(token)

def retrieve_local_token(self) -> NoReturn:
TOKEN_LOCATION = (Path(click.utils.get_app_dir("pulp"), "token.json"))
with Path(TOKEN_LOCATION).open("r") as token_file:
token_json = token_file.read()
self.token = json.loads(token_json)
def retrieve_local_token(self) -> None:
token_file = Path(click.utils.get_app_dir("pulp"), "token.json")
if token_file.exists():
with token_file.open("r") as tf:
token_json = tf.read()
self.token = json.loads(token_json)

def retrieve_token(self) -> NoReturn:
def retrieve_token(self) -> None:
data = {
"client_id": self.client_id,
"client_secret": self.client_secret,
Expand Down
22 changes: 12 additions & 10 deletions pulp-glue/pulp_glue/common/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from pulp_glue.common import __version__
from pulp_glue.common.i18n import get_translation
from pulp_glue.common.authentication import OAuth2Auth

translation = get_translation(__package__)
_ = translation.gettext
Expand Down Expand Up @@ -59,8 +58,11 @@ def basic_auth(self) -> t.Optional[t.Union[t.Tuple[str, str], requests.auth.Auth
"""Implement this to provide means of http basic auth."""
return None

def auth(self) -> t.Optional[t.Union[t.Tuple[str, str], requests.auth.AuthBase]]:
return
def auth(
self, flow: t.Dict[t.Any, t.Any]
) -> t.Optional[t.Union[t.Tuple[str, str], requests.auth.AuthBase]]:
"""Implement this to provide other authentication methods."""
return None

def __call__(
self,
Expand All @@ -79,11 +81,12 @@ def __call__(
if "oauth2" in authorized_schemes_types:
oauth_flow = [flow for flow in authorized_schemes if flow["type"] == "oauth2"][0]
result = self.auth(oauth_flow)
if result:
return result
elif "http" in authorized_schemes_types:
result = self.basic_auth()

if result:
return result
if result:
return result
raise OpenAPIError(_("No suitable auth scheme found."))


Expand All @@ -106,7 +109,7 @@ def __init__(self, username: str, password: str):
self.client_id = username
self.client_secret = password

def auth(self, oauth_payload: dict) -> requests.auth.AuthBase:
def auth(self, flow: dict[t.Any, t.Any]) -> t.Optional[requests.auth.AuthBase]:
pass


Expand Down Expand Up @@ -628,14 +631,13 @@ def render_request(
security: t.List[t.Dict[str, t.List[str]]] = method_spec.get(
"security", self.api_spec.get("security", {})
)
auth: t.Optional[t.Union[tuple[str, str], requests.auth.AuthBase]] = None
if security and self.auth_provider:
if "Authorization" in self._session.headers:
# Bad idea, but you wanted it that way.
auth = None
else:
auth: AuthProviderBase = self.auth_provider(
security, self.api_spec["components"]["securitySchemes"]
)
auth = self.auth_provider(security, self.api_spec["components"]["securitySchemes"])
else:
# No auth required? Don't provide it.
# No auth_provider available? Hope for the best (should do the trick for cert auth).
Expand Down
4 changes: 2 additions & 2 deletions pulpcore/cli/common/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import requests
import schema as s
import yaml
from pulp_glue.common.authentication import OAuth2Auth
from pulp_glue.common.context import (
DATETIME_FORMATS,
DEFAULT_LIMIT,
Expand All @@ -31,7 +32,6 @@
)
from pulp_glue.common.i18n import get_translation
from pulp_glue.common.openapi import AuthProviderBase
from pulp_glue.common.authentication import OAuth2Auth

try:
from pygments import highlight
Expand Down Expand Up @@ -232,7 +232,7 @@ def basic_auth(self) -> t.Optional[t.Union[t.Tuple[str, str], requests.auth.Auth
self.pulp_ctx.password = click.prompt("Password", hide_input=True)
return (self.pulp_ctx.username, self.pulp_ctx.password)

def auth(self, flow):
def auth(self, flow: t.Dict[t.Any, t.Any]) -> t.Optional[requests.auth.AuthBase]:
if self.pulp_ctx.username is None:
self.pulp_ctx.username = click.prompt("Username/ClientID")
if self.pulp_ctx.password is None:
Expand Down

0 comments on commit 88d7682

Please sign in to comment.