diff --git a/flytekit/clients/helpers.py b/flytekit/clients/helpers.py index bee5b46d8c..10640b6d74 100644 --- a/flytekit/clients/helpers.py +++ b/flytekit/clients/helpers.py @@ -1,12 +1,6 @@ -import keyring as _keyring from flytekit.clis.auth import credentials as _credentials_access -# Identifies the service used for storing passwords in keyring -_keyring_service_name = "flytecli" -# Identifies the key used for storing and fetching from keyring. In our case, instead of a username as the keyring docs -# suggest, we are storing a user's oidc. -_keyring_storage_key = "access_token" def iterate_node_executions( @@ -85,17 +79,3 @@ def iterate_task_executions(client, node_execution_identifier, limit=None, filte break token = next_token - -# Fetches an existing authorization access token if it exists in keyring or sets if it's unassigned. -def get_global_access_token(): - access_token = _keyring.get_password(_keyring_service_name, _keyring_storage_key) - if access_token is None: - access_token = set_global_access_token() - return access_token - - -# Assigns and returns the authorization access token in keyring. -def set_global_access_token(): - credentials = _credentials_access.get_client().credentials - _keyring.set_password(_keyring_service_name, _keyring_storage_key, credentials.access_token) - return credentials.access_token diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index 38a7d495a0..e58b4a6e88 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -4,6 +4,7 @@ StatusCode as _GrpcStatusCode, ssl_channel_credentials as _ssl_channel_credentials from flyteidl.service import admin_pb2_grpc as _admin_service from flytekit.common.exceptions import user as _user_exceptions +from flytekit.configuration.platform import AUTH as _AUTH from flytekit.configuration.creds import ( CLIENT_ID as _CLIENT_ID, CLIENT_CREDENTIALS_SCOPE as _SCOPE, @@ -14,9 +15,6 @@ from flytekit.configuration import creds as _creds_config, platform as _platform_config from flytekit.clis.auth import credentials as _credentials_access -from flytekit.clients.helpers import ( - get_global_access_token as _get_global_access_token, set_global_access_token as _set_global_access_token -) def _refresh_credentials_standard(flyte_client): @@ -28,14 +26,7 @@ def _refresh_credentials_standard(flyte_client): """ _credentials_access.get_client().refresh_access_token() - _set_global_access_token() - - if not _platform_config.AUTH.get(): - # nothing to do - return - - access_token = _get_global_access_token() - flyte_client.set_access_token(access_token) + flyte_client.set_access_token(_credentials_access.get_client().credentials.access_token) def _refresh_credentials_basic(flyte_client): @@ -134,6 +125,8 @@ def __init__(self, url, insecure=False, credentials=None, options=None): ) self._stub = _admin_service.AdminServiceStub(self._channel) self._metadata = None + if _AUTH.get(): + self.force_auth_flow() def set_access_token(self, access_token): self._metadata = [(_creds_config.AUTHORIZATION_METADATA_KEY.get(), "Bearer {}".format(access_token))] @@ -279,6 +272,7 @@ def list_workflow_ids_paginated(self, identifier_list_request): :rtype: flyteidl.admin.common_pb2.NamedEntityIdentifierList :raises: TODO """ + _logging.warn("hi katrina, metadata is {}".format(self._metadata)) return self._stub.ListWorkflowIds(identifier_list_request, metadata=self._metadata) @_handle_rpc_error diff --git a/flytekit/clis/auth/auth.py b/flytekit/clis/auth/auth.py index 42bc93ab8f..005371d180 100644 --- a/flytekit/clis/auth/auth.py +++ b/flytekit/clis/auth/auth.py @@ -1,5 +1,6 @@ import base64 as _base64 import hashlib as _hashlib +import keyring as _keyring import os as _os import re as _re import requests as _requests @@ -31,6 +32,15 @@ _utf_8 = 'utf-8' +# Identifies the service used for storing passwords in keyring +_keyring_service_name = "flyteauth" +# Identifies the key used for storing and fetching from keyring. In our case, instead of a username as the keyring docs +# suggest, we are storing a user's oidc. +_keyring_access_token_storage_key = "access_token" +_keyring_id_token_storage_key = "id_token" +_keyring_refresh_token_storage_key = "refresh_token" + + def _generate_code_verifier(): """ Generates a 'code_verifier' as described in https://tools.ietf.org/html/rfc7636#section-4.1 @@ -148,6 +158,7 @@ def __init__(self, auth_endpoint=None, token_endpoint=None, client_id=None, redi self._credentials = None self._refresh_token = None self._headers = {'content-type': "application/x-www-form-urlencoded"} + self._expired = False self._params = { "client_id": client_id, # This must match the Client ID of the OAuth application. @@ -160,7 +171,15 @@ def __init__(self, auth_endpoint=None, token_endpoint=None, client_id=None, redi "code_challenge_method": "S256", } - # Initiate token request flow + # Prefer to use already-fetched token values when they've been set globally. + self._refresh_token = _keyring.get_password(_keyring_service_name, _keyring_refresh_token_storage_key) + access_token = _keyring.get_password(_keyring_service_name, _keyring_access_token_storage_key) + id_token = _keyring.get_password(_keyring_service_name, _keyring_id_token_storage_key) + if access_token and id_token: + self._credentials = Credentials(access_token=access_token, id_token=id_token) + return + + # In the absence of globally-set token values, initiate the token request flow q = _Queue() # First prepare the callback server in the background server = self._create_callback_server(q) @@ -203,7 +222,14 @@ def _initialize_credentials(self, auth_token_resp): if "refresh_token" in response_body: self._refresh_token = response_body["refresh_token"] - self._credentials = Credentials(access_token=response_body["access_token"], id_token=response_body["id_token"]) + access_token = response_body["access_token"] + id_token = response_body["id_token"] + refresh_token = response_body["refresh_token"] + + _keyring.set_password(_keyring_service_name, _keyring_access_token_storage_key, access_token) + _keyring.set_password(_keyring_service_name, _keyring_id_token_storage_key, id_token) + _keyring.set_password(_keyring_service_name, _keyring_refresh_token_storage_key, refresh_token) + self._credentials = Credentials(access_token=access_token, id_token=id_token) def request_access_token(self, auth_code): if self._state != auth_code.state: @@ -239,8 +265,10 @@ def refresh_access_token(self): allow_redirects=False ) if resp.status_code != _StatusCodes.OK: - raise Exception('Failed to request access token with response: [{}] {}'.format( - resp.status_code, resp.content)) + self._expired = True + # In the absence of a successful response, assume the refresh token is expired. This should indicate + # to the caller that the AuthorizationClient is defunct and a new one needs to be re-initialized. + return self._initialize_credentials(resp) @property @@ -249,3 +277,10 @@ def credentials(self): :return flytekit.clis.auth.auth.Credentials: """ return self._credentials + + @property + def expired(self): + """ + :return bool: + """ + return self._expired diff --git a/flytekit/clis/auth/credentials.py b/flytekit/clis/auth/credentials.py index 7b945987fd..f2d3744a8b 100644 --- a/flytekit/clis/auth/credentials.py +++ b/flytekit/clis/auth/credentials.py @@ -1,5 +1,4 @@ from __future__ import absolute_import - from flytekit.clis.auth.auth import AuthorizationClient as _AuthorizationClient from flytekit.clis.auth.discovery import DiscoveryClient as _DiscoveryClient @@ -28,11 +27,9 @@ def _get_discovery_endpoint(): def get_client(): global _authorization_client - if _authorization_client is not None: + if _authorization_client is not None and not _authorization_client.expired: return _authorization_client - discovery_endpoint = _get_discovery_endpoint() - discovery_client = _DiscoveryClient(discovery_url=discovery_endpoint) - authorization_endpoints = discovery_client.get_authorization_endpoints() + authorization_endpoints = get_authorization_endpoints() _authorization_client =\ _AuthorizationClient(redirect_uri=_REDIRECT_URI.get(), client_id=_CLIENT_ID.get(),