Skip to content

Commit

Permalink
lint!
Browse files Browse the repository at this point in the history
Signed-off-by: byhsu <[email protected]>
  • Loading branch information
ByronHsu committed May 11, 2023
1 parent 0d4684f commit bf022a4
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 29 deletions.
29 changes: 22 additions & 7 deletions flytekit/clients/auth/authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,13 @@ class Authenticator(object):
Base authenticator for all authentication flows
"""

def __init__(self, endpoint: str, header_key: str, credentials: Credentials = None, http_proxy_url: typing.Optional[str] = None,):
def __init__(
self,
endpoint: str,
header_key: str,
credentials: Credentials = None,
http_proxy_url: typing.Optional[str] = None,
):
self._endpoint = endpoint
self._creds = credentials
self._header_key = header_key if header_key else "authorization"
Expand Down Expand Up @@ -163,7 +169,7 @@ def __init__(
cfg_store: ClientConfigStore,
header_key: typing.Optional[str] = None,
scopes: typing.Optional[typing.List[str]] = None,
http_proxy_url: typing.Optional[str] = None
http_proxy_url: typing.Optional[str] = None,
):
if not client_id or not client_secret:
raise ValueError("Client ID and Client SECRET both are required.")
Expand All @@ -189,7 +195,9 @@ def refresh_credentials(self):
# Note that unlike the Pkce flow, the client ID does not come from Admin.
logging.debug(f"Basic authorization flow with client id {self._client_id} scope {scopes}")
authorization_header = token_client.get_basic_authorization_header(self._client_id, self._client_secret)
token, expires_in = token_client.get_token(token_endpoint, scopes, authorization_header, http_proxy_url=self._http_proxy_url)
token, expires_in = token_client.get_token(
token_endpoint, scopes, authorization_header, http_proxy_url=self._http_proxy_url
)
logging.info("Retrieved new token, expires in {}".format(expires_in))
self._creds = Credentials(token)

Expand All @@ -209,7 +217,7 @@ def __init__(
cfg_store: ClientConfigStore,
header_key: typing.Optional[str] = None,
audience: typing.Optional[str] = None,
http_proxy_url: typing.Optional[str] = None
http_proxy_url: typing.Optional[str] = None,
):
self._audience = audience
cfg = cfg_store.get_client_config()
Expand All @@ -222,11 +230,16 @@ def __init__(
"Device Authentication is not available on the Flyte backend / authentication server"
)
super().__init__(
endpoint=endpoint, header_key=header_key or cfg.header_key, credentials=KeyringStore.retrieve(endpoint), http_proxy_url=http_proxy_url
endpoint=endpoint,
header_key=header_key or cfg.header_key,
credentials=KeyringStore.retrieve(endpoint),
http_proxy_url=http_proxy_url,
)

def refresh_credentials(self):
resp = token_client.get_device_code(self._device_auth_endpoint, self._client_id, self._audience, self._scope, self._http_proxy_url)
resp = token_client.get_device_code(
self._device_auth_endpoint, self._client_id, self._audience, self._scope, self._http_proxy_url
)
print(
f"""
To Authenticate navigate in a browser to the following URL: {resp.verification_uri} and enter code: {resp.user_code}
Expand All @@ -235,7 +248,9 @@ def refresh_credentials(self):
try:
# Currently the refresh token is not retreived. We may want to add support for refreshTokens so that
# access tokens can be refreshed for once authenticated machines
token, expires_in = token_client.poll_token_endpoint(resp, self._token_endpoint, client_id=self._client_id, http_proxy_url=self._http_proxy_url)
token, expires_in = token_client.poll_token_endpoint(
resp, self._token_endpoint, client_id=self._client_id, http_proxy_url=self._http_proxy_url
)
self._creds = Credentials(access_token=token, expires_in=expires_in, for_endpoint=self._endpoint)
KeyringStore.store(self._creds)
except Exception:
Expand Down
9 changes: 6 additions & 3 deletions flytekit/clients/auth/token_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import urllib.parse
from dataclasses import dataclass
from datetime import datetime, timedelta
from flytekit import logger

import requests

from flytekit import logger
from flytekit.clients.auth.exceptions import AuthenticationError, AuthenticationPending

utf_8 = "utf-8"
Expand Down Expand Up @@ -132,7 +133,9 @@ def get_device_code(
return DeviceCodeResponse.from_json_response(resp.json())


def poll_token_endpoint(resp: DeviceCodeResponse, token_endpoint: str, client_id: str, http_proxy_url: typing.Optional[str] = None) -> typing.Tuple[str, int]:
def poll_token_endpoint(
resp: DeviceCodeResponse, token_endpoint: str, client_id: str, http_proxy_url: typing.Optional[str] = None
) -> typing.Tuple[str, int]:
tick = datetime.now()
interval = timedelta(seconds=resp.interval)
end_time = tick + timedelta(seconds=resp.expires_in)
Expand All @@ -143,7 +146,7 @@ def poll_token_endpoint(resp: DeviceCodeResponse, token_endpoint: str, client_id
grant_type=GrantType.DEVICE_CODE,
client_id=client_id,
device_code=resp.device_code,
http_proxy_url=http_proxy_url
http_proxy_url=http_proxy_url,
)
print("Authentication successful!")
return access_token, expires_in
Expand Down
6 changes: 4 additions & 2 deletions flytekit/clients/auth_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth
client_secret=cfg.client_credentials_secret,
cfg_store=cfg_store,
scopes=cfg.scopes,
http_proxy_url=cfg.http_proxy_url
http_proxy_url=cfg.http_proxy_url,
)
elif cfg_auth == AuthType.EXTERNAL_PROCESS or cfg_auth == AuthType.EXTERNALCOMMAND:
client_cfg = None
Expand All @@ -83,7 +83,9 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth
header_key=client_cfg.header_key if client_cfg else None,
)
elif cfg_auth == AuthType.DEVICEFLOW:
return DeviceCodeAuthenticator(endpoint=cfg.endpoint, cfg_store=cfg_store, audience=cfg.audience, http_proxy_url=cfg.http_proxy_url)
return DeviceCodeAuthenticator(
endpoint=cfg.endpoint, cfg_store=cfg_store, audience=cfg.audience, http_proxy_url=cfg.http_proxy_url
)
else:
raise ValueError(
f"Invalid auth mode [{cfg_auth}] specified." f"Please update the creds config to use a valid value"
Expand Down
4 changes: 1 addition & 3 deletions flytekit/configuration/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,7 @@ class Platform(object):
CA_CERT_FILE_PATH = ConfigEntry(
LegacyConfigEntry(SECTION, "ca_cert_file_path"), YamlConfigEntry("admin.caCertFilePath")
)
HTTP_PROXY_URL = ConfigEntry(
LegacyConfigEntry(SECTION, "http_proxy_url"), YamlConfigEntry("admin.httpProxyURL")
)
HTTP_PROXY_URL = ConfigEntry(LegacyConfigEntry(SECTION, "http_proxy_url"), YamlConfigEntry("admin.httpProxyURL"))


class LocalSDK(object):
Expand Down
13 changes: 6 additions & 7 deletions tests/flytekit/unit/clients/auth/test_authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@ def test_command_authenticator(mock_subprocess: MagicMock):
@patch("flytekit.clients.auth.token_client.requests")
def test_client_creds_authenticator(mock_requests):
authn = ClientCredentialsAuthenticator(
ENDPOINT, client_id="client", client_secret="secret", cfg_store=static_cfg_store, http_proxy_url="https://my-proxy:31111"
ENDPOINT,
client_id="client",
client_secret="secret",
cfg_store=static_cfg_store,
http_proxy_url="https://my-proxy:31111",
)

response = MagicMock()
Expand Down Expand Up @@ -103,12 +107,7 @@ def test_device_flow_authenticator(poll_mock: MagicMock, device_mock: MagicMock,
device_authorization_endpoint="dev",
)
)
authn = DeviceCodeAuthenticator(
ENDPOINT,
cfg_store,
audience="x",
http_proxy_url="http://my-proxy:9000"
)
authn = DeviceCodeAuthenticator(ENDPOINT, cfg_store, audience="x", http_proxy_url="http://my-proxy:9000")

device_mock.return_value = DeviceCodeResponse("x", "y", "s", 1000, 0)
poll_mock.return_value = ("access", 100)
Expand Down
12 changes: 5 additions & 7 deletions tests/flytekit/unit/clients/auth/test_token_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def test_get_token(mock_requests):
response.status_code = 200
response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""")
mock_requests.post.return_value = response
access, expiration = get_token("https://corp.idp.net", client_id="abc123", scopes=["my_scope"], http_proxy_url="http://proxy:3000")
access, expiration = get_token(
"https://corp.idp.net", client_id="abc123", scopes=["my_scope"], http_proxy_url="http://proxy:3000"
)
assert access == "abc"
assert expiration == 60

Expand Down Expand Up @@ -62,19 +64,15 @@ def test_poll_token_endpoint(mock_requests):
response.json.return_value = {"error": error_auth_pending}
mock_requests.post.return_value = response

r = DeviceCodeResponse(
device_code="x", user_code="y", verification_uri="v", expires_in=1, interval=1
)
r = DeviceCodeResponse(device_code="x", user_code="y", verification_uri="v", expires_in=1, interval=1)
with pytest.raises(AuthenticationError):
poll_token_endpoint(r, "test.com", "test", http_proxy_url="http://proxy:3000")

response = MagicMock()
response.ok = True
response.json.return_value = {"access_token": "abc", "expires_in": 60}
mock_requests.post.return_value = response
r = DeviceCodeResponse(
device_code="x", user_code="y", verification_uri="v", expires_in=1, interval=0
)
r = DeviceCodeResponse(device_code="x", user_code="y", verification_uri="v", expires_in=1, interval=0)
t, e = poll_token_endpoint(r, "test.com", "test", http_proxy_url="http://proxy:3000")
assert t
assert e

0 comments on commit bf022a4

Please sign in to comment.