diff --git a/flytekit/clients/auth/authenticator.py b/flytekit/clients/auth/authenticator.py index cff09a4916..bde25602e0 100644 --- a/flytekit/clients/auth/authenticator.py +++ b/flytekit/clients/auth/authenticator.py @@ -48,7 +48,13 @@ class Authenticator(object): Base authenticator for all authentication flows """ - def __init__(self, endpoint: str, header_key: str, credentials: Credentials = None, http_proxy_url: typing.Optional[str] = None,): + def __init__( + self, + endpoint: str, + header_key: str, + credentials: Credentials = None, + http_proxy_url: typing.Optional[str] = None, + ): self._endpoint = endpoint self._creds = credentials self._header_key = header_key if header_key else "authorization" @@ -163,7 +169,7 @@ def __init__( cfg_store: ClientConfigStore, header_key: typing.Optional[str] = None, scopes: typing.Optional[typing.List[str]] = None, - http_proxy_url: typing.Optional[str] = None + http_proxy_url: typing.Optional[str] = None, ): if not client_id or not client_secret: raise ValueError("Client ID and Client SECRET both are required.") @@ -189,7 +195,9 @@ def refresh_credentials(self): # Note that unlike the Pkce flow, the client ID does not come from Admin. logging.debug(f"Basic authorization flow with client id {self._client_id} scope {scopes}") authorization_header = token_client.get_basic_authorization_header(self._client_id, self._client_secret) - token, expires_in = token_client.get_token(token_endpoint, scopes, authorization_header, http_proxy_url=self._http_proxy_url) + token, expires_in = token_client.get_token( + token_endpoint, scopes, authorization_header, http_proxy_url=self._http_proxy_url + ) logging.info("Retrieved new token, expires in {}".format(expires_in)) self._creds = Credentials(token) @@ -209,7 +217,7 @@ def __init__( cfg_store: ClientConfigStore, header_key: typing.Optional[str] = None, audience: typing.Optional[str] = None, - http_proxy_url: typing.Optional[str] = None + http_proxy_url: typing.Optional[str] = None, ): self._audience = audience cfg = cfg_store.get_client_config() @@ -222,11 +230,16 @@ def __init__( "Device Authentication is not available on the Flyte backend / authentication server" ) super().__init__( - endpoint=endpoint, header_key=header_key or cfg.header_key, credentials=KeyringStore.retrieve(endpoint), http_proxy_url=http_proxy_url + endpoint=endpoint, + header_key=header_key or cfg.header_key, + credentials=KeyringStore.retrieve(endpoint), + http_proxy_url=http_proxy_url, ) def refresh_credentials(self): - resp = token_client.get_device_code(self._device_auth_endpoint, self._client_id, self._audience, self._scope, self._http_proxy_url) + resp = token_client.get_device_code( + self._device_auth_endpoint, self._client_id, self._audience, self._scope, self._http_proxy_url + ) print( f""" To Authenticate navigate in a browser to the following URL: {resp.verification_uri} and enter code: {resp.user_code} @@ -235,7 +248,9 @@ def refresh_credentials(self): try: # Currently the refresh token is not retreived. We may want to add support for refreshTokens so that # access tokens can be refreshed for once authenticated machines - token, expires_in = token_client.poll_token_endpoint(resp, self._token_endpoint, client_id=self._client_id, http_proxy_url=self._http_proxy_url) + token, expires_in = token_client.poll_token_endpoint( + resp, self._token_endpoint, client_id=self._client_id, http_proxy_url=self._http_proxy_url + ) self._creds = Credentials(access_token=token, expires_in=expires_in, for_endpoint=self._endpoint) KeyringStore.store(self._creds) except Exception: diff --git a/flytekit/clients/auth/token_client.py b/flytekit/clients/auth/token_client.py index 452f179fac..728a56b5f7 100644 --- a/flytekit/clients/auth/token_client.py +++ b/flytekit/clients/auth/token_client.py @@ -6,9 +6,10 @@ import urllib.parse from dataclasses import dataclass from datetime import datetime, timedelta -from flytekit import logger + import requests +from flytekit import logger from flytekit.clients.auth.exceptions import AuthenticationError, AuthenticationPending utf_8 = "utf-8" @@ -132,7 +133,9 @@ def get_device_code( return DeviceCodeResponse.from_json_response(resp.json()) -def poll_token_endpoint(resp: DeviceCodeResponse, token_endpoint: str, client_id: str, http_proxy_url: typing.Optional[str] = None) -> typing.Tuple[str, int]: +def poll_token_endpoint( + resp: DeviceCodeResponse, token_endpoint: str, client_id: str, http_proxy_url: typing.Optional[str] = None +) -> typing.Tuple[str, int]: tick = datetime.now() interval = timedelta(seconds=resp.interval) end_time = tick + timedelta(seconds=resp.expires_in) @@ -143,7 +146,7 @@ def poll_token_endpoint(resp: DeviceCodeResponse, token_endpoint: str, client_id grant_type=GrantType.DEVICE_CODE, client_id=client_id, device_code=resp.device_code, - http_proxy_url=http_proxy_url + http_proxy_url=http_proxy_url, ) print("Authentication successful!") return access_token, expires_in diff --git a/flytekit/clients/auth_helper.py b/flytekit/clients/auth_helper.py index 991945bf2c..a068e487cf 100644 --- a/flytekit/clients/auth_helper.py +++ b/flytekit/clients/auth_helper.py @@ -72,7 +72,7 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth client_secret=cfg.client_credentials_secret, cfg_store=cfg_store, scopes=cfg.scopes, - http_proxy_url=cfg.http_proxy_url + http_proxy_url=cfg.http_proxy_url, ) elif cfg_auth == AuthType.EXTERNAL_PROCESS or cfg_auth == AuthType.EXTERNALCOMMAND: client_cfg = None @@ -83,7 +83,9 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth header_key=client_cfg.header_key if client_cfg else None, ) elif cfg_auth == AuthType.DEVICEFLOW: - return DeviceCodeAuthenticator(endpoint=cfg.endpoint, cfg_store=cfg_store, audience=cfg.audience, http_proxy_url=cfg.http_proxy_url) + return DeviceCodeAuthenticator( + endpoint=cfg.endpoint, cfg_store=cfg_store, audience=cfg.audience, http_proxy_url=cfg.http_proxy_url + ) else: raise ValueError( f"Invalid auth mode [{cfg_auth}] specified." f"Please update the creds config to use a valid value" diff --git a/flytekit/configuration/internal.py b/flytekit/configuration/internal.py index e00f699d80..4f993b4e11 100644 --- a/flytekit/configuration/internal.py +++ b/flytekit/configuration/internal.py @@ -111,9 +111,7 @@ class Platform(object): CA_CERT_FILE_PATH = ConfigEntry( LegacyConfigEntry(SECTION, "ca_cert_file_path"), YamlConfigEntry("admin.caCertFilePath") ) - HTTP_PROXY_URL = ConfigEntry( - LegacyConfigEntry(SECTION, "http_proxy_url"), YamlConfigEntry("admin.httpProxyURL") - ) + HTTP_PROXY_URL = ConfigEntry(LegacyConfigEntry(SECTION, "http_proxy_url"), YamlConfigEntry("admin.httpProxyURL")) class LocalSDK(object): diff --git a/tests/flytekit/unit/clients/auth/test_authenticator.py b/tests/flytekit/unit/clients/auth/test_authenticator.py index 231c9c8f0b..495f4648ac 100644 --- a/tests/flytekit/unit/clients/auth/test_authenticator.py +++ b/tests/flytekit/unit/clients/auth/test_authenticator.py @@ -70,7 +70,11 @@ def test_command_authenticator(mock_subprocess: MagicMock): @patch("flytekit.clients.auth.token_client.requests") def test_client_creds_authenticator(mock_requests): authn = ClientCredentialsAuthenticator( - ENDPOINT, client_id="client", client_secret="secret", cfg_store=static_cfg_store, http_proxy_url="https://my-proxy:31111" + ENDPOINT, + client_id="client", + client_secret="secret", + cfg_store=static_cfg_store, + http_proxy_url="https://my-proxy:31111", ) response = MagicMock() @@ -103,12 +107,7 @@ def test_device_flow_authenticator(poll_mock: MagicMock, device_mock: MagicMock, device_authorization_endpoint="dev", ) ) - authn = DeviceCodeAuthenticator( - ENDPOINT, - cfg_store, - audience="x", - http_proxy_url="http://my-proxy:9000" - ) + authn = DeviceCodeAuthenticator(ENDPOINT, cfg_store, audience="x", http_proxy_url="http://my-proxy:9000") device_mock.return_value = DeviceCodeResponse("x", "y", "s", 1000, 0) poll_mock.return_value = ("access", 100) diff --git a/tests/flytekit/unit/clients/auth/test_token_client.py b/tests/flytekit/unit/clients/auth/test_token_client.py index 50699396c0..f0c10b16d1 100644 --- a/tests/flytekit/unit/clients/auth/test_token_client.py +++ b/tests/flytekit/unit/clients/auth/test_token_client.py @@ -28,7 +28,9 @@ def test_get_token(mock_requests): response.status_code = 200 response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") mock_requests.post.return_value = response - access, expiration = get_token("https://corp.idp.net", client_id="abc123", scopes=["my_scope"], http_proxy_url="http://proxy:3000") + access, expiration = get_token( + "https://corp.idp.net", client_id="abc123", scopes=["my_scope"], http_proxy_url="http://proxy:3000" + ) assert access == "abc" assert expiration == 60 @@ -62,9 +64,7 @@ def test_poll_token_endpoint(mock_requests): response.json.return_value = {"error": error_auth_pending} mock_requests.post.return_value = response - r = DeviceCodeResponse( - device_code="x", user_code="y", verification_uri="v", expires_in=1, interval=1 - ) + r = DeviceCodeResponse(device_code="x", user_code="y", verification_uri="v", expires_in=1, interval=1) with pytest.raises(AuthenticationError): poll_token_endpoint(r, "test.com", "test", http_proxy_url="http://proxy:3000") @@ -72,9 +72,7 @@ def test_poll_token_endpoint(mock_requests): response.ok = True response.json.return_value = {"access_token": "abc", "expires_in": 60} mock_requests.post.return_value = response - r = DeviceCodeResponse( - device_code="x", user_code="y", verification_uri="v", expires_in=1, interval=0 - ) + r = DeviceCodeResponse(device_code="x", user_code="y", verification_uri="v", expires_in=1, interval=0) t, e = poll_token_endpoint(r, "test.com", "test", http_proxy_url="http://proxy:3000") assert t assert e