Skip to content

Commit

Permalink
cleanup keyring code
Browse files Browse the repository at this point in the history
  • Loading branch information
katrogan committed Dec 6, 2019
1 parent 053e03e commit 946a058
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 40 deletions.
20 changes: 0 additions & 20 deletions flytekit/clients/helpers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
import keyring as _keyring

from flytekit.clis.auth import credentials as _credentials_access

# Identifies the service used for storing passwords in keyring
_keyring_service_name = "flytecli"
# Identifies the key used for storing and fetching from keyring. In our case, instead of a username as the keyring docs
# suggest, we are storing a user's oidc.
_keyring_storage_key = "access_token"


def iterate_node_executions(
Expand Down Expand Up @@ -85,17 +79,3 @@ def iterate_task_executions(client, node_execution_identifier, limit=None, filte
break
token = next_token


# Fetches an existing authorization access token if it exists in keyring or sets if it's unassigned.
def get_global_access_token():
access_token = _keyring.get_password(_keyring_service_name, _keyring_storage_key)
if access_token is None:
access_token = set_global_access_token()
return access_token


# Assigns and returns the authorization access token in keyring.
def set_global_access_token():
credentials = _credentials_access.get_client().credentials
_keyring.set_password(_keyring_service_name, _keyring_storage_key, credentials.access_token)
return credentials.access_token
16 changes: 5 additions & 11 deletions flytekit/clients/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
StatusCode as _GrpcStatusCode, ssl_channel_credentials as _ssl_channel_credentials
from flyteidl.service import admin_pb2_grpc as _admin_service
from flytekit.common.exceptions import user as _user_exceptions
from flytekit.configuration.platform import AUTH as _AUTH
from flytekit.configuration.creds import (
CLIENT_ID as _CLIENT_ID,
CLIENT_CREDENTIALS_SCOPE as _SCOPE,
Expand All @@ -14,9 +15,6 @@
from flytekit.configuration import creds as _creds_config, platform as _platform_config

from flytekit.clis.auth import credentials as _credentials_access
from flytekit.clients.helpers import (
get_global_access_token as _get_global_access_token, set_global_access_token as _set_global_access_token
)


def _refresh_credentials_standard(flyte_client):
Expand All @@ -28,14 +26,7 @@ def _refresh_credentials_standard(flyte_client):
"""

_credentials_access.get_client().refresh_access_token()
_set_global_access_token()

if not _platform_config.AUTH.get():
# nothing to do
return

access_token = _get_global_access_token()
flyte_client.set_access_token(access_token)
flyte_client.set_access_token(_credentials_access.get_client().credentials.access_token)


def _refresh_credentials_basic(flyte_client):
Expand Down Expand Up @@ -134,6 +125,8 @@ def __init__(self, url, insecure=False, credentials=None, options=None):
)
self._stub = _admin_service.AdminServiceStub(self._channel)
self._metadata = None
if _AUTH.get():
self.force_auth_flow()

def set_access_token(self, access_token):
self._metadata = [(_creds_config.AUTHORIZATION_METADATA_KEY.get(), "Bearer {}".format(access_token))]
Expand Down Expand Up @@ -279,6 +272,7 @@ def list_workflow_ids_paginated(self, identifier_list_request):
:rtype: flyteidl.admin.common_pb2.NamedEntityIdentifierList
:raises: TODO
"""
_logging.warn("hi katrina, metadata is {}".format(self._metadata))
return self._stub.ListWorkflowIds(identifier_list_request, metadata=self._metadata)

@_handle_rpc_error
Expand Down
43 changes: 39 additions & 4 deletions flytekit/clis/auth/auth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import base64 as _base64
import hashlib as _hashlib
import keyring as _keyring
import os as _os
import re as _re
import requests as _requests
Expand Down Expand Up @@ -31,6 +32,15 @@
_utf_8 = 'utf-8'


# Identifies the service used for storing passwords in keyring
_keyring_service_name = "flyteauth"
# Identifies the key used for storing and fetching from keyring. In our case, instead of a username as the keyring docs
# suggest, we are storing a user's oidc.
_keyring_access_token_storage_key = "access_token"
_keyring_id_token_storage_key = "id_token"
_keyring_refresh_token_storage_key = "refresh_token"


def _generate_code_verifier():
"""
Generates a 'code_verifier' as described in https://tools.ietf.org/html/rfc7636#section-4.1
Expand Down Expand Up @@ -148,6 +158,7 @@ def __init__(self, auth_endpoint=None, token_endpoint=None, client_id=None, redi
self._credentials = None
self._refresh_token = None
self._headers = {'content-type': "application/x-www-form-urlencoded"}
self._expired = False

self._params = {
"client_id": client_id, # This must match the Client ID of the OAuth application.
Expand All @@ -160,7 +171,15 @@ def __init__(self, auth_endpoint=None, token_endpoint=None, client_id=None, redi
"code_challenge_method": "S256",
}

# Initiate token request flow
# Prefer to use already-fetched token values when they've been set globally.
self._refresh_token = _keyring.get_password(_keyring_service_name, _keyring_refresh_token_storage_key)
access_token = _keyring.get_password(_keyring_service_name, _keyring_access_token_storage_key)
id_token = _keyring.get_password(_keyring_service_name, _keyring_id_token_storage_key)
if access_token and id_token:
self._credentials = Credentials(access_token=access_token, id_token=id_token)
return

# In the absence of globally-set token values, initiate the token request flow
q = _Queue()
# First prepare the callback server in the background
server = self._create_callback_server(q)
Expand Down Expand Up @@ -203,7 +222,14 @@ def _initialize_credentials(self, auth_token_resp):
if "refresh_token" in response_body:
self._refresh_token = response_body["refresh_token"]

self._credentials = Credentials(access_token=response_body["access_token"], id_token=response_body["id_token"])
access_token = response_body["access_token"]
id_token = response_body["id_token"]
refresh_token = response_body["refresh_token"]

_keyring.set_password(_keyring_service_name, _keyring_access_token_storage_key, access_token)
_keyring.set_password(_keyring_service_name, _keyring_id_token_storage_key, id_token)
_keyring.set_password(_keyring_service_name, _keyring_refresh_token_storage_key, refresh_token)
self._credentials = Credentials(access_token=access_token, id_token=id_token)

def request_access_token(self, auth_code):
if self._state != auth_code.state:
Expand Down Expand Up @@ -239,8 +265,10 @@ def refresh_access_token(self):
allow_redirects=False
)
if resp.status_code != _StatusCodes.OK:
raise Exception('Failed to request access token with response: [{}] {}'.format(
resp.status_code, resp.content))
self._expired = True
# In the absence of a successful response, assume the refresh token is expired. This should indicate
# to the caller that the AuthorizationClient is defunct and a new one needs to be re-initialized.
return
self._initialize_credentials(resp)

@property
Expand All @@ -249,3 +277,10 @@ def credentials(self):
:return flytekit.clis.auth.auth.Credentials:
"""
return self._credentials

@property
def expired(self):
"""
:return bool:
"""
return self._expired
7 changes: 2 additions & 5 deletions flytekit/clis/auth/credentials.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from __future__ import absolute_import

from flytekit.clis.auth.auth import AuthorizationClient as _AuthorizationClient
from flytekit.clis.auth.discovery import DiscoveryClient as _DiscoveryClient

Expand Down Expand Up @@ -28,11 +27,9 @@ def _get_discovery_endpoint():

def get_client():
global _authorization_client
if _authorization_client is not None:
if _authorization_client is not None and not _authorization_client.expired:
return _authorization_client
discovery_endpoint = _get_discovery_endpoint()
discovery_client = _DiscoveryClient(discovery_url=discovery_endpoint)
authorization_endpoints = discovery_client.get_authorization_endpoints()
authorization_endpoints = get_authorization_endpoints()

_authorization_client =\
_AuthorizationClient(redirect_uri=_REDIRECT_URI.get(), client_id=_CLIENT_ID.get(),
Expand Down

0 comments on commit 946a058

Please sign in to comment.