From d0715bbe0bab50602982d22b236da41df9e8f82d Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Fri, 11 Oct 2019 16:46:43 -0700 Subject: [PATCH 01/40] WIP pkce auth flow --- flytekit/clis/auth/__init__.py | 0 flytekit/clis/auth/auth.py | 234 +++++++++++++++++++++++++++++++++ 2 files changed, 234 insertions(+) create mode 100644 flytekit/clis/auth/__init__.py create mode 100644 flytekit/clis/auth/auth.py diff --git a/flytekit/clis/auth/__init__.py b/flytekit/clis/auth/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flytekit/clis/auth/auth.py b/flytekit/clis/auth/auth.py new file mode 100644 index 0000000000..3fee40b975 --- /dev/null +++ b/flytekit/clis/auth/auth.py @@ -0,0 +1,234 @@ +import webbrowser +import base64 +import hashlib +import os +import re +import requests +from multiprocessing import Process, Queue + +try: # Python 3.5+ + from http import HTTPStatus as StatusCodes +except ImportError: + try: # Python 3 + from http import client as StatusCodes + except ImportError: # Python 2 + import httplib as StatusCodes +from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler + +from urlparse import urlparse, urljoin, parse_qsl + +code_verifier_length = 64 + + +def generate_code_verifier(): + """ + Generates a 'code_verifier' as described in section 4.1 of RFC 7636. + Adapted from https://github.com/openstack/deb-python-oauth2client/blob/master/oauth2client/_pkce.py. + :return str: + """ + code_verifier = base64.urlsafe_b64encode(os.urandom(code_verifier_length)).decode('utf-8') + # Eliminate invalid characters. + code_verifier = re.sub('[^a-zA-Z0-9]+', '', code_verifier) + if len(code_verifier) < 43: + raise ValueError("Verifier too short. number of bytes must be > 30.") + elif len(code_verifier) > 128: + raise ValueError("Verifier too long. number of bytes must be < 97.") + return code_verifier + + +# TODO(katrogan): Figure out how random this needs to be +def generate_state_parameter(): + state = base64.urlsafe_b64encode(os.urandom(40)).decode('utf-8') + # Eliminate invalid characters. + code_verifier = re.sub('[^a-zA-Z0-9-_.,]+', '', state) + return code_verifier + + +def create_code_challenge(code_verifier): + """ + Adapted from https://github.com/openstack/deb-python-oauth2client/blob/master/oauth2client/_pkce.py. + :param str code_verifier: represents a code verifier generated by generate_code_verifier() + :return str: urlsafe base64-encoded sha256 hash digest + """ + code_challenge = hashlib.sha256(code_verifier.encode('utf-8')).digest() + code_challenge = base64.urlsafe_b64encode(code_challenge).decode('utf-8') + # Eliminate invalid characters + code_challenge = code_challenge.replace('=', '') + return code_challenge + + +class AuthorizationCode(object): + def __init__(self, code, state): + self._code = code + self._state = state + + @property + def code(self): + return self._code + + @property + def state(self): + return self._state + + def __repr__(self): + return "[{}, {}]".format(self.code, self.state) + + +class OAuthCallbackHandler(BaseHTTPRequestHandler): + """ + A simple wrapper around BaseHTTPServer.BaseHTTPRequestHandler that handles a callback URL that accepts an + authorization token. + """ + + def do_GET(self): + url = urlparse(self.path) + if url.path == self.server.redirect_path: + self.send_response(StatusCodes.OK) + self.end_headers() + self.handle_login(dict(parse_qsl(url.query))) + else: + self.send_response(404) + + def handle_login(self, data): + self.server.handle_authorization_code(AuthorizationCode(data['code'], data['state'])) + + +class OAuthHTTPServer(HTTPServer): + """ + A simple wrapper around the BaseHTTPServer.HTTPServer implementation that binds an authorization_client for handling + authorization code callbacks. + """ + def __init__(self, server_address, RequestHandlerClass, bind_and_activate=True, + redirect_path=None, queue=None): + HTTPServer.__init__(self, server_address, RequestHandlerClass, bind_and_activate) + self._redirect_path = redirect_path + self._auth_code = None + self._queue = queue + + @property + def redirect_path(self): + return self._redirect_path + + def handle_authorization_code(self, auth_code): + self._queue.put(auth_code) + + +class Credentials(object): + # TODO(katrogan): Also add expires_in handling. + def __init__(self, access_token=None, id_token=None): + self._access_token = access_token + self._id_token = id_token + + @property + def access_token(self): + return self._access_token + + @property + def id_token(self): + return self._id_token + + +# TODO: +# do we need to support initiate login URI? https://devforum.okta.com/t/initiate-login-uri-for-all-subdomain-urls/3766 + + +class AuthorizationClient(object): + def __init__(self, auth_endpoint=None, token_endpoint=None, client_id=None, redirect_path=None): + self._auth_endpoint = auth_endpoint + self._token_endpoint = token_endpoint + self._client_id = client_id + self._redirect_path = redirect_path + self._code_verifier = generate_code_verifier() + code_challenge = create_code_challenge(self._code_verifier) + self._code_challenge = code_challenge + state = generate_state_parameter() + self._state = state + self._credentials = None + + self._params = { + # TODO: need an audience param here? + "client_id": client_id, # This must match the Client ID of the OAuth application. + "response_type": "code", # Indicates the authorization code grant + "scope": "openid", # ensures that the /token endpoint returns an ID token + # callback location where the user-agent will be directed to. + "redirect_uri": urljoin("http://localhost:8088", self._redirect_path), + "state": state, + "code_challenge": code_challenge, + "code_challenge_method": "S256", + } + + # Initialize token request flow + self._request_authorization_code() + # Start a server to handle the callback url. + self._start_callback_server() + + def _start_callback_server(self): + # TODO: change okta application port + server_address = ('localhost', 8088) + q = Queue() + server = OAuthHTTPServer(server_address, OAuthCallbackHandler, redirect_path=self._redirect_path, queue=q) + server_process = Process(target=server.handle_request) + + server_process.start() + auth_code = q.get() + server_process.terminate() + self.request_access_token(auth_code) + + def _request_authorization_code(self): + # Spin up a background local http server to receive the callback request containing the authorization code + resp = requests.get( + url=self._auth_endpoint, + params=self._params, + allow_redirects=False + ) + if resp.status_code == StatusCodes.FOUND: + # Follow the redirect + redirect_location = resp.headers['Location'] + if redirect_location is None: + raise ValueError('Received a 302 but no follow up location was provided in headers') + webbrowser.open_new_tab(redirect_location) + + def request_access_token(self, auth_code): + if self._state != auth_code.state: + raise ValueError("Unexpected state parameter [{}] passed".format(auth_code.state)) + self._params.update({ + "code": auth_code.code, + "code_verifier": self._code_verifier, + "grant_type": "authorization_code", + }) + resp = requests.post( + url=self._token_endpoint, + data=self._params, + headers={'content-type': "application/x-www-form-urlencoded"}, + allow_redirects=False + ) + if resp.status_code != StatusCodes.OK: + # TODO: handle expected (?) error cases: + # https://auth0.com/docs/flows/guides/device-auth/call-api-device-auth#token-responses + raise Exception('Failed to request access token with response: [{}] {}'.format( + resp.status_code, resp.content)) + + """ + The response body is of the form: + { + "access_token": "foo", + "refresh_token": "bar", + "id_token": "baz", + "token_type": "Bearer" + } + """ + response_body = resp.json() + if "access_token" not in response_body: + raise ValueError('Expected "access_token" in response from oauth server') + + self._credentials = Credentials(access_token=response_body["access_token"], id_token=response_body["id_token"]) + + def credentials(self): + return self._credentials + + +if __name__ == '__main__': + client = AuthorizationClient(redirect_path="/callback", client_id="my_client", + auth_endpoint="https://myoauth.com/oauth2/default/v1/authorize", + token_endpoint="https://myoauth.com/oauth2/default/v1/token") + client.credentials() From f5761298868d9ef3fa5547cdd455ac512b439d53 Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Tue, 22 Oct 2019 17:13:15 -0700 Subject: [PATCH 02/40] Add discovery --- flytekit/clis/auth/access.py | 34 +++++++++++++++++ flytekit/clis/auth/auth.py | 27 +++++++------- flytekit/clis/auth/discovery.py | 66 +++++++++++++++++++++++++++++++++ flytekit/configuration/creds.py | 22 +++++++++++ 4 files changed, 135 insertions(+), 14 deletions(-) create mode 100644 flytekit/clis/auth/access.py create mode 100644 flytekit/clis/auth/discovery.py create mode 100644 flytekit/configuration/creds.py diff --git a/flytekit/clis/auth/access.py b/flytekit/clis/auth/access.py new file mode 100644 index 0000000000..a30ab5b76d --- /dev/null +++ b/flytekit/clis/auth/access.py @@ -0,0 +1,34 @@ +from auth import AuthorizationClient +from discovery import DiscoveryClient + +from flytekit.configuration.platform import URL +from flytekit.configuration.creds import DISCOVERY_ENDPOINT, REDIRECT_URI, CLIENT_ID + + +try: # Python 3 + from urllib.parse import urlparse, urljoin +except ImportError: # Python 2 + from urlparse import urlparse, urljoin + + +def _is_absolute(url): + return bool(urlparse(url).netloc) + + +def get_credentials(): + discovery_endpoint = DISCOVERY_ENDPOINT.get() + if not _is_absolute(discovery_endpoint): + discovery_endpoint = urljoin(URL.get(), discovery_endpoint) + discovery_client = DiscoveryClient(discovery_url=discovery_endpoint) + authorization_endpoints = discovery_client.get_authorization_endpoints() + + client = AuthorizationClient(redirect_uri=REDIRECT_URI.get(), + client_id=CLIENT_ID.get(), + auth_endpoint=authorization_endpoints.auth_endpoint, + token_endpoint=authorization_endpoints.token_endpoint) + return client.credentials() + + +if __name__ == '__main__': + get_credentials() + diff --git a/flytekit/clis/auth/auth.py b/flytekit/clis/auth/auth.py index 3fee40b975..a356defc45 100644 --- a/flytekit/clis/auth/auth.py +++ b/flytekit/clis/auth/auth.py @@ -15,7 +15,10 @@ import httplib as StatusCodes from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler -from urlparse import urlparse, urljoin, parse_qsl +try: # Python 3 + from urllib.parse import urlparse, urljoin, parse_qsl +except ImportError: # Python 2 + from urlparse import urlparse, urljoin, parse_qsl code_verifier_length = 64 @@ -133,11 +136,11 @@ def id_token(self): class AuthorizationClient(object): - def __init__(self, auth_endpoint=None, token_endpoint=None, client_id=None, redirect_path=None): + def __init__(self, auth_endpoint=None, token_endpoint=None, client_id=None, redirect_uri=None): self._auth_endpoint = auth_endpoint self._token_endpoint = token_endpoint self._client_id = client_id - self._redirect_path = redirect_path + self._redirect_uri = redirect_uri self._code_verifier = generate_code_verifier() code_challenge = create_code_challenge(self._code_verifier) self._code_challenge = code_challenge @@ -151,7 +154,7 @@ def __init__(self, auth_endpoint=None, token_endpoint=None, client_id=None, redi "response_type": "code", # Indicates the authorization code grant "scope": "openid", # ensures that the /token endpoint returns an ID token # callback location where the user-agent will be directed to. - "redirect_uri": urljoin("http://localhost:8088", self._redirect_path), + "redirect_uri": self._redirect_uri, "state": state, "code_challenge": code_challenge, "code_challenge_method": "S256", @@ -163,10 +166,10 @@ def __init__(self, auth_endpoint=None, token_endpoint=None, client_id=None, redi self._start_callback_server() def _start_callback_server(self): - # TODO: change okta application port - server_address = ('localhost', 8088) + server_url = urlparse(self._redirect_uri) + server_address = (server_url.hostname, server_url.port) q = Queue() - server = OAuthHTTPServer(server_address, OAuthCallbackHandler, redirect_path=self._redirect_path, queue=q) + server = OAuthHTTPServer(server_address, OAuthCallbackHandler, redirect_path=server_url.path, queue=q) server_process = Process(target=server.handle_request) server_process.start() @@ -224,11 +227,7 @@ def request_access_token(self, auth_code): self._credentials = Credentials(access_token=response_body["access_token"], id_token=response_body["id_token"]) def credentials(self): + """ + :return flytekit.clis.auth.Credentials: + """ return self._credentials - - -if __name__ == '__main__': - client = AuthorizationClient(redirect_path="/callback", client_id="my_client", - auth_endpoint="https://myoauth.com/oauth2/default/v1/authorize", - token_endpoint="https://myoauth.com/oauth2/default/v1/token") - client.credentials() diff --git a/flytekit/clis/auth/discovery.py b/flytekit/clis/auth/discovery.py new file mode 100644 index 0000000000..286b3e965a --- /dev/null +++ b/flytekit/clis/auth/discovery.py @@ -0,0 +1,66 @@ +import requests + +try: # Python 3.5+ + from http import HTTPStatus as StatusCodes +except ImportError: + try: # Python 3 + from http import client as StatusCodes + except ImportError: # Python 2 + import httplib as StatusCodes + +# These response keys are defined in https://tools.ietf.org/id/draft-ietf-oauth-discovery-08.html. +authorization_endpoint_key = "authorization_endpoint" +token_endpoint_key = "token_endpoint" + + +class AuthorizationEndpoints(object): + """ + A simple wrapper around commonly discovered endpoints used for the PKCE auth flow. + """ + def __init__(self, auth_endpoint=None, token_endpoint=None): + self._auth_endpoint = auth_endpoint + self._token_endpoint = token_endpoint + + @property + def auth_endpoint(self): + return self._auth_endpoint + + @property + def token_endpoint(self): + return self._token_endpoint + + +class DiscoveryClient(object): + """ + Discovers + """ + + def __init__(self, discovery_url=None): + self._discovery_url = discovery_url + self._authorization_endpoints = None + + def get_authorization_endpoints(self): + if self._authorization_endpoints is not None: + return self._authorization_endpoints + resp = requests.get( + url=self._discovery_url, + ) + # Follow at most one redirect. + if resp.status_code == StatusCodes.FOUND: + redirect_location = resp.headers['Location'] + if redirect_location is None: + raise ValueError('Received a 302 but no follow up location was provided in headers') + resp = requests.get( + url=redirect_location, + ) + + response_body = resp.json() + if response_body[authorization_endpoint_key] is None: + raise ValueError('Unable to discover authorization endpoint') + + if response_body[token_endpoint_key] is None: + raise ValueError('Unable to discover token endpoint') + + self._authorization_endpoints = AuthorizationEndpoints(auth_endpoint=response_body[authorization_endpoint_key], + token_endpoint=response_body[token_endpoint_key]) + return self._authorization_endpoints diff --git a/flytekit/configuration/creds.py b/flytekit/configuration/creds.py new file mode 100644 index 0000000000..dd95e53bad --- /dev/null +++ b/flytekit/configuration/creds.py @@ -0,0 +1,22 @@ +from __future__ import absolute_import + +from flytekit.configuration import common as _config_common + +DISCOVERY_ENDPOINT = _config_common.FlyteStringConfigurationEntry('credentials', 'discovery_endpoint', default=None) +""" +This endpoint fetches authorization server metadata as describe in this proposal: +https://tools.ietf.org/id/draft-ietf-oauth-discovery-08.html. +The endpoint path can be relative or absolute. +""" + +CLIENT_ID = _config_common.FlyteStringConfigurationEntry('credentials', 'client_id', default=None) +""" +This is the public identifier for the app which handles authorization for a Flyte deployment. +More details here: https://www.oauth.com/oauth2-servers/client-registration/client-id-secret/. +""" + +REDIRECT_URI = _config_common.FlyteStringConfigurationEntry('credentials', 'redirect_uri', default=None) +""" +This is the redirect uri registered with the app which handles authorization for a Flyte deployment. +More details here: https://www.oauth.com/oauth2-servers/redirect-uris/. +""" From c7eba16ef3d149d2f4afd5bea148b5d0af832193 Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Wed, 23 Oct 2019 13:53:38 -0700 Subject: [PATCH 03/40] pass authentication metadata in gRPC calls --- flytekit/clients/raw.py | 57 ++++++++++--------- flytekit/clis/auth/access.py | 34 ----------- flytekit/clis/auth/auth.py | 90 ++++++++++++++---------------- flytekit/clis/auth/credentials.py | 35 ++++++++++++ flytekit/clis/auth/discovery.py | 42 ++++++++------ flytekit/clis/flyte_cli/main.py | 62 +++++++++++++------- flytekit/configuration/creds.py | 8 +++ flytekit/configuration/platform.py | 4 ++ 8 files changed, 188 insertions(+), 144 deletions(-) delete mode 100644 flytekit/clis/auth/access.py create mode 100644 flytekit/clis/auth/credentials.py diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index 2babcf8072..48a29daa10 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -26,7 +26,7 @@ class RawSynchronousFlyteClient(object): be explicit as opposed to inferred from the environment or a configuration file. """ - def __init__(self, url, insecure=False, credentials=None, options=None): + def __init__(self, url, insecure=False, credentials=None, options=None, metadata=None): """ Initializes a gRPC channel to the given Flyte Admin service. @@ -35,6 +35,8 @@ def __init__(self, url, insecure=False, credentials=None, options=None): :param Text credentials: [Optional] If provided, a secure channel will be opened with the Flyte Admin Service. :param dict[Text, Text] options: [Optional] A dict of key-value string pairs for configuring the gRPC core runtime. + :param [(Text, Text)] metadata: [Optional] metadata pairs to be transmitted to the + service-side of the RPC. """ self._channel = None @@ -48,6 +50,7 @@ def __init__(self, url, insecure=False, credentials=None, options=None): options=list((options or {}).items()) ) self._stub = _admin_service.AdminServiceStub(self._channel) + self._metadata = metadata #################################################################################################################### # @@ -74,7 +77,7 @@ def create_task(self, task_create_request): task is already registered. :raises grpc.RpcError: """ - return self._stub.CreateTask(task_create_request) + return self._stub.CreateTask(task_create_request, metadata=self._metadata) @_handle_rpc_error def list_task_ids_paginated(self, identifier_list_request): @@ -100,7 +103,7 @@ def list_task_ids_paginated(self, identifier_list_request): :rtype: flyteidl.admin.common_pb2.NamedEntityIdentifierList :raises: TODO """ - return self._stub.ListTaskIds(identifier_list_request) + return self._stub.ListTaskIds(identifier_list_request, metadata=self._metadata) @_handle_rpc_error def list_tasks_paginated(self, resource_list_request): @@ -122,7 +125,7 @@ def list_tasks_paginated(self, resource_list_request): :rtype: flyteidl.admin.task_pb2.TaskList :raises: TODO """ - return self._stub.ListTasks(resource_list_request) + return self._stub.ListTasks(resource_list_request, metadata=self._metadata) @_handle_rpc_error def get_task(self, get_object_request): @@ -133,7 +136,7 @@ def get_task(self, get_object_request): :rtype: flyteidl.admin.task_pb2.Task :raises: TODO """ - return self._stub.GetTask(get_object_request) + return self._stub.GetTask(get_object_request, metadata=self._metadata) #################################################################################################################### # @@ -160,7 +163,7 @@ def create_workflow(self, workflow_create_request): identical workflow is already registered. :raises grpc.RpcError: """ - return self._stub.CreateWorkflow(workflow_create_request) + return self._stub.CreateWorkflow(workflow_create_request, metadata=self._metadata) @_handle_rpc_error def list_workflow_ids_paginated(self, identifier_list_request): @@ -186,7 +189,7 @@ def list_workflow_ids_paginated(self, identifier_list_request): :rtype: flyteidl.admin.common_pb2.NamedEntityIdentifierList :raises: TODO """ - return self._stub.ListWorkflowIds(identifier_list_request) + return self._stub.ListWorkflowIds(identifier_list_request, metadata=self._metadata) @_handle_rpc_error def list_workflows_paginated(self, resource_list_request): @@ -208,7 +211,7 @@ def list_workflows_paginated(self, resource_list_request): :rtype: flyteidl.admin.workflow_pb2.WorkflowList :raises: TODO """ - return self._stub.ListWorkflows(resource_list_request) + return self._stub.ListWorkflows(resource_list_request, metadata=self._metadata) @_handle_rpc_error def get_workflow(self, get_object_request): @@ -219,7 +222,7 @@ def get_workflow(self, get_object_request): :rtype: flyteidl.admin.workflow_pb2.Workflow :raises: TODO """ - return self._stub.GetWorkflow(get_object_request) + return self._stub.GetWorkflow(get_object_request, metadata=self._metadata) #################################################################################################################### # @@ -247,7 +250,7 @@ def create_launch_plan(self, launch_plan_create_request): the identical launch plan is already registered. :raises grpc.RpcError: """ - return self._stub.CreateLaunchPlan(launch_plan_create_request) + return self._stub.CreateLaunchPlan(launch_plan_create_request, metadata=self._metadata) # TODO: List endpoints when they come in @@ -259,7 +262,7 @@ def get_launch_plan(self, object_get_request): :param flyteidl.admin.common_pb2.ObjectGetRequest object_get_request: :rtype: flyteidl.admin.launch_plan_pb2.LaunchPlan """ - return self._stub.GetLaunchPlan(object_get_request) + return self._stub.GetLaunchPlan(object_get_request, metadata=self._metadata) @_handle_rpc_error def get_active_launch_plan(self, active_launch_plan_request): @@ -269,7 +272,7 @@ def get_active_launch_plan(self, active_launch_plan_request): :param flyteidl.admin.common_pb2.ActiveLaunchPlanRequest active_launch_plan_request: :rtype: flyteidl.admin.launch_plan_pb2.LaunchPlan """ - return self._stub.GetActiveLaunchPlan(active_launch_plan_request) + return self._stub.GetActiveLaunchPlan(active_launch_plan_request, metadata=self._metadata) @_handle_rpc_error def update_launch_plan(self, update_request): @@ -280,7 +283,7 @@ def update_launch_plan(self, update_request): :param flyteidl.admin.launch_plan_pb2.LaunchPlanUpdateRequest update_request: :rtype: flyteidl.admin.launch_plan_pb2.LaunchPlanUpdateResponse """ - return self._stub.UpdateLaunchPlan(update_request) + return self._stub.UpdateLaunchPlan(update_request, metadata=self._metadata) @_handle_rpc_error def list_launch_plan_ids_paginated(self, identifier_list_request): @@ -290,7 +293,7 @@ def list_launch_plan_ids_paginated(self, identifier_list_request): :param: flyteidl.admin.common_pb2.NamedEntityIdentifierListRequest identifier_list_request: :rtype: flyteidl.admin.common_pb2.NamedEntityIdentifierList """ - return self._stub.ListLaunchPlanIds(identifier_list_request) + return self._stub.ListLaunchPlanIds(identifier_list_request, metadata=self._metadata) @_handle_rpc_error def list_launch_plans_paginated(self, resource_list_request): @@ -300,7 +303,7 @@ def list_launch_plans_paginated(self, resource_list_request): :param: flyteidl.admin.common_pb2.ResourceListRequest resource_list_request: :rtype: flyteidl.admin.launch_plan_pb2.LaunchPlanList """ - return self._stub.ListLaunchPlans(resource_list_request) + return self._stub.ListLaunchPlans(resource_list_request, metadata=self._metadata) @_handle_rpc_error def list_active_launch_plans_paginated(self, active_launch_plan_list_request): @@ -310,7 +313,7 @@ def list_active_launch_plans_paginated(self, active_launch_plan_list_request): :param: flyteidl.admin.common_pb2.ActiveLaunchPlanListRequest active_launch_plan_list_request: :rtype: flyteidl.admin.launch_plan_pb2.LaunchPlanList """ - return self._stub.ListActiveLaunchPlans(active_launch_plan_list_request) + return self._stub.ListActiveLaunchPlans(active_launch_plan_list_request, metadata=self._metadata) #################################################################################################################### # @@ -325,7 +328,7 @@ def create_execution(self, create_execution_request): :param flyteidl.admin.execution_pb2.ExecutionCreateRequest create_execution_request: :rtype: flyteidl.admin.execution_pb2.ExecutionCreateResponse """ - return self._stub.CreateExecution(create_execution_request) + return self._stub.CreateExecution(create_execution_request, metadata=self._metadata) @_handle_rpc_error def get_execution(self, get_object_request): @@ -335,7 +338,7 @@ def get_execution(self, get_object_request): :param flyteidl.admin.execution_pb2.WorkflowExecutionGetRequest get_object_request: :rtype: flyteidl.admin.execution_pb2.Execution """ - return self._stub.GetExecution(get_object_request) + return self._stub.GetExecution(get_object_request, metadata=self._metadata) @_handle_rpc_error def list_executions_paginated(self, resource_list_request): @@ -345,7 +348,7 @@ def list_executions_paginated(self, resource_list_request): :param flyteidl.admin.common_pb2.ResourceListRequest resource_list_request: :rtype: flyteidl.admin.execution_pb2.ExecutionList """ - return self._stub.ListExecutions(resource_list_request) + return self._stub.ListExecutions(resource_list_request, metadata=self._metadata) @_handle_rpc_error def terminate_execution(self, terminate_execution_request): @@ -353,7 +356,7 @@ def terminate_execution(self, terminate_execution_request): :param flyteidl.admin.execution_pb2.TerminateExecutionRequest terminate_execution_request: :rtype: flyteidl.admin.execution_pb2.TerminateExecutionResponse """ - return self._stub.TerminateExecution(terminate_execution_request) + return self._stub.TerminateExecution(terminate_execution_request, metadata=self._metadata) #################################################################################################################### # @@ -366,21 +369,21 @@ def get_node_execution(self, node_execution_request): :param flyteidl.admin.node_execution_pb2.NodeExecutionGetRequest node_execution_request: :rtype: flyteidl.admin.node_execution_pb2.NodeExecution """ - return self._stub.GetNodeExecution(node_execution_request) + return self._stub.GetNodeExecution(node_execution_request, metadata=self._metadata) def list_node_executions_paginated(self, node_execution_list_request): """ :param flyteidl.admin.node_execution_pb2.NodeExecutionListRequest node_execution_list_request: :rtype: flyteidl.admin.node_execution_pb2.NodeExecutionList """ - return self._stub.ListNodeExecutions(node_execution_list_request) + return self._stub.ListNodeExecutions(node_execution_list_request, metadata=self._metadata) def list_node_executions_for_task_paginated(self, node_execution_for_task_list_request): """ :param flyteidl.admin.node_execution_pb2.NodeExecutionListRequest node_execution_for_task_list_request: :rtype: flyteidl.admin.node_execution_pb2.NodeExecutionList """ - return self._stub.ListNodeExecutionsForTask(node_execution_for_task_list_request) + return self._stub.ListNodeExecutionsForTask(node_execution_for_task_list_request, metadata=self._metadata) #################################################################################################################### # @@ -393,14 +396,14 @@ def get_task_execution(self, task_execution_request): :param flyteidl.admin.task_execution_pb2.TaskExecutionGetRequest task_execution_request: :rtype: flyteidl.admin.task_execution_pb2.TaskExecution """ - return self._stub.GetTaskExecution(task_execution_request) + return self._stub.GetTaskExecution(task_execution_request, metadata=self._metadata) def list_task_executions_paginated(self, task_execution_list_request): """ :param flyteidl.admin.task_execution_pb2.TaskExecutionListRequest task_execution_list_request: :rtype: flyteidl.admin.task_execution_pb2.TaskExecutionList """ - return self._stub.ListTaskExecutions(task_execution_list_request) + return self._stub.ListTaskExecutions(task_execution_list_request, metadata=self._metadata) #################################################################################################################### # @@ -415,7 +418,7 @@ def list_projects(self, project_list_request): :param flyteidl.admin.project_pb2.ProjectListRequest project_list_request: :rtype: flyteidl.admin.project_pb2.Projects """ - return self._stub.ListProjects(project_list_request) + return self._stub.ListProjects(project_list_request, metadata=self._metadata) @_handle_rpc_error def register_project(self, project_register_request): @@ -424,7 +427,7 @@ def register_project(self, project_register_request): :param flyteidl.admin.project_pb2.ProjectRegisterRequest project_register_request: :rtype: flyteidl.admin.project_pb2.ProjectRegisterResponse """ - return self._stub.RegisterProject(project_register_request) + return self._stub.RegisterProject(project_register_request, metadata=self._metadata) #################################################################################################################### # diff --git a/flytekit/clis/auth/access.py b/flytekit/clis/auth/access.py deleted file mode 100644 index a30ab5b76d..0000000000 --- a/flytekit/clis/auth/access.py +++ /dev/null @@ -1,34 +0,0 @@ -from auth import AuthorizationClient -from discovery import DiscoveryClient - -from flytekit.configuration.platform import URL -from flytekit.configuration.creds import DISCOVERY_ENDPOINT, REDIRECT_URI, CLIENT_ID - - -try: # Python 3 - from urllib.parse import urlparse, urljoin -except ImportError: # Python 2 - from urlparse import urlparse, urljoin - - -def _is_absolute(url): - return bool(urlparse(url).netloc) - - -def get_credentials(): - discovery_endpoint = DISCOVERY_ENDPOINT.get() - if not _is_absolute(discovery_endpoint): - discovery_endpoint = urljoin(URL.get(), discovery_endpoint) - discovery_client = DiscoveryClient(discovery_url=discovery_endpoint) - authorization_endpoints = discovery_client.get_authorization_endpoints() - - client = AuthorizationClient(redirect_uri=REDIRECT_URI.get(), - client_id=CLIENT_ID.get(), - auth_endpoint=authorization_endpoints.auth_endpoint, - token_endpoint=authorization_endpoints.token_endpoint) - return client.credentials() - - -if __name__ == '__main__': - get_credentials() - diff --git a/flytekit/clis/auth/auth.py b/flytekit/clis/auth/auth.py index a356defc45..9c0b4fbffe 100644 --- a/flytekit/clis/auth/auth.py +++ b/flytekit/clis/auth/auth.py @@ -1,37 +1,38 @@ -import webbrowser -import base64 -import hashlib -import os -import re -import requests -from multiprocessing import Process, Queue +import base64 as _base64 +import hashlib as _hashlib +import os as _os +import re as _re +import requests as _requests +import webbrowser as _webbrowser + +from multiprocessing import Process as _Process, Queue as _Queue try: # Python 3.5+ - from http import HTTPStatus as StatusCodes + from http import HTTPStatus as _StatusCodes except ImportError: try: # Python 3 - from http import client as StatusCodes + from http import client as _StatusCodes except ImportError: # Python 2 - import httplib as StatusCodes -from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler + import httplib as _StatusCodes +from BaseHTTPServer import HTTPServer as _HTTPServer, BaseHTTPRequestHandler as _BaseHTTPRequestHandler try: # Python 3 - from urllib.parse import urlparse, urljoin, parse_qsl + from urllib.parse import urlparse as _urlparse, parse_qsl as _parse_qsl except ImportError: # Python 2 - from urlparse import urlparse, urljoin, parse_qsl + from urlparse import urlparse as _urlparse, parse_qsl as _parse_qsl -code_verifier_length = 64 +_code_verifier_length = 64 -def generate_code_verifier(): +def _generate_code_verifier(): """ Generates a 'code_verifier' as described in section 4.1 of RFC 7636. Adapted from https://github.com/openstack/deb-python-oauth2client/blob/master/oauth2client/_pkce.py. :return str: """ - code_verifier = base64.urlsafe_b64encode(os.urandom(code_verifier_length)).decode('utf-8') + code_verifier = _base64.urlsafe_b64encode(_os.urandom(_code_verifier_length)).decode('utf-8') # Eliminate invalid characters. - code_verifier = re.sub('[^a-zA-Z0-9]+', '', code_verifier) + code_verifier = _re.sub('[^a-zA-Z0-9]+', '', code_verifier) if len(code_verifier) < 43: raise ValueError("Verifier too short. number of bytes must be > 30.") elif len(code_verifier) > 128: @@ -39,22 +40,21 @@ def generate_code_verifier(): return code_verifier -# TODO(katrogan): Figure out how random this needs to be -def generate_state_parameter(): - state = base64.urlsafe_b64encode(os.urandom(40)).decode('utf-8') +def _generate_state_parameter(): + state = _base64.urlsafe_b64encode(_os.urandom(40)).decode('utf-8') # Eliminate invalid characters. - code_verifier = re.sub('[^a-zA-Z0-9-_.,]+', '', state) + code_verifier = _re.sub('[^a-zA-Z0-9-_.,]+', '', state) return code_verifier -def create_code_challenge(code_verifier): +def _create_code_challenge(code_verifier): """ Adapted from https://github.com/openstack/deb-python-oauth2client/blob/master/oauth2client/_pkce.py. :param str code_verifier: represents a code verifier generated by generate_code_verifier() :return str: urlsafe base64-encoded sha256 hash digest """ - code_challenge = hashlib.sha256(code_verifier.encode('utf-8')).digest() - code_challenge = base64.urlsafe_b64encode(code_challenge).decode('utf-8') + code_challenge = _hashlib.sha256(code_verifier.encode('utf-8')).digest() + code_challenge = _base64.urlsafe_b64encode(code_challenge).decode('utf-8') # Eliminate invalid characters code_challenge = code_challenge.replace('=', '') return code_challenge @@ -73,22 +73,19 @@ def code(self): def state(self): return self._state - def __repr__(self): - return "[{}, {}]".format(self.code, self.state) - -class OAuthCallbackHandler(BaseHTTPRequestHandler): +class OAuthCallbackHandler(_BaseHTTPRequestHandler): """ A simple wrapper around BaseHTTPServer.BaseHTTPRequestHandler that handles a callback URL that accepts an authorization token. """ def do_GET(self): - url = urlparse(self.path) + url = _urlparse(self.path) if url.path == self.server.redirect_path: - self.send_response(StatusCodes.OK) + self.send_response(_StatusCodes.OK) self.end_headers() - self.handle_login(dict(parse_qsl(url.query))) + self.handle_login(dict(_parse_qsl(url.query))) else: self.send_response(404) @@ -96,14 +93,14 @@ def handle_login(self, data): self.server.handle_authorization_code(AuthorizationCode(data['code'], data['state'])) -class OAuthHTTPServer(HTTPServer): +class OAuthHTTPServer(_HTTPServer): """ A simple wrapper around the BaseHTTPServer.HTTPServer implementation that binds an authorization_client for handling authorization code callbacks. """ def __init__(self, server_address, RequestHandlerClass, bind_and_activate=True, redirect_path=None, queue=None): - HTTPServer.__init__(self, server_address, RequestHandlerClass, bind_and_activate) + _HTTPServer.__init__(self, server_address, RequestHandlerClass, bind_and_activate) self._redirect_path = redirect_path self._auth_code = None self._queue = queue @@ -141,15 +138,14 @@ def __init__(self, auth_endpoint=None, token_endpoint=None, client_id=None, redi self._token_endpoint = token_endpoint self._client_id = client_id self._redirect_uri = redirect_uri - self._code_verifier = generate_code_verifier() - code_challenge = create_code_challenge(self._code_verifier) + self._code_verifier = _generate_code_verifier() + code_challenge = _create_code_challenge(self._code_verifier) self._code_challenge = code_challenge - state = generate_state_parameter() + state = _generate_state_parameter() self._state = state self._credentials = None self._params = { - # TODO: need an audience param here? "client_id": client_id, # This must match the Client ID of the OAuth application. "response_type": "code", # Indicates the authorization code grant "scope": "openid", # ensures that the /token endpoint returns an ID token @@ -166,11 +162,11 @@ def __init__(self, auth_endpoint=None, token_endpoint=None, client_id=None, redi self._start_callback_server() def _start_callback_server(self): - server_url = urlparse(self._redirect_uri) + server_url = _urlparse(self._redirect_uri) server_address = (server_url.hostname, server_url.port) - q = Queue() + q = _Queue() server = OAuthHTTPServer(server_address, OAuthCallbackHandler, redirect_path=server_url.path, queue=q) - server_process = Process(target=server.handle_request) + server_process = _Process(target=server.handle_request) server_process.start() auth_code = q.get() @@ -178,18 +174,17 @@ def _start_callback_server(self): self.request_access_token(auth_code) def _request_authorization_code(self): - # Spin up a background local http server to receive the callback request containing the authorization code - resp = requests.get( + resp = _requests.get( url=self._auth_endpoint, params=self._params, allow_redirects=False ) - if resp.status_code == StatusCodes.FOUND: + if resp.status_code == _StatusCodes.FOUND: # Follow the redirect redirect_location = resp.headers['Location'] if redirect_location is None: raise ValueError('Received a 302 but no follow up location was provided in headers') - webbrowser.open_new_tab(redirect_location) + _webbrowser.open_new_tab(redirect_location) def request_access_token(self, auth_code): if self._state != auth_code.state: @@ -199,13 +194,13 @@ def request_access_token(self, auth_code): "code_verifier": self._code_verifier, "grant_type": "authorization_code", }) - resp = requests.post( + resp = _requests.post( url=self._token_endpoint, data=self._params, headers={'content-type': "application/x-www-form-urlencoded"}, allow_redirects=False ) - if resp.status_code != StatusCodes.OK: + if resp.status_code != _StatusCodes.OK: # TODO: handle expected (?) error cases: # https://auth0.com/docs/flows/guides/device-auth/call-api-device-auth#token-responses raise Exception('Failed to request access token with response: [{}] {}'.format( @@ -226,8 +221,9 @@ def request_access_token(self, auth_code): self._credentials = Credentials(access_token=response_body["access_token"], id_token=response_body["id_token"]) + @property def credentials(self): """ - :return flytekit.clis.auth.Credentials: + :return flytekit.clis.auth.auth.Credentials: """ return self._credentials diff --git a/flytekit/clis/auth/credentials.py b/flytekit/clis/auth/credentials.py new file mode 100644 index 0000000000..93ab895633 --- /dev/null +++ b/flytekit/clis/auth/credentials.py @@ -0,0 +1,35 @@ +from __future__ import absolute_import + +from flytekit.clis.auth.auth import AuthorizationClient as _AuthorizationClient +from flytekit.clis.auth.discovery import DiscoveryClient as _DiscoveryClient + +from flytekit.configuration.creds import ( + DISCOVERY_ENDPOINT as _DISCOVERY_ENDPOINT, + REDIRECT_URI as _REDIRECT_URI, + CLIENT_ID as _CLIENT_ID +) +from flytekit.configuration.platform import URL as _URL + + +try: # Python 3 + from urllib.parse import urlparse, urljoin +except ImportError: # Python 2 + from urlparse import urlparse, urljoin + + +def _is_absolute(url): + return bool(urlparse(url).netloc) + + +def get_credentials(): + discovery_endpoint = _DISCOVERY_ENDPOINT.get() + if not _is_absolute(discovery_endpoint): + discovery_endpoint = urljoin(_URL.get(), discovery_endpoint) + discovery_client = _DiscoveryClient(discovery_url=discovery_endpoint) + authorization_endpoints = discovery_client.get_authorization_endpoints() + + client = _AuthorizationClient(redirect_uri=_REDIRECT_URI.get(), + client_id=_CLIENT_ID.get(), + auth_endpoint=authorization_endpoints.auth_endpoint, + token_endpoint=authorization_endpoints.token_endpoint) + return client.credentials diff --git a/flytekit/clis/auth/discovery.py b/flytekit/clis/auth/discovery.py index 286b3e965a..68fb92512d 100644 --- a/flytekit/clis/auth/discovery.py +++ b/flytekit/clis/auth/discovery.py @@ -1,16 +1,16 @@ -import requests +import requests as _requests try: # Python 3.5+ - from http import HTTPStatus as StatusCodes + from http import HTTPStatus as _StatusCodes except ImportError: try: # Python 3 - from http import client as StatusCodes + from http import client as _StatusCodes except ImportError: # Python 2 - import httplib as StatusCodes + import httplib as _StatusCodes # These response keys are defined in https://tools.ietf.org/id/draft-ietf-oauth-discovery-08.html. -authorization_endpoint_key = "authorization_endpoint" -token_endpoint_key = "token_endpoint" +_authorization_endpoint_key = "authorization_endpoint" +_token_endpoint_key = "token_endpoint" class AuthorizationEndpoints(object): @@ -32,35 +32,43 @@ def token_endpoint(self): class DiscoveryClient(object): """ - Discovers + Discovers well known OpenID configuration and parses out authorization endpoints required for initiating the PKCE + auth flow. """ def __init__(self, discovery_url=None): self._discovery_url = discovery_url self._authorization_endpoints = None + @property + def authorization_endpoints(self): + """ + :rtype: flytekit.clis.auth.discovery.AuthorizationEndpoints: + """ + return self._authorization_endpoints + def get_authorization_endpoints(self): - if self._authorization_endpoints is not None: - return self._authorization_endpoints - resp = requests.get( + if self.authorization_endpoints is not None: + return self.authorization_endpoints + resp = _requests.get( url=self._discovery_url, ) # Follow at most one redirect. - if resp.status_code == StatusCodes.FOUND: + if resp.status_code == _StatusCodes.FOUND: redirect_location = resp.headers['Location'] if redirect_location is None: raise ValueError('Received a 302 but no follow up location was provided in headers') - resp = requests.get( + resp = _requests.get( url=redirect_location, ) response_body = resp.json() - if response_body[authorization_endpoint_key] is None: + if response_body[_authorization_endpoint_key] is None: raise ValueError('Unable to discover authorization endpoint') - if response_body[token_endpoint_key] is None: + if response_body[_token_endpoint_key] is None: raise ValueError('Unable to discover token endpoint') - self._authorization_endpoints = AuthorizationEndpoints(auth_endpoint=response_body[authorization_endpoint_key], - token_endpoint=response_body[token_endpoint_key]) - return self._authorization_endpoints + self._authorization_endpoints = AuthorizationEndpoints(auth_endpoint=response_body[_authorization_endpoint_key], + token_endpoint=response_body[_token_endpoint_key]) + return self.authorization_endpoints diff --git a/flytekit/clis/flyte_cli/main.py b/flytekit/clis/flyte_cli/main.py index cbe796f0a4..9170fee4d6 100644 --- a/flytekit/clis/flyte_cli/main.py +++ b/flytekit/clis/flyte_cli/main.py @@ -1,6 +1,7 @@ from __future__ import absolute_import import importlib as _importlib +import keyring as _keyring import os as _os import sys as _sys import stat as _stat @@ -11,6 +12,7 @@ from flytekit import __version__ from flytekit.clients import friendly as _friendly_client +from flytekit.clis.auth import credentials as _credentials_access from flytekit.clis.helpers import construct_literal_map_from_variable_map as _construct_literal_map_from_variable_map, \ construct_literal_map_from_parameter_map as _construct_literal_map_from_parameter_map, \ parse_args_into_dict as _parse_args_into_dict, str2bool as _str2bool @@ -18,7 +20,7 @@ from flytekit.common.core import identifier as _identifier from flytekit.common.types import helpers as _type_helpers from flytekit.common.utils import load_proto_from_file as _load_proto_from_file -from flytekit.configuration import platform as _platform_config +from flytekit.configuration import creds as _creds_config, platform as _platform_config from flytekit.interfaces.data import data_proxy as _data_proxy from flytekit.models import common as _common_models, filters as _filters, launch_plan as _launch_plan, literals as \ _literals @@ -30,6 +32,11 @@ _tt = _six.text_type +# Identifies the 'metadata' service used for storing passwords in keyring +_metadata = "metadata" +# Identifies the access token username for storing and fetching access token values in keyring. +_metadata_access_token = "access_token" + def _welcome_message(): _click.secho("Welcome to Flyte CLI! Version: {}".format(_tt(__version__)), bold=True) @@ -203,7 +210,7 @@ def _terminate_one_execution(client, urn, cause, shouldPrint=True): def _update_one_launch_plan(urn, host, insecure, state): - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) if state == "active": state = _launch_plan.LaunchPlanState.ACTIVE @@ -229,6 +236,23 @@ def _render_schedule_expr(lp): return "{:30}".format(sched_expr) +def _fetch_metadata(): + """ + Initializes gRPC metadata according to parameters set in the flyte config. + Currently this is used to pass security credentials when authentication is enabled. + :return [(Text, Text)]: metadata pairs to be transmitted to the service-side of the RPC. + """ + if not _platform_config.AUTH.get(): + # nothing to do + return None + access_token = _keyring.get_password(_metadata, _metadata_access_token) + if access_token is None: + credentials = _credentials_access.get_credentials() + _keyring.set_password(_metadata, _metadata_access_token, credentials.access_token) + access_token = credentials.access_token + return [(_creds_config.AUTHORIZATION_METADATA_KEY.get(), access_token)] + + _HOST_URL_ENV = _os.environ.get(_platform_config.URL.env_var, None) _INSECURE_ENV = _os.environ.get(_platform_config.INSECURE.env_var, None) _PROJECT_FLAGS = ["-p", "--project"] @@ -518,7 +542,7 @@ def list_task_names(project, domain, host, insecure, token, limit, show_all, sor a specific project and domain. """ _welcome_message() - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) _click.echo("Task Names Found in {}:{}\n".format(_tt(project), _tt(domain))) while True: @@ -560,7 +584,7 @@ def list_task_versions(project, domain, name, host, insecure, token, limit, show versions of that particular task (identifiable by {Project, Domain, Name}). """ _welcome_message() - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) _click.echo("Task Versions Found for {}:{}:{}\n".format(_tt(project), _tt(domain), _tt(name or '*'))) _click.echo("{:50} {:40}".format('Version', 'Urn')) @@ -599,7 +623,7 @@ def get_task(urn, host, insecure): The URN of the versioned task is in the form of ``tsk::::``. """ _welcome_message() - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) t = client.get_task(_identifier.Identifier.from_python_std(urn)) _click.echo(_tt(t)) _click.echo("") @@ -625,7 +649,7 @@ def list_workflow_names(project, domain, host, insecure, token, limit, show_all, List the names of the workflows under a scope specified by ``{project, domain}``. """ _welcome_message() - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) _click.echo("Workflow Names Found in {}:{}\n".format(_tt(project), _tt(domain))) while True: @@ -667,7 +691,7 @@ def list_workflow_versions(project, domain, name, host, insecure, token, limit, versions of that particular workflow (identifiable by ``{project, domain, name}``). """ _welcome_message() - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) _click.echo("Workflow Versions Found for {}:{}:{}\n".format(_tt(project), _tt(domain), _tt(name or '*'))) _click.echo("{:50} {:40}".format('Version', 'Urn')) @@ -706,7 +730,7 @@ def get_workflow(urn, host, insecure): ``wf::::`` """ _welcome_message() - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) _click.echo(client.get_workflow(_identifier.Identifier.from_python_std(urn))) # TODO: Print workflow pretty _click.echo("") @@ -732,7 +756,7 @@ def list_launch_plan_names(project, domain, host, insecure, token, limit, show_a List the names of the launch plans under the scope specified by {project, domain}. """ _welcome_message() - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) _click.echo("Launch Plan Names Found in {}:{}\n".format(_tt(project), _tt(domain))) while True: @@ -776,7 +800,7 @@ def list_active_launch_plans(project, domain, host, insecure, token, limit, show _click.echo("Active Launch Plan Found in {}:{}\n".format(_tt(project), _tt(domain))) _click.echo("{:30} {:50} {:80}".format('Schedule', 'Version', 'Urn')) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) while True: active_lps, next_token = client.list_active_launch_plans_paginated( @@ -836,7 +860,7 @@ def list_launch_plan_versions(project, domain, name, host, insecure, token, limi _click.echo("Launch Plan Versions Found for {}:{}:{}\n".format(_tt(project), _tt(domain), _tt(name))) _click.echo("{:50} {:80} {:30} {:15}".format('Version', 'Urn', "Schedule", "Schedule State")) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) while True: lp_list, next_token = client.list_launch_plans_paginated( @@ -894,7 +918,7 @@ def get_launch_plan(urn, host, insecure): The URN of a launch plan is in the form of ``lp::::`` """ _welcome_message() - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) _click.echo(_tt(client.get_launch_plan(_identifier.Identifier.from_python_std(urn)))) # TODO: Print launch plan pretty _click.echo("") @@ -911,7 +935,7 @@ def get_active_launch_plan(project, domain, name, host, insecure): List the versions of all the launch plans under the scope specified by {project, domain}. """ _welcome_message() - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) lp = client.get_active_launch_plan( _common_models.NamedEntityIdentifier( @@ -1030,7 +1054,7 @@ def relaunch_execution(project, domain, name, host, insecure, urn, principal, ve Users should use the get-execution and get-launch-plan commands to ascertain the names of inputs to use. """ _welcome_message() - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) _click.echo("Relaunching execution {}\n".format(_tt(urn))) existing_workflow_execution_identifier = _identifier.WorkflowExecutionIdentifier.from_python_std(urn) @@ -1107,7 +1131,7 @@ def terminate_execution(host, insecure, cause, urn=None): -u lp:flyteexamples:development:some-execution:abc123 """ _welcome_message() - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) _click.echo("Killing the following executions:\n") _click.echo("{:100} {:40}".format("Urn", "Cause")) @@ -1159,7 +1183,7 @@ def list_executions(project, domain, host, insecure, token, limit, show_all, fil _click.echo("Executions Found in {}:{}\n".format(_tt(project), _tt(domain))) _click.echo("{:100} {:40} {:10}".format("Urn", "Name", "Status")) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) while True: exec_ids, next_token = client.list_executions_paginated( @@ -1404,7 +1428,7 @@ def get_execution(urn, host, insecure, show_io, verbose): The URN of an execution is in the form of ``ex:::`` """ _welcome_message() - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) e = client.get_execution(_identifier.WorkflowExecutionIdentifier.from_python_std(urn)) node_execs = _get_all_node_executions(client, workflow_execution_identifier=e.id) _render_node_executions(client, node_execs, show_io, verbose, host, insecure, wf_execution=e) @@ -1418,7 +1442,7 @@ def get_execution(urn, host, insecure, show_io, verbose): @_verbose_option def get_child_executions(urn, host, insecure, show_io, verbose): _welcome_message() - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) node_execs = _get_all_node_executions( client, task_execution_identifier=_identifier.TaskExecutionIdentifier.from_python_std(urn) @@ -1437,7 +1461,7 @@ def register_project(identifier, name, host, insecure): """ _welcome_message() - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) client.register_project(_Project(identifier, name)) _click.echo("Registered project [id: {}, name: {}]".format(identifier, name)) diff --git a/flytekit/configuration/creds.py b/flytekit/configuration/creds.py index dd95e53bad..29137bb046 100644 --- a/flytekit/configuration/creds.py +++ b/flytekit/configuration/creds.py @@ -20,3 +20,11 @@ This is the redirect uri registered with the app which handles authorization for a Flyte deployment. More details here: https://www.oauth.com/oauth2-servers/redirect-uris/. """ + + +AUTHORIZATION_METADATA_KEY = _config_common.FlyteStringConfigurationEntry('credentials', 'authorization_metadata_key', + default="authorization") +""" +The authorization metadata key used for passing access tokens in gRPC requests. +Traditionally this value is 'authorization' however it is made configurable. +""" diff --git a/flytekit/configuration/platform.py b/flytekit/configuration/platform.py index 06440dcfd8..0f922f0e17 100644 --- a/flytekit/configuration/platform.py +++ b/flytekit/configuration/platform.py @@ -4,3 +4,7 @@ URL = _config_common.FlyteRequiredStringConfigurationEntry('platform', 'url') INSECURE = _config_common.FlyteBoolConfigurationEntry('platform', 'insecure', default=False) +AUTH = _config_common.FlyteBoolConfigurationEntry('platform', 'auth', default=False) +""" +Whether to use auth when communicating with the Flyte platform. +""" From 1bda5da6965eb6e97664332f8f213454a1d7a736 Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Wed, 23 Oct 2019 14:49:09 -0700 Subject: [PATCH 04/40] add credential type, misc clean-up --- flytekit/clis/auth/auth.py | 4 ++-- flytekit/clis/flyte_cli/main.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/flytekit/clis/auth/auth.py b/flytekit/clis/auth/auth.py index 9c0b4fbffe..980608a9c0 100644 --- a/flytekit/clis/auth/auth.py +++ b/flytekit/clis/auth/auth.py @@ -87,7 +87,7 @@ def do_GET(self): self.end_headers() self.handle_login(dict(_parse_qsl(url.query))) else: - self.send_response(404) + self.send_response(_StatusCodes.NOT_FOUND) def handle_login(self, data): self.server.handle_authorization_code(AuthorizationCode(data['code'], data['state'])) @@ -156,7 +156,7 @@ def __init__(self, auth_endpoint=None, token_endpoint=None, client_id=None, redi "code_challenge_method": "S256", } - # Initialize token request flow + # Initiate token request flow self._request_authorization_code() # Start a server to handle the callback url. self._start_callback_server() diff --git a/flytekit/clis/flyte_cli/main.py b/flytekit/clis/flyte_cli/main.py index 9170fee4d6..045854d6de 100644 --- a/flytekit/clis/flyte_cli/main.py +++ b/flytekit/clis/flyte_cli/main.py @@ -250,7 +250,7 @@ def _fetch_metadata(): credentials = _credentials_access.get_credentials() _keyring.set_password(_metadata, _metadata_access_token, credentials.access_token) access_token = credentials.access_token - return [(_creds_config.AUTHORIZATION_METADATA_KEY.get(), access_token)] + return [(_creds_config.AUTHORIZATION_METADATA_KEY.get(), "Bearer {}".format(access_token))] _HOST_URL_ENV = _os.environ.get(_platform_config.URL.env_var, None) From 2a350c7ff0ba5899e70003f59301472786225552 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Wed, 13 Nov 2019 10:34:45 -0800 Subject: [PATCH 05/40] removing a redirect follow, comments, adding a sample default discovery endpoint --- flytekit/clis/auth/auth.py | 5 ++++- flytekit/clis/auth/discovery.py | 8 -------- flytekit/configuration/creds.py | 15 +++++++++------ 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/flytekit/clis/auth/auth.py b/flytekit/clis/auth/auth.py index d869ceb4fe..4ecaf26306 100644 --- a/flytekit/clis/auth/auth.py +++ b/flytekit/clis/auth/auth.py @@ -14,7 +14,10 @@ from http import client as _StatusCodes except ImportError: # Python 2 import httplib as _StatusCodes -import BaseHTTPServer as _BaseHTTPServer +try: # Python 3 + import http.server as _BaseHTTPServer +except ImportError: # Python 2 + import BaseHTTPServer as _BaseHTTPServer try: # Python 3 import urllib.parse as _urlparse diff --git a/flytekit/clis/auth/discovery.py b/flytekit/clis/auth/discovery.py index 68fb92512d..14fec8e6b9 100644 --- a/flytekit/clis/auth/discovery.py +++ b/flytekit/clis/auth/discovery.py @@ -53,14 +53,6 @@ def get_authorization_endpoints(self): resp = _requests.get( url=self._discovery_url, ) - # Follow at most one redirect. - if resp.status_code == _StatusCodes.FOUND: - redirect_location = resp.headers['Location'] - if redirect_location is None: - raise ValueError('Received a 302 but no follow up location was provided in headers') - resp = _requests.get( - url=redirect_location, - ) response_body = resp.json() if response_body[_authorization_endpoint_key] is None: diff --git a/flytekit/configuration/creds.py b/flytekit/configuration/creds.py index 29137bb046..7973197de7 100644 --- a/flytekit/configuration/creds.py +++ b/flytekit/configuration/creds.py @@ -2,10 +2,10 @@ from flytekit.configuration import common as _config_common -DISCOVERY_ENDPOINT = _config_common.FlyteStringConfigurationEntry('credentials', 'discovery_endpoint', default=None) +DISCOVERY_ENDPOINT = _config_common.FlyteStringConfigurationEntry('credentials', 'discovery_endpoint', default='https://company.idp.com/.well-known/oauth-authorization-server') """ -This endpoint fetches authorization server metadata as describe in this proposal: -https://tools.ietf.org/id/draft-ietf-oauth-discovery-08.html. +This endpoint fetches authorization server metadata as described in: +https://tools.ietf.org/html/rfc8414 The endpoint path can be relative or absolute. """ @@ -15,13 +15,16 @@ More details here: https://www.oauth.com/oauth2-servers/client-registration/client-id-secret/. """ -REDIRECT_URI = _config_common.FlyteStringConfigurationEntry('credentials', 'redirect_uri', default=None) +REDIRECT_URI = _config_common.FlyteStringConfigurationEntry('credentials', 'redirect_uri', default="http://localhost:53593/callback") """ -This is the redirect uri registered with the app which handles authorization for a Flyte deployment. +This is the callback uri registered with the app which handles authorization for a Flyte deployment. +Please note the hardcoded port number. Ideally we would not do this, but some IDPs do not allow wildcards for +the URL, which means we have to use the same port every time. This is the only reason this is a configuration option, +otherwise, we'd just hardcode the callback path as a constant. +FYI, to see if a given port is already in use, run `sudo lsof -i :` if on a Linux system. More details here: https://www.oauth.com/oauth2-servers/redirect-uris/. """ - AUTHORIZATION_METADATA_KEY = _config_common.FlyteStringConfigurationEntry('credentials', 'authorization_metadata_key', default="authorization") """ From 1cb1620cf038991df5150a4ea6e3425a553884b3 Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Wed, 13 Nov 2019 14:48:40 -0800 Subject: [PATCH 06/40] checkpoint --- flytekit/clis/auth/auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/clis/auth/auth.py b/flytekit/clis/auth/auth.py index 4ecaf26306..85dbd6283d 100644 --- a/flytekit/clis/auth/auth.py +++ b/flytekit/clis/auth/auth.py @@ -167,7 +167,7 @@ def __init__(self, auth_endpoint=None, token_endpoint=None, client_id=None, redi self._start_callback_server() def _start_callback_server(self): - server_url = _urlparse(self._redirect_uri) + server_url = _urlparse.urlparse(self._redirect_uri) server_address = (server_url.hostname, server_url.port) q = _Queue() server = OAuthHTTPServer(server_address, OAuthCallbackHandler, redirect_path=server_url.path, queue=q) From 2b9d1785242ef51d868b7622ab603f724fb7173f Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Wed, 13 Nov 2019 15:03:39 -0800 Subject: [PATCH 07/40] proper sequence of events --- flytekit/clis/auth/auth.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/flytekit/clis/auth/auth.py b/flytekit/clis/auth/auth.py index 85dbd6283d..faed954821 100644 --- a/flytekit/clis/auth/auth.py +++ b/flytekit/clis/auth/auth.py @@ -162,22 +162,25 @@ def __init__(self, auth_endpoint=None, token_endpoint=None, client_id=None, redi } # Initiate token request flow - self._request_authorization_code() - # Start a server to handle the callback url. - self._start_callback_server() - - def _start_callback_server(self): - server_url = _urlparse.urlparse(self._redirect_uri) - server_address = (server_url.hostname, server_url.port) q = _Queue() - server = OAuthHTTPServer(server_address, OAuthCallbackHandler, redirect_path=server_url.path, queue=q) + # First prepare the callback server in the background + server = self._create_callback_server(q) server_process = _Process(target=server.handle_request) - server_process.start() + + # Send the call to request the authorization code + self._request_authorization_code() + + # Request the access token once the auth code has been received. auth_code = q.get() server_process.terminate() self.request_access_token(auth_code) + def _create_callback_server(self, q): + server_url = _urlparse.urlparse(self._redirect_uri) + server_address = (server_url.hostname, server_url.port) + return OAuthHTTPServer(server_address, OAuthCallbackHandler, redirect_path=server_url.path, queue=q) + def _request_authorization_code(self): resp = _requests.get( url=self._auth_endpoint, From 1cc9999167bc45b9e205906fe4a0454132f81b6a Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Wed, 13 Nov 2019 15:04:37 -0800 Subject: [PATCH 08/40] adding config file read, sample config file, and a test that will be deleted, only using for debugging at the moment --- flytekit/clis/auth/auth.py | 2 +- flytekit/clis/flyte_cli/example.config | 11 +++++++++++ flytekit/clis/flyte_cli/main.py | 17 +++++++++++++++++ tests/flytekit/unit/cli/auth/test_auth.py | 21 +++++++++++++++++++++ 4 files changed, 50 insertions(+), 1 deletion(-) create mode 100644 flytekit/clis/flyte_cli/example.config create mode 100644 tests/flytekit/unit/cli/auth/test_auth.py diff --git a/flytekit/clis/auth/auth.py b/flytekit/clis/auth/auth.py index 4ecaf26306..85dbd6283d 100644 --- a/flytekit/clis/auth/auth.py +++ b/flytekit/clis/auth/auth.py @@ -167,7 +167,7 @@ def __init__(self, auth_endpoint=None, token_endpoint=None, client_id=None, redi self._start_callback_server() def _start_callback_server(self): - server_url = _urlparse(self._redirect_uri) + server_url = _urlparse.urlparse(self._redirect_uri) server_address = (server_url.hostname, server_url.port) q = _Queue() server = OAuthHTTPServer(server_address, OAuthCallbackHandler, redirect_path=server_url.path, queue=q) diff --git a/flytekit/clis/flyte_cli/example.config b/flytekit/clis/flyte_cli/example.config new file mode 100644 index 0000000000..cd0bc07223 --- /dev/null +++ b/flytekit/clis/flyte_cli/example.config @@ -0,0 +1,11 @@ +# This is an example of what a config file may look like. +# Place this in ~/.flyte/config and flyte-cli will pick it up automatically + +[platform] +url=flyte.company.com +auth=true + +[credentials] +discovery_endpoint=http://corp.idp.com/.well-known/oauth-authorization-server +client_id=123abc123 +redirect_uri=http://localhost:53593/callback diff --git a/flytekit/clis/flyte_cli/main.py b/flytekit/clis/flyte_cli/main.py index 045854d6de..391f9ece13 100644 --- a/flytekit/clis/flyte_cli/main.py +++ b/flytekit/clis/flyte_cli/main.py @@ -8,6 +8,7 @@ import click as _click import six as _six + from flyteidl.core import literals_pb2 as _literals_pb2 from flytekit import __version__ @@ -21,6 +22,7 @@ from flytekit.common.types import helpers as _type_helpers from flytekit.common.utils import load_proto_from_file as _load_proto_from_file from flytekit.configuration import creds as _creds_config, platform as _platform_config +from flytekit.configuration import set_flyte_config_file from flytekit.interfaces.data import data_proxy as _data_proxy from flytekit.models import common as _common_models, filters as _filters, launch_plan as _launch_plan, literals as \ _literals @@ -36,10 +38,25 @@ _metadata = "metadata" # Identifies the access token username for storing and fetching access token values in keyring. _metadata_access_token = "access_token" +# Similar to how kubectl has a config file in the users home directory, this Flyte CLI will also look for one. +# The format of this config file is the same as a workflow's config file, except that the relevant fields are different. +# Please see the example.config file +_default_config_file_path = ".flyte/config" def _welcome_message(): _click.secho("Welcome to Flyte CLI! Version: {}".format(_tt(__version__)), bold=True) + _detect_default_config_file() + + +def _detect_default_config_file(): + home = _os.path.expanduser("~") + config_file = _os.path.join(home, _default_config_file_path) + if home and _os.path.exists(config_file): + _click.secho("Using default config file at {}".format(_tt(config_file)), fg='blue') + set_flyte_config_file(config_file_path=config_file) + else: + _click.secho("Config file not found at default location, relying on environment variables instead", fg='blue') def _get_io_string(literal_map, verbose=False): diff --git a/tests/flytekit/unit/cli/auth/test_auth.py b/tests/flytekit/unit/cli/auth/test_auth.py new file mode 100644 index 0000000000..fc1bfc647d --- /dev/null +++ b/tests/flytekit/unit/cli/auth/test_auth.py @@ -0,0 +1,21 @@ +from flytekit.clis.auth.discovery import DiscoveryClient +from flytekit.clis.auth.auth import AuthorizationClient as _AuthorizationClient +from flytekit.configuration.creds import ( + DISCOVERY_ENDPOINT as _DISCOVERY_ENDPOINT, + REDIRECT_URI as _REDIRECT_URI, + CLIENT_ID as _CLIENT_ID +) +from flytekit.configuration.platform import URL as _URL + + +def test_discovery_client(): + discovery_endpoint = _DISCOVERY_ENDPOINT.get() + discovery_client = DiscoveryClient(discovery_url=discovery_endpoint) + authorization_endpoints = discovery_client.get_authorization_endpoints() + print("///////////////////////////////////////|||||||||||||||||||||||||||||||||||||||||") + print(authorization_endpoints.auth_endpoint) + print(authorization_endpoints.token_endpoint) + # client = _AuthorizationClient(redirect_uri=_REDIRECT_URI.get(), + # client_id=_CLIENT_ID.get(), + # auth_endpoint=authorization_endpoints.auth_endpoint, + # token_endpoint=authorization_endpoints.token_endpoint) From 1cbcf3a64461b790bfe304ceaef5bafa167e6c6e Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Wed, 13 Nov 2019 15:50:45 -0800 Subject: [PATCH 09/40] let ur browser handle redirects --- flytekit/clis/auth/auth.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/flytekit/clis/auth/auth.py b/flytekit/clis/auth/auth.py index faed954821..132c1bf12b 100644 --- a/flytekit/clis/auth/auth.py +++ b/flytekit/clis/auth/auth.py @@ -19,6 +19,7 @@ except ImportError: # Python 2 import BaseHTTPServer as _BaseHTTPServer +import urllib as _urllib try: # Python 3 import urllib.parse as _urlparse except ImportError: # Python 2 @@ -182,17 +183,10 @@ def _create_callback_server(self, q): return OAuthHTTPServer(server_address, OAuthCallbackHandler, redirect_path=server_url.path, queue=q) def _request_authorization_code(self): - resp = _requests.get( - url=self._auth_endpoint, - params=self._params, - allow_redirects=False - ) - if resp.status_code == _StatusCodes.FOUND: - # Follow the redirect - redirect_location = resp.headers['Location'] - if redirect_location is None: - raise ValueError('Received a 302 but no follow up location was provided in headers') - _webbrowser.open_new_tab(redirect_location) + scheme, netloc, path, _, _, _ = _urlparse.urlparse(self._auth_endpoint) + query = _urllib.urlencode(self._params) + endpoint = _urlparse.urlunparse((scheme, netloc, path, None, query, None)) + _webbrowser.open_new_tab(endpoint) def request_access_token(self, auth_code): if self._state != auth_code.state: From 6ccb3272e75262b433c80c7f58cdd4b01bc9959d Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Thu, 14 Nov 2019 09:48:04 -0800 Subject: [PATCH 10/40] pull out urlencode separately, doesn't work for python3 --- flytekit/clis/auth/auth.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/flytekit/clis/auth/auth.py b/flytekit/clis/auth/auth.py index 132c1bf12b..a88db84d7c 100644 --- a/flytekit/clis/auth/auth.py +++ b/flytekit/clis/auth/auth.py @@ -19,11 +19,12 @@ except ImportError: # Python 2 import BaseHTTPServer as _BaseHTTPServer -import urllib as _urllib try: # Python 3 import urllib.parse as _urlparse + from _urlparse import urlencode as _urlencode except ImportError: # Python 2 import urlparse as _urlparse + from urllib import urlencode as _urlencode _code_verifier_length = 64 _random_seed_length = 40 @@ -184,7 +185,7 @@ def _create_callback_server(self, q): def _request_authorization_code(self): scheme, netloc, path, _, _, _ = _urlparse.urlparse(self._auth_endpoint) - query = _urllib.urlencode(self._params) + query = _urlencode(self._params) endpoint = _urlparse.urlunparse((scheme, netloc, path, None, query, None)) _webbrowser.open_new_tab(endpoint) From 58f97e9503c0df9bd856340911d668964f43eb41 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Thu, 14 Nov 2019 10:02:59 -0800 Subject: [PATCH 11/40] fixing python3 import, changing keyring var names --- flytekit/clis/auth/auth.py | 2 +- flytekit/clis/flyte_cli/main.py | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/flytekit/clis/auth/auth.py b/flytekit/clis/auth/auth.py index a88db84d7c..cd18cd2e6f 100644 --- a/flytekit/clis/auth/auth.py +++ b/flytekit/clis/auth/auth.py @@ -21,7 +21,7 @@ try: # Python 3 import urllib.parse as _urlparse - from _urlparse import urlencode as _urlencode + from urllib.parse import urlencode as _urlencode except ImportError: # Python 2 import urlparse as _urlparse from urllib import urlencode as _urlencode diff --git a/flytekit/clis/flyte_cli/main.py b/flytekit/clis/flyte_cli/main.py index 391f9ece13..d94702b60d 100644 --- a/flytekit/clis/flyte_cli/main.py +++ b/flytekit/clis/flyte_cli/main.py @@ -34,10 +34,11 @@ _tt = _six.text_type -# Identifies the 'metadata' service used for storing passwords in keyring -_metadata = "metadata" -# Identifies the access token username for storing and fetching access token values in keyring. -_metadata_access_token = "access_token" +# Identifies the service used for storing passwords in keyring +_keyring_service_name = "flytecli" +# Identifies the key used for storing and fetching from keyring. In our case, instead of a username as the keyring docs +# suggest, we are storing a user's oidc. +_keyring_storage_key = "access_token" # Similar to how kubectl has a config file in the users home directory, this Flyte CLI will also look for one. # The format of this config file is the same as a workflow's config file, except that the relevant fields are different. # Please see the example.config file @@ -262,10 +263,10 @@ def _fetch_metadata(): if not _platform_config.AUTH.get(): # nothing to do return None - access_token = _keyring.get_password(_metadata, _metadata_access_token) + access_token = _keyring.get_password(_keyring_service_name, _keyring_storage_key) if access_token is None: credentials = _credentials_access.get_credentials() - _keyring.set_password(_metadata, _metadata_access_token, credentials.access_token) + _keyring.set_password(_keyring_service_name, _keyring_storage_key, credentials.access_token) access_token = credentials.access_token return [(_creds_config.AUTHORIZATION_METADATA_KEY.get(), "Bearer {}".format(access_token))] From 986d49b9515fb3033532e16d7fd82ba72d1a8aee Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Fri, 15 Nov 2019 13:49:25 -0800 Subject: [PATCH 12/40] some super pythonic code for handling token refresh --- flytekit/clients/helpers.py | 24 ++++++++++++ flytekit/clients/raw.py | 44 ++++++++++++++++++++-- flytekit/clis/auth/auth.py | 54 +++++++++++++++++++-------- flytekit/clis/auth/credentials.py | 19 +++++++--- flytekit/clis/flyte_cli/main.py | 62 ++++++++++--------------------- 5 files changed, 136 insertions(+), 67 deletions(-) diff --git a/flytekit/clients/helpers.py b/flytekit/clients/helpers.py index 783d801df0..bee5b46d8c 100644 --- a/flytekit/clients/helpers.py +++ b/flytekit/clients/helpers.py @@ -1,3 +1,12 @@ +import keyring as _keyring + +from flytekit.clis.auth import credentials as _credentials_access + +# Identifies the service used for storing passwords in keyring +_keyring_service_name = "flytecli" +# Identifies the key used for storing and fetching from keyring. In our case, instead of a username as the keyring docs +# suggest, we are storing a user's oidc. +_keyring_storage_key = "access_token" def iterate_node_executions( @@ -75,3 +84,18 @@ def iterate_task_executions(client, node_execution_identifier, limit=None, filte if not next_token: break token = next_token + + +# Fetches an existing authorization access token if it exists in keyring or sets if it's unassigned. +def get_global_access_token(): + access_token = _keyring.get_password(_keyring_service_name, _keyring_storage_key) + if access_token is None: + access_token = set_global_access_token() + return access_token + + +# Assigns and returns the authorization access token in keyring. +def set_global_access_token(): + credentials = _credentials_access.get_client().credentials + _keyring.set_password(_keyring_service_name, _keyring_storage_key, credentials.access_token) + return credentials.access_token diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index d0adaca27f..dcda8bfba6 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -1,17 +1,46 @@ from __future__ import absolute_import + from grpc import insecure_channel as _insecure_channel, secure_channel as _secure_channel, RpcError as _RpcError, \ StatusCode as _GrpcStatusCode, ssl_channel_credentials as _ssl_channel_credentials from flyteidl.service import admin_pb2_grpc as _admin_service from flytekit.common.exceptions import user as _user_exceptions import six as _six +from flytekit.configuration import creds as _creds_config, platform as _platform_config + +from flytekit.clis.auth import credentials as _credentials_access +from flytekit.clients.helpers import ( + get_global_access_token as _get_global_access_token, set_global_access_token as _set_global_access_token +) + + +def _try_three_times(fn): + def handler(*args, **kwargs): + attempt = 0 + while True: + try: + attempt += 1 + return fn(*args, **kwargs) + except Exception as e: + if attempt >= 3: + raise e + else: + print('retrying') + return handler def _handle_rpc_error(fn): + @_try_three_times def handler(*args, **kwargs): try: return fn(*args, **kwargs) except _RpcError as e: - if e.code() == _GrpcStatusCode.ALREADY_EXISTS: + if e.code() == _GrpcStatusCode.UNAUTHENTICATED: + _credentials_access.get_client().refresh_access_token() + _set_global_access_token() + flyte_client = args[0] + flyte_client.refresh_metadata() + raise + elif e.code() == _GrpcStatusCode.ALREADY_EXISTS: raise _user_exceptions.FlyteEntityAlreadyExistsException(_six.text_type(e)) else: raise @@ -25,8 +54,9 @@ class RawSynchronousFlyteClient(object): This client should be usable regardless of environment in which this is used. In other words, configurations should be explicit as opposed to inferred from the environment or a configuration file. """ + authentication_client = None - def __init__(self, url, insecure=False, credentials=None, options=None, metadata=None): + def __init__(self, url, insecure=False, credentials=None, options=None): """ Initializes a gRPC channel to the given Flyte Admin service. @@ -50,7 +80,15 @@ def __init__(self, url, insecure=False, credentials=None, options=None, metadata options=list((options or {}).items()) ) self._stub = _admin_service.AdminServiceStub(self._channel) - self._metadata = metadata + self._metadata = None + self.refresh_metadata() + + def refresh_metadata(self): + if not _platform_config.AUTH.get(): + # nothing to do + self._metadata = None + access_token = _get_global_access_token() + self._metadata = [(_creds_config.AUTHORIZATION_METADATA_KEY.get(), "Bearer {}".format(access_token))] #################################################################################################################### # diff --git a/flytekit/clis/auth/auth.py b/flytekit/clis/auth/auth.py index cd18cd2e6f..4d3dfedf6b 100644 --- a/flytekit/clis/auth/auth.py +++ b/flytekit/clis/auth/auth.py @@ -151,11 +151,13 @@ def __init__(self, auth_endpoint=None, token_endpoint=None, client_id=None, redi state = _generate_state_parameter() self._state = state self._credentials = None + self._refresh_token = None + self._headers = {'content-type': "application/x-www-form-urlencoded"} self._params = { "client_id": client_id, # This must match the Client ID of the OAuth application. "response_type": "code", # Indicates the authorization code grant - "scope": "openid", # ensures that the /token endpoint returns an ID token + "scope": "openid offline_access", # ensures that the /token endpoint returns an ID and refresh token # callback location where the user-agent will be directed to. "redirect_uri": self._redirect_uri, "state": state, @@ -189,6 +191,25 @@ def _request_authorization_code(self): endpoint = _urlparse.urlunparse((scheme, netloc, path, None, query, None)) _webbrowser.open_new_tab(endpoint) + def _initialize_credentials(self, auth_token_resp): + + """ + The auth_token_resp body is of the form: + { + "access_token": "foo", + "refresh_token": "bar", + "id_token": "baz", + "token_type": "Bearer" + } + """ + response_body = auth_token_resp.json() + if "access_token" not in response_body: + raise ValueError('Expected "access_token" in response from oauth server') + if "refresh_token" in response_body: + self._refresh_token = response_body["refresh_token"] + + self._credentials = Credentials(access_token=response_body["access_token"], id_token=response_body["id_token"]) + def request_access_token(self, auth_code): if self._state != auth_code.state: raise ValueError("Unexpected state parameter [{}] passed".format(auth_code.state)) @@ -200,7 +221,7 @@ def request_access_token(self, auth_code): resp = _requests.post( url=self._token_endpoint, data=self._params, - headers={'content-type': "application/x-www-form-urlencoded"}, + headers=self._headers, allow_redirects=False ) if resp.status_code != _StatusCodes.OK: @@ -208,21 +229,24 @@ def request_access_token(self, auth_code): # https://auth0.com/docs/flows/guides/device-auth/call-api-device-auth#token-responses raise Exception('Failed to request access token with response: [{}] {}'.format( resp.status_code, resp.content)) + self._initialize_credentials(resp) - """ - The response body is of the form: - { - "access_token": "foo", - "refresh_token": "bar", - "id_token": "baz", - "token_type": "Bearer" - } - """ - response_body = resp.json() - if "access_token" not in response_body: - raise ValueError('Expected "access_token" in response from oauth server') + def refresh_access_token(self): + if self._refresh_token is None: + raise ValueError("no refresh token available with which to refresh authorization credentials") - self._credentials = Credentials(access_token=response_body["access_token"], id_token=response_body["id_token"]) + resp = _requests.post( + url=self._token_endpoint, + data={'grant_type': 'refresh_token', + 'client_id': self._client_id, + 'refresh_token': self._refresh_token}, + headers=self._headers, + allow_redirects=False + ) + if resp.status_code != _StatusCodes.OK: + raise Exception('Failed to request access token with response: [{}] {}'.format( + resp.status_code, resp.content)) + self._initialize_credentials(resp) @property def credentials(self): diff --git a/flytekit/clis/auth/credentials.py b/flytekit/clis/auth/credentials.py index 2009d3a73d..fcf4d1cf76 100644 --- a/flytekit/clis/auth/credentials.py +++ b/flytekit/clis/auth/credentials.py @@ -20,15 +20,22 @@ def _is_absolute(url): return bool(_urlparse.urlparse(url).netloc) -def get_credentials(): +# Lazy initialized authorization client singleton +_authorization_client = None + + +def get_client(): + global _authorization_client + if _authorization_client is not None: + return _authorization_client discovery_endpoint = _DISCOVERY_ENDPOINT.get() if not _is_absolute(discovery_endpoint): discovery_endpoint = _urlparse.urljoin(_URL.get(), discovery_endpoint) discovery_client = _DiscoveryClient(discovery_url=discovery_endpoint) authorization_endpoints = discovery_client.get_authorization_endpoints() - client = _AuthorizationClient(redirect_uri=_REDIRECT_URI.get(), - client_id=_CLIENT_ID.get(), - auth_endpoint=authorization_endpoints.auth_endpoint, - token_endpoint=authorization_endpoints.token_endpoint) - return client.credentials + _authorization_client =\ + _AuthorizationClient(redirect_uri=_REDIRECT_URI.get(), client_id=_CLIENT_ID.get(), + auth_endpoint=authorization_endpoints.auth_endpoint, + token_endpoint=authorization_endpoints.token_endpoint) + return _authorization_client diff --git a/flytekit/clis/flyte_cli/main.py b/flytekit/clis/flyte_cli/main.py index d94702b60d..6dc9b34c40 100644 --- a/flytekit/clis/flyte_cli/main.py +++ b/flytekit/clis/flyte_cli/main.py @@ -1,7 +1,6 @@ from __future__ import absolute_import import importlib as _importlib -import keyring as _keyring import os as _os import sys as _sys import stat as _stat @@ -13,7 +12,6 @@ from flytekit import __version__ from flytekit.clients import friendly as _friendly_client -from flytekit.clis.auth import credentials as _credentials_access from flytekit.clis.helpers import construct_literal_map_from_variable_map as _construct_literal_map_from_variable_map, \ construct_literal_map_from_parameter_map as _construct_literal_map_from_parameter_map, \ parse_args_into_dict as _parse_args_into_dict, str2bool as _str2bool @@ -21,7 +19,7 @@ from flytekit.common.core import identifier as _identifier from flytekit.common.types import helpers as _type_helpers from flytekit.common.utils import load_proto_from_file as _load_proto_from_file -from flytekit.configuration import creds as _creds_config, platform as _platform_config +from flytekit.configuration import platform as _platform_config from flytekit.configuration import set_flyte_config_file from flytekit.interfaces.data import data_proxy as _data_proxy from flytekit.models import common as _common_models, filters as _filters, launch_plan as _launch_plan, literals as \ @@ -34,11 +32,6 @@ _tt = _six.text_type -# Identifies the service used for storing passwords in keyring -_keyring_service_name = "flytecli" -# Identifies the key used for storing and fetching from keyring. In our case, instead of a username as the keyring docs -# suggest, we are storing a user's oidc. -_keyring_storage_key = "access_token" # Similar to how kubectl has a config file in the users home directory, this Flyte CLI will also look for one. # The format of this config file is the same as a workflow's config file, except that the relevant fields are different. # Please see the example.config file @@ -228,7 +221,7 @@ def _terminate_one_execution(client, urn, cause, shouldPrint=True): def _update_one_launch_plan(urn, host, insecure, state): - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) if state == "active": state = _launch_plan.LaunchPlanState.ACTIVE @@ -254,23 +247,6 @@ def _render_schedule_expr(lp): return "{:30}".format(sched_expr) -def _fetch_metadata(): - """ - Initializes gRPC metadata according to parameters set in the flyte config. - Currently this is used to pass security credentials when authentication is enabled. - :return [(Text, Text)]: metadata pairs to be transmitted to the service-side of the RPC. - """ - if not _platform_config.AUTH.get(): - # nothing to do - return None - access_token = _keyring.get_password(_keyring_service_name, _keyring_storage_key) - if access_token is None: - credentials = _credentials_access.get_credentials() - _keyring.set_password(_keyring_service_name, _keyring_storage_key, credentials.access_token) - access_token = credentials.access_token - return [(_creds_config.AUTHORIZATION_METADATA_KEY.get(), "Bearer {}".format(access_token))] - - _HOST_URL_ENV = _os.environ.get(_platform_config.URL.env_var, None) _INSECURE_ENV = _os.environ.get(_platform_config.INSECURE.env_var, None) _PROJECT_FLAGS = ["-p", "--project"] @@ -560,7 +536,7 @@ def list_task_names(project, domain, host, insecure, token, limit, show_all, sor a specific project and domain. """ _welcome_message() - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) _click.echo("Task Names Found in {}:{}\n".format(_tt(project), _tt(domain))) while True: @@ -602,7 +578,7 @@ def list_task_versions(project, domain, name, host, insecure, token, limit, show versions of that particular task (identifiable by {Project, Domain, Name}). """ _welcome_message() - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) _click.echo("Task Versions Found for {}:{}:{}\n".format(_tt(project), _tt(domain), _tt(name or '*'))) _click.echo("{:50} {:40}".format('Version', 'Urn')) @@ -641,7 +617,7 @@ def get_task(urn, host, insecure): The URN of the versioned task is in the form of ``tsk::::``. """ _welcome_message() - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) t = client.get_task(_identifier.Identifier.from_python_std(urn)) _click.echo(_tt(t)) _click.echo("") @@ -667,7 +643,7 @@ def list_workflow_names(project, domain, host, insecure, token, limit, show_all, List the names of the workflows under a scope specified by ``{project, domain}``. """ _welcome_message() - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) _click.echo("Workflow Names Found in {}:{}\n".format(_tt(project), _tt(domain))) while True: @@ -709,7 +685,7 @@ def list_workflow_versions(project, domain, name, host, insecure, token, limit, versions of that particular workflow (identifiable by ``{project, domain, name}``). """ _welcome_message() - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) _click.echo("Workflow Versions Found for {}:{}:{}\n".format(_tt(project), _tt(domain), _tt(name or '*'))) _click.echo("{:50} {:40}".format('Version', 'Urn')) @@ -748,7 +724,7 @@ def get_workflow(urn, host, insecure): ``wf::::`` """ _welcome_message() - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) _click.echo(client.get_workflow(_identifier.Identifier.from_python_std(urn))) # TODO: Print workflow pretty _click.echo("") @@ -774,7 +750,7 @@ def list_launch_plan_names(project, domain, host, insecure, token, limit, show_a List the names of the launch plans under the scope specified by {project, domain}. """ _welcome_message() - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) _click.echo("Launch Plan Names Found in {}:{}\n".format(_tt(project), _tt(domain))) while True: @@ -818,7 +794,7 @@ def list_active_launch_plans(project, domain, host, insecure, token, limit, show _click.echo("Active Launch Plan Found in {}:{}\n".format(_tt(project), _tt(domain))) _click.echo("{:30} {:50} {:80}".format('Schedule', 'Version', 'Urn')) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) while True: active_lps, next_token = client.list_active_launch_plans_paginated( @@ -878,7 +854,7 @@ def list_launch_plan_versions(project, domain, name, host, insecure, token, limi _click.echo("Launch Plan Versions Found for {}:{}:{}\n".format(_tt(project), _tt(domain), _tt(name))) _click.echo("{:50} {:80} {:30} {:15}".format('Version', 'Urn', "Schedule", "Schedule State")) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) while True: lp_list, next_token = client.list_launch_plans_paginated( @@ -936,7 +912,7 @@ def get_launch_plan(urn, host, insecure): The URN of a launch plan is in the form of ``lp::::`` """ _welcome_message() - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) _click.echo(_tt(client.get_launch_plan(_identifier.Identifier.from_python_std(urn)))) # TODO: Print launch plan pretty _click.echo("") @@ -953,7 +929,7 @@ def get_active_launch_plan(project, domain, name, host, insecure): List the versions of all the launch plans under the scope specified by {project, domain}. """ _welcome_message() - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) lp = client.get_active_launch_plan( _common_models.NamedEntityIdentifier( @@ -1072,7 +1048,7 @@ def relaunch_execution(project, domain, name, host, insecure, urn, principal, ve Users should use the get-execution and get-launch-plan commands to ascertain the names of inputs to use. """ _welcome_message() - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) _click.echo("Relaunching execution {}\n".format(_tt(urn))) existing_workflow_execution_identifier = _identifier.WorkflowExecutionIdentifier.from_python_std(urn) @@ -1149,7 +1125,7 @@ def terminate_execution(host, insecure, cause, urn=None): -u lp:flyteexamples:development:some-execution:abc123 """ _welcome_message() - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) _click.echo("Killing the following executions:\n") _click.echo("{:100} {:40}".format("Urn", "Cause")) @@ -1201,7 +1177,7 @@ def list_executions(project, domain, host, insecure, token, limit, show_all, fil _click.echo("Executions Found in {}:{}\n".format(_tt(project), _tt(domain))) _click.echo("{:100} {:40} {:10}".format("Urn", "Name", "Status")) - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) while True: exec_ids, next_token = client.list_executions_paginated( @@ -1446,7 +1422,7 @@ def get_execution(urn, host, insecure, show_io, verbose): The URN of an execution is in the form of ``ex:::`` """ _welcome_message() - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) e = client.get_execution(_identifier.WorkflowExecutionIdentifier.from_python_std(urn)) node_execs = _get_all_node_executions(client, workflow_execution_identifier=e.id) _render_node_executions(client, node_execs, show_io, verbose, host, insecure, wf_execution=e) @@ -1460,7 +1436,7 @@ def get_execution(urn, host, insecure, show_io, verbose): @_verbose_option def get_child_executions(urn, host, insecure, show_io, verbose): _welcome_message() - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) node_execs = _get_all_node_executions( client, task_execution_identifier=_identifier.TaskExecutionIdentifier.from_python_std(urn) @@ -1479,7 +1455,7 @@ def register_project(identifier, name, host, insecure): """ _welcome_message() - client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure, metadata=_fetch_metadata()) + client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) client.register_project(_Project(identifier, name)) _click.echo("Registered project [id: {}, name: {}]".format(identifier, name)) From ff842e26bf3dc8d388d12315024c113945d3a140 Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Fri, 15 Nov 2019 16:41:32 -0800 Subject: [PATCH 13/40] use retry library --- flytekit/clients/raw.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index dcda8bfba6..c9ee3f7b7f 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -5,6 +5,7 @@ from flyteidl.service import admin_pb2_grpc as _admin_service from flytekit.common.exceptions import user as _user_exceptions import six as _six +import retry as _retry from flytekit.configuration import creds as _creds_config, platform as _platform_config from flytekit.clis.auth import credentials as _credentials_access @@ -13,23 +14,8 @@ ) -def _try_three_times(fn): - def handler(*args, **kwargs): - attempt = 0 - while True: - try: - attempt += 1 - return fn(*args, **kwargs) - except Exception as e: - if attempt >= 3: - raise e - else: - print('retrying') - return handler - - def _handle_rpc_error(fn): - @_try_three_times + @_retry.retry(tries=3) def handler(*args, **kwargs): try: return fn(*args, **kwargs) From 292841792d5ec8b97af36d2472e35fb9a2c3c855 Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Fri, 22 Nov 2019 11:40:37 -0800 Subject: [PATCH 14/40] customizable rpc error callback fn --- flytekit/clients/raw.py | 110 ++++++++++++++++------------- flytekit/common/exceptions/user.py | 4 ++ 2 files changed, 66 insertions(+), 48 deletions(-) diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index c9ee3f7b7f..027b1d6a37 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -5,7 +5,6 @@ from flyteidl.service import admin_pb2_grpc as _admin_service from flytekit.common.exceptions import user as _user_exceptions import six as _six -import retry as _retry from flytekit.configuration import creds as _creds_config, platform as _platform_config from flytekit.clis.auth import credentials as _credentials_access @@ -14,23 +13,38 @@ ) -def _handle_rpc_error(fn): - @_retry.retry(tries=3) - def handler(*args, **kwargs): - try: - return fn(*args, **kwargs) - except _RpcError as e: - if e.code() == _GrpcStatusCode.UNAUTHENTICATED: - _credentials_access.get_client().refresh_access_token() - _set_global_access_token() - flyte_client = args[0] - flyte_client.refresh_metadata() - raise - elif e.code() == _GrpcStatusCode.ALREADY_EXISTS: - raise _user_exceptions.FlyteEntityAlreadyExistsException(_six.text_type(e)) - else: - raise - return handler +def _refresh_credentials(flyte_client): + _credentials_access.get_client().refresh_access_token() + _set_global_access_token() + flyte_client.refresh_metadata() + + +def _handle_rpc_error(raw_cli_fn=None, callback_function=None): + def wrapper(raw_cli_fn): + def handler(*args, **kwargs): + retries = 3 + try: + for i in range(retries): + try: + return raw_cli_fn(*args, **kwargs) + except _RpcError as e: + if e.code() == _GrpcStatusCode.UNAUTHENTICATED: + if i == (retries - 1): + # Exit the loop and wrap the authentication error. + raise _user_exceptions.FlyteAuthenticationException(_six.text_type(e)) + callback_function(args[0]) + else: + raise + except _RpcError as e: + if e.code() == _GrpcStatusCode.ALREADY_EXISTS: + raise _user_exceptions.FlyteEntityAlreadyExistsException(_six.text_type(e)) + else: + raise + return handler + if raw_cli_fn: + return wrapper(raw_cli_fn) + else: + return wrapper class RawSynchronousFlyteClient(object): @@ -82,7 +96,7 @@ def refresh_metadata(self): # #################################################################################################################### - @_handle_rpc_error + @_handle_rpc_error() def create_task(self, task_create_request): """ This will create a task definition in the Admin database. Once successful, the task object can be @@ -103,7 +117,7 @@ def create_task(self, task_create_request): """ return self._stub.CreateTask(task_create_request, metadata=self._metadata) - @_handle_rpc_error + @_handle_rpc_error(callback_function=_refresh_credentials) def list_task_ids_paginated(self, identifier_list_request): """ This returns a page of identifiers for the tasks for a given project and domain. Filters can also be @@ -129,7 +143,7 @@ def list_task_ids_paginated(self, identifier_list_request): """ return self._stub.ListTaskIds(identifier_list_request, metadata=self._metadata) - @_handle_rpc_error + @_handle_rpc_error(callback_function=_refresh_credentials) def list_tasks_paginated(self, resource_list_request): """ This returns a page of task metadata for tasks in a given project and domain. Optionally, @@ -151,7 +165,7 @@ def list_tasks_paginated(self, resource_list_request): """ return self._stub.ListTasks(resource_list_request, metadata=self._metadata) - @_handle_rpc_error + @_handle_rpc_error(callback_function=_refresh_credentials) def get_task(self, get_object_request): """ This returns a single task for a given identifier. @@ -168,7 +182,7 @@ def get_task(self, get_object_request): # #################################################################################################################### - @_handle_rpc_error + @_handle_rpc_error(callback_function=_refresh_credentials) def create_workflow(self, workflow_create_request): """ This will create a workflow definition in the Admin database. Once successful, the workflow object can be @@ -189,7 +203,7 @@ def create_workflow(self, workflow_create_request): """ return self._stub.CreateWorkflow(workflow_create_request, metadata=self._metadata) - @_handle_rpc_error + @_handle_rpc_error(callback_function=_refresh_credentials) def list_workflow_ids_paginated(self, identifier_list_request): """ This returns a page of identifiers for the workflows for a given project and domain. Filters can also be @@ -215,7 +229,7 @@ def list_workflow_ids_paginated(self, identifier_list_request): """ return self._stub.ListWorkflowIds(identifier_list_request, metadata=self._metadata) - @_handle_rpc_error + @_handle_rpc_error(callback_function=_refresh_credentials) def list_workflows_paginated(self, resource_list_request): """ This returns a page of workflow meta-information for workflows in a given project and domain. Optionally, @@ -237,7 +251,7 @@ def list_workflows_paginated(self, resource_list_request): """ return self._stub.ListWorkflows(resource_list_request, metadata=self._metadata) - @_handle_rpc_error + @_handle_rpc_error(callback_function=_refresh_credentials) def get_workflow(self, get_object_request): """ This returns a single workflow for a given identifier. @@ -254,7 +268,7 @@ def get_workflow(self, get_object_request): # #################################################################################################################### - @_handle_rpc_error + @_handle_rpc_error(callback_function=_refresh_credentials) def create_launch_plan(self, launch_plan_create_request): """ This will create a launch plan definition in the Admin database. Once successful, the launch plan object can be @@ -278,7 +292,7 @@ def create_launch_plan(self, launch_plan_create_request): # TODO: List endpoints when they come in - @_handle_rpc_error + @_handle_rpc_error(callback_function=_refresh_credentials) def get_launch_plan(self, object_get_request): """ Retrieves a launch plan entity. @@ -288,7 +302,7 @@ def get_launch_plan(self, object_get_request): """ return self._stub.GetLaunchPlan(object_get_request, metadata=self._metadata) - @_handle_rpc_error + @_handle_rpc_error(callback_function=_refresh_credentials) def get_active_launch_plan(self, active_launch_plan_request): """ Retrieves a launch plan entity. @@ -298,7 +312,7 @@ def get_active_launch_plan(self, active_launch_plan_request): """ return self._stub.GetActiveLaunchPlan(active_launch_plan_request, metadata=self._metadata) - @_handle_rpc_error + @_handle_rpc_error(callback_function=_refresh_credentials) def update_launch_plan(self, update_request): """ Allows updates to a launch plan at a given identifier. Currently, a launch plan may only have it's state @@ -309,7 +323,7 @@ def update_launch_plan(self, update_request): """ return self._stub.UpdateLaunchPlan(update_request, metadata=self._metadata) - @_handle_rpc_error + @_handle_rpc_error(callback_function=_refresh_credentials) def list_launch_plan_ids_paginated(self, identifier_list_request): """ Lists launch plan named identifiers for a given project and domain. @@ -319,7 +333,7 @@ def list_launch_plan_ids_paginated(self, identifier_list_request): """ return self._stub.ListLaunchPlanIds(identifier_list_request, metadata=self._metadata) - @_handle_rpc_error + @_handle_rpc_error(callback_function=_refresh_credentials) def list_launch_plans_paginated(self, resource_list_request): """ Lists Launch Plans for a given Identifer (project, domain, name) @@ -329,7 +343,7 @@ def list_launch_plans_paginated(self, resource_list_request): """ return self._stub.ListLaunchPlans(resource_list_request, metadata=self._metadata) - @_handle_rpc_error + @_handle_rpc_error(callback_function=_refresh_credentials) def list_active_launch_plans_paginated(self, active_launch_plan_list_request): """ Lists Active Launch Plans for a given (project, domain) @@ -345,7 +359,7 @@ def list_active_launch_plans_paginated(self, active_launch_plan_list_request): # #################################################################################################################### - @_handle_rpc_error + @_handle_rpc_error(callback_function=_refresh_credentials) def create_execution(self, create_execution_request): """ This will create an execution for the given execution spec. @@ -354,7 +368,7 @@ def create_execution(self, create_execution_request): """ return self._stub.CreateExecution(create_execution_request, metadata=self._metadata) - @_handle_rpc_error + @_handle_rpc_error(callback_function=_refresh_credentials) def get_execution(self, get_object_request): """ Returns an execution of a workflow entity. @@ -364,7 +378,7 @@ def get_execution(self, get_object_request): """ return self._stub.GetExecution(get_object_request, metadata=self._metadata) - @_handle_rpc_error + @_handle_rpc_error(callback_function=_refresh_credentials) def get_execution_data(self, get_execution_data_request): """ Returns signed URLs to LiteralMap blobs for an execution's inputs and outputs (when available). @@ -374,7 +388,7 @@ def get_execution_data(self, get_execution_data_request): """ return self._stub.GetExecutionData(get_execution_data_request) - @_handle_rpc_error + @_handle_rpc_error(callback_function=_refresh_credentials) def list_executions_paginated(self, resource_list_request): """ Lists the executions for a given identifier. @@ -384,7 +398,7 @@ def list_executions_paginated(self, resource_list_request): """ return self._stub.ListExecutions(resource_list_request, metadata=self._metadata) - @_handle_rpc_error + @_handle_rpc_error(callback_function=_refresh_credentials) def terminate_execution(self, terminate_execution_request): """ :param flyteidl.admin.execution_pb2.TerminateExecutionRequest terminate_execution_request: @@ -392,7 +406,7 @@ def terminate_execution(self, terminate_execution_request): """ return self._stub.TerminateExecution(terminate_execution_request, metadata=self._metadata) - @_handle_rpc_error + @_handle_rpc_error(callback_function=_refresh_credentials) def relaunch_execution(self, relaunch_execution_request): """ :param flyteidl.admin.execution_pb2.ExecutionRelaunchRequest relaunch_execution_request: @@ -406,7 +420,7 @@ def relaunch_execution(self, relaunch_execution_request): # #################################################################################################################### - @_handle_rpc_error + @_handle_rpc_error(callback_function=_refresh_credentials) def get_node_execution(self, node_execution_request): """ :param flyteidl.admin.node_execution_pb2.NodeExecutionGetRequest node_execution_request: @@ -414,7 +428,7 @@ def get_node_execution(self, node_execution_request): """ return self._stub.GetNodeExecution(node_execution_request, metadata=self._metadata) - @_handle_rpc_error + @_handle_rpc_error(callback_function=_refresh_credentials) def get_node_execution_data(self, get_node_execution_data_request): """ Returns signed URLs to LiteralMap blobs for a node execution's inputs and outputs (when available). @@ -424,7 +438,7 @@ def get_node_execution_data(self, get_node_execution_data_request): """ return self._stub.GetNodeExecutionData(get_node_execution_data_request) - @_handle_rpc_error + @_handle_rpc_error(callback_function=_refresh_credentials) def list_node_executions_paginated(self, node_execution_list_request): """ :param flyteidl.admin.node_execution_pb2.NodeExecutionListRequest node_execution_list_request: @@ -432,7 +446,7 @@ def list_node_executions_paginated(self, node_execution_list_request): """ return self._stub.ListNodeExecutions(node_execution_list_request, metadata=self._metadata) - @_handle_rpc_error + @_handle_rpc_error(callback_function=_refresh_credentials) def list_node_executions_for_task_paginated(self, node_execution_for_task_list_request): """ :param flyteidl.admin.node_execution_pb2.NodeExecutionListRequest node_execution_for_task_list_request: @@ -446,7 +460,7 @@ def list_node_executions_for_task_paginated(self, node_execution_for_task_list_r # #################################################################################################################### - @_handle_rpc_error + @_handle_rpc_error(callback_function=_refresh_credentials) def get_task_execution(self, task_execution_request): """ :param flyteidl.admin.task_execution_pb2.TaskExecutionGetRequest task_execution_request: @@ -454,7 +468,7 @@ def get_task_execution(self, task_execution_request): """ return self._stub.GetTaskExecution(task_execution_request, metadata=self._metadata) - @_handle_rpc_error + @_handle_rpc_error(callback_function=_refresh_credentials) def get_task_execution_data(self, get_task_execution_data_request): """ Returns signed URLs to LiteralMap blobs for a task execution's inputs and outputs (when available). @@ -464,7 +478,7 @@ def get_task_execution_data(self, get_task_execution_data_request): """ return self._stub.GetTaskExecutionData(get_task_execution_data_request) - @_handle_rpc_error + @_handle_rpc_error(callback_function=_refresh_credentials) def list_task_executions_paginated(self, task_execution_list_request): """ :param flyteidl.admin.task_execution_pb2.TaskExecutionListRequest task_execution_list_request: @@ -478,7 +492,7 @@ def list_task_executions_paginated(self, task_execution_list_request): # #################################################################################################################### - @_handle_rpc_error + @_handle_rpc_error(callback_function=_refresh_credentials) def list_projects(self, project_list_request): """ This will return a list of the projects registered with the Flyte Admin Service @@ -487,7 +501,7 @@ def list_projects(self, project_list_request): """ return self._stub.ListProjects(project_list_request, metadata=self._metadata) - @_handle_rpc_error + @_handle_rpc_error(callback_function=_refresh_credentials) def register_project(self, project_register_request): """ Registers a project along with a set of domains. diff --git a/flytekit/common/exceptions/user.py b/flytekit/common/exceptions/user.py index d050c05667..d68cb1985f 100644 --- a/flytekit/common/exceptions/user.py +++ b/flytekit/common/exceptions/user.py @@ -72,3 +72,7 @@ class FlyteTimeout(FlyteAssertion): class FlyteRecoverableException(FlyteUserException, _Recoverable): _ERROR_CODE = "USER:Recoverable" + + +class FlyteAuthenticationException(FlyteAssertion): + _ERROR_CODE = "USER:AuthenticationError" From e50df71b51e06cb5e08c8c82c37783be8d30500a Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Mon, 25 Nov 2019 10:32:30 -0800 Subject: [PATCH 15/40] configurable refresh handlers --- flytekit/clients/raw.py | 123 +++++++++++++++++--------------- flytekit/configuration/creds.py | 10 +++ 2 files changed, 77 insertions(+), 56 deletions(-) diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index 027b1d6a37..1f0491d59f 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -13,38 +13,49 @@ ) -def _refresh_credentials(flyte_client): +def _refresh_credentials_standard(flyte_client): _credentials_access.get_client().refresh_access_token() _set_global_access_token() flyte_client.refresh_metadata() -def _handle_rpc_error(raw_cli_fn=None, callback_function=None): - def wrapper(raw_cli_fn): - def handler(*args, **kwargs): - retries = 3 - try: - for i in range(retries): - try: - return raw_cli_fn(*args, **kwargs) - except _RpcError as e: - if e.code() == _GrpcStatusCode.UNAUTHENTICATED: - if i == (retries - 1): - # Exit the loop and wrap the authentication error. - raise _user_exceptions.FlyteAuthenticationException(_six.text_type(e)) - callback_function(args[0]) - else: - raise - except _RpcError as e: - if e.code() == _GrpcStatusCode.ALREADY_EXISTS: - raise _user_exceptions.FlyteEntityAlreadyExistsException(_six.text_type(e)) - else: - raise - return handler - if raw_cli_fn: - return wrapper(raw_cli_fn) +def _refresh_credentials_basic(flyte_client): + # TODO(wild-endeavor): fill me in + pass + + +def _get_refresh_handler(auth_mode): + if auth_mode == "standard": + return _refresh_credentials_standard + elif auth_mode == "basic": + return _refresh_credentials_basic else: - return wrapper + raise ValueError( + "Invalid auth mode [{}] specified. Please update the creds config to use a valid value".format(auth_mode)) + + +def _handle_rpc_error(fn): + def handler(*args, **kwargs): + retries = 3 + try: + for i in range(retries): + try: + return fn(*args, **kwargs) + except _RpcError as e: + if e.code() == _GrpcStatusCode.UNAUTHENTICATED: + if i == (retries - 1): + # Exit the loop and wrap the authentication error. + raise _user_exceptions.FlyteAuthenticationException(_six.text_type(e)) + refresh_handler_fn = _get_refresh_handler(_creds_config.AUTH_MODE.get()) + refresh_handler_fn(args[0]) + else: + raise + except _RpcError as e: + if e.code() == _GrpcStatusCode.ALREADY_EXISTS: + raise _user_exceptions.FlyteEntityAlreadyExistsException(_six.text_type(e)) + else: + raise + return handler class RawSynchronousFlyteClient(object): @@ -96,7 +107,7 @@ def refresh_metadata(self): # #################################################################################################################### - @_handle_rpc_error() + @_handle_rpc_error def create_task(self, task_create_request): """ This will create a task definition in the Admin database. Once successful, the task object can be @@ -117,7 +128,7 @@ def create_task(self, task_create_request): """ return self._stub.CreateTask(task_create_request, metadata=self._metadata) - @_handle_rpc_error(callback_function=_refresh_credentials) + @_handle_rpc_error def list_task_ids_paginated(self, identifier_list_request): """ This returns a page of identifiers for the tasks for a given project and domain. Filters can also be @@ -143,7 +154,7 @@ def list_task_ids_paginated(self, identifier_list_request): """ return self._stub.ListTaskIds(identifier_list_request, metadata=self._metadata) - @_handle_rpc_error(callback_function=_refresh_credentials) + @_handle_rpc_error def list_tasks_paginated(self, resource_list_request): """ This returns a page of task metadata for tasks in a given project and domain. Optionally, @@ -165,7 +176,7 @@ def list_tasks_paginated(self, resource_list_request): """ return self._stub.ListTasks(resource_list_request, metadata=self._metadata) - @_handle_rpc_error(callback_function=_refresh_credentials) + @_handle_rpc_error def get_task(self, get_object_request): """ This returns a single task for a given identifier. @@ -182,7 +193,7 @@ def get_task(self, get_object_request): # #################################################################################################################### - @_handle_rpc_error(callback_function=_refresh_credentials) + @_handle_rpc_error def create_workflow(self, workflow_create_request): """ This will create a workflow definition in the Admin database. Once successful, the workflow object can be @@ -203,7 +214,7 @@ def create_workflow(self, workflow_create_request): """ return self._stub.CreateWorkflow(workflow_create_request, metadata=self._metadata) - @_handle_rpc_error(callback_function=_refresh_credentials) + @_handle_rpc_error def list_workflow_ids_paginated(self, identifier_list_request): """ This returns a page of identifiers for the workflows for a given project and domain. Filters can also be @@ -229,7 +240,7 @@ def list_workflow_ids_paginated(self, identifier_list_request): """ return self._stub.ListWorkflowIds(identifier_list_request, metadata=self._metadata) - @_handle_rpc_error(callback_function=_refresh_credentials) + @_handle_rpc_error def list_workflows_paginated(self, resource_list_request): """ This returns a page of workflow meta-information for workflows in a given project and domain. Optionally, @@ -251,7 +262,7 @@ def list_workflows_paginated(self, resource_list_request): """ return self._stub.ListWorkflows(resource_list_request, metadata=self._metadata) - @_handle_rpc_error(callback_function=_refresh_credentials) + @_handle_rpc_error def get_workflow(self, get_object_request): """ This returns a single workflow for a given identifier. @@ -268,7 +279,7 @@ def get_workflow(self, get_object_request): # #################################################################################################################### - @_handle_rpc_error(callback_function=_refresh_credentials) + @_handle_rpc_error def create_launch_plan(self, launch_plan_create_request): """ This will create a launch plan definition in the Admin database. Once successful, the launch plan object can be @@ -292,7 +303,7 @@ def create_launch_plan(self, launch_plan_create_request): # TODO: List endpoints when they come in - @_handle_rpc_error(callback_function=_refresh_credentials) + @_handle_rpc_error def get_launch_plan(self, object_get_request): """ Retrieves a launch plan entity. @@ -302,7 +313,7 @@ def get_launch_plan(self, object_get_request): """ return self._stub.GetLaunchPlan(object_get_request, metadata=self._metadata) - @_handle_rpc_error(callback_function=_refresh_credentials) + @_handle_rpc_error def get_active_launch_plan(self, active_launch_plan_request): """ Retrieves a launch plan entity. @@ -312,7 +323,7 @@ def get_active_launch_plan(self, active_launch_plan_request): """ return self._stub.GetActiveLaunchPlan(active_launch_plan_request, metadata=self._metadata) - @_handle_rpc_error(callback_function=_refresh_credentials) + @_handle_rpc_error def update_launch_plan(self, update_request): """ Allows updates to a launch plan at a given identifier. Currently, a launch plan may only have it's state @@ -323,7 +334,7 @@ def update_launch_plan(self, update_request): """ return self._stub.UpdateLaunchPlan(update_request, metadata=self._metadata) - @_handle_rpc_error(callback_function=_refresh_credentials) + @_handle_rpc_error def list_launch_plan_ids_paginated(self, identifier_list_request): """ Lists launch plan named identifiers for a given project and domain. @@ -333,7 +344,7 @@ def list_launch_plan_ids_paginated(self, identifier_list_request): """ return self._stub.ListLaunchPlanIds(identifier_list_request, metadata=self._metadata) - @_handle_rpc_error(callback_function=_refresh_credentials) + @_handle_rpc_error def list_launch_plans_paginated(self, resource_list_request): """ Lists Launch Plans for a given Identifer (project, domain, name) @@ -343,7 +354,7 @@ def list_launch_plans_paginated(self, resource_list_request): """ return self._stub.ListLaunchPlans(resource_list_request, metadata=self._metadata) - @_handle_rpc_error(callback_function=_refresh_credentials) + @_handle_rpc_error def list_active_launch_plans_paginated(self, active_launch_plan_list_request): """ Lists Active Launch Plans for a given (project, domain) @@ -359,7 +370,7 @@ def list_active_launch_plans_paginated(self, active_launch_plan_list_request): # #################################################################################################################### - @_handle_rpc_error(callback_function=_refresh_credentials) + @_handle_rpc_error def create_execution(self, create_execution_request): """ This will create an execution for the given execution spec. @@ -368,7 +379,7 @@ def create_execution(self, create_execution_request): """ return self._stub.CreateExecution(create_execution_request, metadata=self._metadata) - @_handle_rpc_error(callback_function=_refresh_credentials) + @_handle_rpc_error def get_execution(self, get_object_request): """ Returns an execution of a workflow entity. @@ -378,7 +389,7 @@ def get_execution(self, get_object_request): """ return self._stub.GetExecution(get_object_request, metadata=self._metadata) - @_handle_rpc_error(callback_function=_refresh_credentials) + @_handle_rpc_error def get_execution_data(self, get_execution_data_request): """ Returns signed URLs to LiteralMap blobs for an execution's inputs and outputs (when available). @@ -388,7 +399,7 @@ def get_execution_data(self, get_execution_data_request): """ return self._stub.GetExecutionData(get_execution_data_request) - @_handle_rpc_error(callback_function=_refresh_credentials) + @_handle_rpc_error def list_executions_paginated(self, resource_list_request): """ Lists the executions for a given identifier. @@ -398,7 +409,7 @@ def list_executions_paginated(self, resource_list_request): """ return self._stub.ListExecutions(resource_list_request, metadata=self._metadata) - @_handle_rpc_error(callback_function=_refresh_credentials) + @_handle_rpc_error def terminate_execution(self, terminate_execution_request): """ :param flyteidl.admin.execution_pb2.TerminateExecutionRequest terminate_execution_request: @@ -406,7 +417,7 @@ def terminate_execution(self, terminate_execution_request): """ return self._stub.TerminateExecution(terminate_execution_request, metadata=self._metadata) - @_handle_rpc_error(callback_function=_refresh_credentials) + @_handle_rpc_error def relaunch_execution(self, relaunch_execution_request): """ :param flyteidl.admin.execution_pb2.ExecutionRelaunchRequest relaunch_execution_request: @@ -420,7 +431,7 @@ def relaunch_execution(self, relaunch_execution_request): # #################################################################################################################### - @_handle_rpc_error(callback_function=_refresh_credentials) + @_handle_rpc_error def get_node_execution(self, node_execution_request): """ :param flyteidl.admin.node_execution_pb2.NodeExecutionGetRequest node_execution_request: @@ -428,7 +439,7 @@ def get_node_execution(self, node_execution_request): """ return self._stub.GetNodeExecution(node_execution_request, metadata=self._metadata) - @_handle_rpc_error(callback_function=_refresh_credentials) + @_handle_rpc_error def get_node_execution_data(self, get_node_execution_data_request): """ Returns signed URLs to LiteralMap blobs for a node execution's inputs and outputs (when available). @@ -438,7 +449,7 @@ def get_node_execution_data(self, get_node_execution_data_request): """ return self._stub.GetNodeExecutionData(get_node_execution_data_request) - @_handle_rpc_error(callback_function=_refresh_credentials) + @_handle_rpc_error def list_node_executions_paginated(self, node_execution_list_request): """ :param flyteidl.admin.node_execution_pb2.NodeExecutionListRequest node_execution_list_request: @@ -446,7 +457,7 @@ def list_node_executions_paginated(self, node_execution_list_request): """ return self._stub.ListNodeExecutions(node_execution_list_request, metadata=self._metadata) - @_handle_rpc_error(callback_function=_refresh_credentials) + @_handle_rpc_error def list_node_executions_for_task_paginated(self, node_execution_for_task_list_request): """ :param flyteidl.admin.node_execution_pb2.NodeExecutionListRequest node_execution_for_task_list_request: @@ -460,7 +471,7 @@ def list_node_executions_for_task_paginated(self, node_execution_for_task_list_r # #################################################################################################################### - @_handle_rpc_error(callback_function=_refresh_credentials) + @_handle_rpc_error def get_task_execution(self, task_execution_request): """ :param flyteidl.admin.task_execution_pb2.TaskExecutionGetRequest task_execution_request: @@ -468,7 +479,7 @@ def get_task_execution(self, task_execution_request): """ return self._stub.GetTaskExecution(task_execution_request, metadata=self._metadata) - @_handle_rpc_error(callback_function=_refresh_credentials) + @_handle_rpc_error def get_task_execution_data(self, get_task_execution_data_request): """ Returns signed URLs to LiteralMap blobs for a task execution's inputs and outputs (when available). @@ -478,7 +489,7 @@ def get_task_execution_data(self, get_task_execution_data_request): """ return self._stub.GetTaskExecutionData(get_task_execution_data_request) - @_handle_rpc_error(callback_function=_refresh_credentials) + @_handle_rpc_error def list_task_executions_paginated(self, task_execution_list_request): """ :param flyteidl.admin.task_execution_pb2.TaskExecutionListRequest task_execution_list_request: @@ -492,7 +503,7 @@ def list_task_executions_paginated(self, task_execution_list_request): # #################################################################################################################### - @_handle_rpc_error(callback_function=_refresh_credentials) + @_handle_rpc_error def list_projects(self, project_list_request): """ This will return a list of the projects registered with the Flyte Admin Service @@ -501,7 +512,7 @@ def list_projects(self, project_list_request): """ return self._stub.ListProjects(project_list_request, metadata=self._metadata) - @_handle_rpc_error(callback_function=_refresh_credentials) + @_handle_rpc_error def register_project(self, project_register_request): """ Registers a project along with a set of domains. diff --git a/flytekit/configuration/creds.py b/flytekit/configuration/creds.py index 7973197de7..7f0d66bc9f 100644 --- a/flytekit/configuration/creds.py +++ b/flytekit/configuration/creds.py @@ -31,3 +31,13 @@ The authorization metadata key used for passing access tokens in gRPC requests. Traditionally this value is 'authorization' however it is made configurable. """ + +# TODO(katrogan) Make this an enum rather than a string config entry +AUTH_MODE = _config_common.FlyteStringConfigurationEntry('credentials', 'auth_mode', default="standard") +""" +The auth mode defines the behavior used to request and refresh credentials. The currently supported modes include: +- 'standard' This uses the pkce-enhanced authorization code flow by opening a browser window to initiate credentials + access. +- 'basic' This uses cert-based auth in which the end user enters his/her username and password and public key encryption + is used to facilitate authentication. +""" From 41c6c2b992dcb28065a6a5e2eb53f9de22134592 Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Mon, 25 Nov 2019 10:53:53 -0800 Subject: [PATCH 16/40] expose set_access_token --- flytekit/clients/raw.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index 1f0491d59f..c1f4085ec7 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -94,12 +94,15 @@ def __init__(self, url, insecure=False, credentials=None, options=None): self._metadata = None self.refresh_metadata() + def set_access_token(self, access_token): + self._metadata = [(_creds_config.AUTHORIZATION_METADATA_KEY.get(), "Bearer {}".format(access_token))] + def refresh_metadata(self): if not _platform_config.AUTH.get(): # nothing to do self._metadata = None access_token = _get_global_access_token() - self._metadata = [(_creds_config.AUTHORIZATION_METADATA_KEY.get(), "Bearer {}".format(access_token))] + self.set_access_token(access_token) #################################################################################################################### # From fd1083eed046456618b69cfe25340728599404ff Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Mon, 25 Nov 2019 16:19:44 -0800 Subject: [PATCH 17/40] remove refresh_metadata() --- flytekit/clients/raw.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index c1f4085ec7..def4dd10f4 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -16,7 +16,12 @@ def _refresh_credentials_standard(flyte_client): _credentials_access.get_client().refresh_access_token() _set_global_access_token() - flyte_client.refresh_metadata() + + if not _platform_config.AUTH.get(): + # nothing to do + return + access_token = _get_global_access_token() + flyte_client.set_access_token(access_token) def _refresh_credentials_basic(flyte_client): @@ -92,18 +97,10 @@ def __init__(self, url, insecure=False, credentials=None, options=None): ) self._stub = _admin_service.AdminServiceStub(self._channel) self._metadata = None - self.refresh_metadata() def set_access_token(self, access_token): self._metadata = [(_creds_config.AUTHORIZATION_METADATA_KEY.get(), "Bearer {}".format(access_token))] - def refresh_metadata(self): - if not _platform_config.AUTH.get(): - # nothing to do - self._metadata = None - access_token = _get_global_access_token() - self.set_access_token(access_token) - #################################################################################################################### # # Task Endpoints From 33e15865bd2a80818b7cfd89a50b211048c93b77 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 25 Nov 2019 16:54:49 -0800 Subject: [PATCH 18/40] pyflyte component of auth (#52) This merges in the pyflyte components of the auth change into the main flytekit auth PR. Will continue to work on tests and such from that PR. --- flytekit/clients/raw.py | 24 +++++- flytekit/clis/auth/credentials.py | 8 ++ flytekit/clis/auth/discovery.py | 2 + flytekit/clis/sdk_in_container/basic_auth.py | 83 +++++++++++++++++++ flytekit/clis/sdk_in_container/register.py | 2 +- flytekit/configuration/creds.py | 26 +++++- flytekit/configuration/platform.py | 4 +- flytekit/engines/flyte/engine.py | 12 ++- tests/flytekit/unit/cli/auth/test_auth.py | 26 ++---- .../unit/cli/pyflyte/test_basic_auth.py | 33 ++++++++ 10 files changed, 189 insertions(+), 31 deletions(-) create mode 100644 flytekit/clis/sdk_in_container/basic_auth.py create mode 100644 tests/flytekit/unit/cli/pyflyte/test_basic_auth.py diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index def4dd10f4..ea33a3902e 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -4,6 +4,13 @@ StatusCode as _GrpcStatusCode, ssl_channel_credentials as _ssl_channel_credentials from flyteidl.service import admin_pb2_grpc as _admin_service from flytekit.common.exceptions import user as _user_exceptions +from flytekit.configuration.creds import ( + CLIENT_ID as _CLIENT_ID, + CLIENT_CREDENTIALS_SECRET_LOCATION as _CREDENTIALS_SECRET_FILE, + CLIENT_CREDENTIALS_SCOPE as _SCOPE, +) +from flytekit.clis.sdk_in_container import basic_auth +import logging import six as _six from flytekit.configuration import creds as _creds_config, platform as _platform_config @@ -25,8 +32,14 @@ def _refresh_credentials_standard(flyte_client): def _refresh_credentials_basic(flyte_client): - # TODO(wild-endeavor): fill me in - pass + auth_endpoints = _credentials_access.get_authorization_endpoints() + token_endpoint = auth_endpoints.token_endpoint + client_secret = basic_auth.get_secret() + logging.debug('Basic authorization flow with client id {} scope {}', _CLIENT_ID.get(), _SCOPE.get()) + authorization_header = basic_auth.get_basic_authorization_header(_CLIENT_ID.get(), client_secret) + token, expires_in = basic_auth.get_token(token_endpoint, authorization_header, _SCOPE.get()) + logging.info('Retrieved new token, expires in {}'.format(expires_in)) + flyte_client.set_access_token(token) def _get_refresh_handler(auth_mode): @@ -41,7 +54,7 @@ def _get_refresh_handler(auth_mode): def _handle_rpc_error(fn): def handler(*args, **kwargs): - retries = 3 + retries = 2 try: for i in range(retries): try: @@ -70,7 +83,6 @@ class RawSynchronousFlyteClient(object): This client should be usable regardless of environment in which this is used. In other words, configurations should be explicit as opposed to inferred from the environment or a configuration file. """ - authentication_client = None def __init__(self, url, insecure=False, credentials=None, options=None): """ @@ -101,6 +113,10 @@ def __init__(self, url, insecure=False, credentials=None, options=None): def set_access_token(self, access_token): self._metadata = [(_creds_config.AUTHORIZATION_METADATA_KEY.get(), "Bearer {}".format(access_token))] + def force_auth_flow(self): + refresh_handler_fn = _get_refresh_handler(_creds_config.AUTH_MODE.get()) + refresh_handler_fn(self) + #################################################################################################################### # # Task Endpoints diff --git a/flytekit/clis/auth/credentials.py b/flytekit/clis/auth/credentials.py index fcf4d1cf76..e0dde5cc76 100644 --- a/flytekit/clis/auth/credentials.py +++ b/flytekit/clis/auth/credentials.py @@ -39,3 +39,11 @@ def get_client(): auth_endpoint=authorization_endpoints.auth_endpoint, token_endpoint=authorization_endpoints.token_endpoint) return _authorization_client + + +def get_authorization_endpoints(): + discovery_endpoint = _DISCOVERY_ENDPOINT.get() + if not _is_absolute(discovery_endpoint): + discovery_endpoint = _urlparse.urljoin(_URL.get(), discovery_endpoint) + discovery_client = _DiscoveryClient(discovery_url=discovery_endpoint) + return discovery_client.get_authorization_endpoints() diff --git a/flytekit/clis/auth/discovery.py b/flytekit/clis/auth/discovery.py index 14fec8e6b9..fce6988da2 100644 --- a/flytekit/clis/auth/discovery.py +++ b/flytekit/clis/auth/discovery.py @@ -1,4 +1,5 @@ import requests as _requests +import logging try: # Python 3.5+ from http import HTTPStatus as _StatusCodes @@ -37,6 +38,7 @@ class DiscoveryClient(object): """ def __init__(self, discovery_url=None): + logging.debug("Initializing discovery client with {}".format(discovery_url)) self._discovery_url = discovery_url self._authorization_endpoints = None diff --git a/flytekit/clis/sdk_in_container/basic_auth.py b/flytekit/clis/sdk_in_container/basic_auth.py new file mode 100644 index 0000000000..6f936df838 --- /dev/null +++ b/flytekit/clis/sdk_in_container/basic_auth.py @@ -0,0 +1,83 @@ +from __future__ import absolute_import + +import base64 as _base64 +import logging as _logging + +import requests as _requests + +from flytekit.common.exceptions.base import FlyteException as _FlyteException +from flytekit.configuration.creds import ( + CLIENT_CREDENTIALS_SECRET_LOCATION as _CREDENTIALS_SECRET_FILE, + CLIENT_CREDENTIALS_SECRET as _CREDENTIALS_SECRET, +) + +_utf_8 = 'utf-8' + + +class FlyteAuthenticationException(_FlyteException): + _ERROR_CODE = "FlyteAuthenticationFailed" + + +def get_file_contents(location): + """ + This reads an input file, and returns the string contents, and should be used for reading credentials. + This function will also strip newlines. + + :param Text location: The file path holding the client id or secret + :rtype: Text + """ + with open(location, 'r') as f: + return f.read().replace('\n', '') + + +def get_secret(): + """ + This function will either read in the password from the file path given by the CLIENT_CREDENTIALS_SECRET_LOCATION + config object, or from the environment variable using the CLIENT_CREDENTIALS_SECRET config object. + :rtype: Text + """ + if _CREDENTIALS_SECRET_FILE.get(): + return get_file_contents(_CREDENTIALS_SECRET_FILE.get()) + elif _CREDENTIALS_SECRET.get(): + return _CREDENTIALS_SECRET.get() + raise FlyteAuthenticationException('No secret could be found in either {} or the {} env variable'.format( + _CREDENTIALS_SECRET_FILE.get(), _CREDENTIALS_SECRET.env_var)) + + +def get_basic_authorization_header(client_id, client_secret): + """ + This function transforms the client id and the client secret into a header that conforms with http basic auth. + It joins the id and the secret with a : then base64 encodes it, then adds the appropriate text. + :param Text client_id: + :param Text client_secret: + :rtype: Text + """ + concated = "{}:{}".format(client_id, client_secret) + return "Basic {}".format(str(_base64.b64encode(concated.encode(_utf_8)), _utf_8)) + + +def get_token(token_endpoint, authorization_header, scope): + """ + :param token_endpoint: + :param authorization_header: + :param scope: + :rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration + in seconds + """ + headers = { + 'Authorization': authorization_header, + 'Cache-Control': 'no-cache', + 'Accept': 'application/json', + 'Content-Type': 'application/x-www-form-urlencoded' + } + body = { + 'grant_type': 'client_credentials', + 'scope': scope, + } + response = _requests.post(token_endpoint, data=body, headers=headers) + if response.status_code != 200: + _logging.error("Non-200 ({}) received from IDP: {}".format(response.status_code, response.text)) + raise FlyteAuthenticationException('Non-200 received from IDP') + + response = response.json() + return response['access_token'], response['expires_in'] diff --git a/flytekit/clis/sdk_in_container/register.py b/flytekit/clis/sdk_in_container/register.py index b2e465c959..199405bca3 100644 --- a/flytekit/clis/sdk_in_container/register.py +++ b/flytekit/clis/sdk_in_container/register.py @@ -47,7 +47,7 @@ def register_tasks_only(project, domain, pkgs, test, version): @click.group('register') # --pkgs on the register group is DEPRECATED, use same arg on pyflyte.main instead -@click.option('--pkgs', multiple=True, hidden=True) +@click.option('--pkgs', multiple=True) @click.option('--test', is_flag=True, help='Dry run, do not actually register with Admin') @click.pass_context def register(ctx, pkgs=None, test=None): diff --git a/flytekit/configuration/creds.py b/flytekit/configuration/creds.py index 7f0d66bc9f..876fa727e9 100644 --- a/flytekit/configuration/creds.py +++ b/flytekit/configuration/creds.py @@ -15,7 +15,8 @@ More details here: https://www.oauth.com/oauth2-servers/client-registration/client-id-secret/. """ -REDIRECT_URI = _config_common.FlyteStringConfigurationEntry('credentials', 'redirect_uri', default="http://localhost:53593/callback") +REDIRECT_URI = _config_common.FlyteStringConfigurationEntry('credentials', 'redirect_uri', + default="http://localhost:12345/callback") """ This is the callback uri registered with the app which handles authorization for a Flyte deployment. Please note the hardcoded port number. Ideally we would not do this, but some IDPs do not allow wildcards for @@ -32,7 +33,28 @@ Traditionally this value is 'authorization' however it is made configurable. """ -# TODO(katrogan) Make this an enum rather than a string config entry +CLIENT_CREDENTIALS_SECRET_LOCATION = _config_common.FlyteStringConfigurationEntry( + 'credentials', 'client_secret_location', default=None) +""" +Used for basic auth, which is automatically called during pyflyte. This is the location to look for the password. The +client id config setting is shared across the basic and standard auth flows. +""" + +CLIENT_CREDENTIALS_SECRET = _config_common.FlyteStringConfigurationEntry('credentials', 'client_secret', default=None) +""" +Used for basic auth, which is automatically called during pyflyte. This will allow the Flyte engine to read the +password directly from the environment variable. Note that this is less secure! Please only use this if mounting the +secret as a file is impossible. +""" + + +CLIENT_CREDENTIALS_SCOPE = _config_common.FlyteStringConfigurationEntry('credentials', 'scope', default=None) +""" +Used for basic auth, which is automatically called during pyflyte. This is the scope that will be requested. Because +there is no user explicitly in this auth flow, certain IDPs require a custom scope for basic auth in the configuration +of the authorization server. +""" + AUTH_MODE = _config_common.FlyteStringConfigurationEntry('credentials', 'auth_mode', default="standard") """ The auth mode defines the behavior used to request and refresh credentials. The currently supported modes include: diff --git a/flytekit/configuration/platform.py b/flytekit/configuration/platform.py index ffa652a61d..5ea0bd7185 100644 --- a/flytekit/configuration/platform.py +++ b/flytekit/configuration/platform.py @@ -12,5 +12,7 @@ AUTH = _config_common.FlyteBoolConfigurationEntry('platform', 'auth', default=False) """ -Whether to use auth when communicating with the Flyte platform. +This config setting should not normally be filled in. Whether or not an admin server requires authentication should be +something published by the admin server itself (typically by returning a 401). However, to help with migration, this +config object is here to force the SDK to attempt the auth flow even without prompting by Admin. """ diff --git a/flytekit/engines/flyte/engine.py b/flytekit/engines/flyte/engine.py index 53bf34e708..a326274b6a 100644 --- a/flytekit/engines/flyte/engine.py +++ b/flytekit/engines/flyte/engine.py @@ -6,11 +6,12 @@ from datetime import datetime as _datetime import six as _six +from flyteidl.core import literals_pb2 as _literals_pb2 -from flytekit.clients.helpers import iterate_node_executions as _iterate_node_executions, iterate_task_executions as \ - _iterate_task_executions from flytekit import __version__ as _api_version from flytekit.clients.friendly import SynchronousFlyteClient as _SynchronousFlyteClient +from flytekit.clients.helpers import iterate_node_executions as _iterate_node_executions, iterate_task_executions as \ + _iterate_task_executions from flytekit.common import utils as _common_utils, constants as _constants from flytekit.common.exceptions import user as _user_exceptions, scopes as _exception_scopes from flytekit.configuration import platform as _platform_config, internal as _internal_config, sdk as _sdk_config @@ -21,7 +22,6 @@ literals as _literals, common as _common_models from flytekit.models.admin import workflow as _workflow_model from flytekit.models.core import errors as _error_models, identifier as _identifier -from flyteidl.core import literals_pb2 as _literals_pb2 class _FlyteClientManager(object): @@ -32,7 +32,11 @@ def __init__(self, *args, **kwargs): # TODO: React to changing configs. For now this is frozen for the lifetime of the process, which covers most # TODO: use cases. if type(self)._CLIENT is None: - type(self)._CLIENT = _SynchronousFlyteClient(*args, **kwargs) + c = _SynchronousFlyteClient(*args, **kwargs) + if _platform_config.AUTH.get(): + # Force authentication + c.force_auth_flow() + type(self)._CLIENT = c @property def client(self): diff --git a/tests/flytekit/unit/cli/auth/test_auth.py b/tests/flytekit/unit/cli/auth/test_auth.py index fc1bfc647d..178eec1115 100644 --- a/tests/flytekit/unit/cli/auth/test_auth.py +++ b/tests/flytekit/unit/cli/auth/test_auth.py @@ -1,21 +1,9 @@ -from flytekit.clis.auth.discovery import DiscoveryClient -from flytekit.clis.auth.auth import AuthorizationClient as _AuthorizationClient -from flytekit.configuration.creds import ( - DISCOVERY_ENDPOINT as _DISCOVERY_ENDPOINT, - REDIRECT_URI as _REDIRECT_URI, - CLIENT_ID as _CLIENT_ID -) -from flytekit.configuration.platform import URL as _URL +from __future__ import absolute_import +from flytekit.clis.auth import auth as _auth -def test_discovery_client(): - discovery_endpoint = _DISCOVERY_ENDPOINT.get() - discovery_client = DiscoveryClient(discovery_url=discovery_endpoint) - authorization_endpoints = discovery_client.get_authorization_endpoints() - print("///////////////////////////////////////|||||||||||||||||||||||||||||||||||||||||") - print(authorization_endpoints.auth_endpoint) - print(authorization_endpoints.token_endpoint) - # client = _AuthorizationClient(redirect_uri=_REDIRECT_URI.get(), - # client_id=_CLIENT_ID.get(), - # auth_endpoint=authorization_endpoints.auth_endpoint, - # token_endpoint=authorization_endpoints.token_endpoint) + +def test_generate_code_verifier(): + verifier = _auth._generate_code_verifier() + # TODO: Write test later + assert verifier is not None diff --git a/tests/flytekit/unit/cli/pyflyte/test_basic_auth.py b/tests/flytekit/unit/cli/pyflyte/test_basic_auth.py new file mode 100644 index 0000000000..7c4436113b --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/test_basic_auth.py @@ -0,0 +1,33 @@ +from __future__ import absolute_import +import json + +from mock import MagicMock, patch, PropertyMock +from flytekit.clis.flyte_cli.main import _welcome_message +from flytekit.clis.sdk_in_container import basic_auth +from flytekit.configuration.creds import ( + CLIENT_CREDENTIALS_SECRET as _CREDENTIALS_SECRET, +) + +_welcome_message() + + +def test_get_secret(): + import os + os.environ[_CREDENTIALS_SECRET.env_var] = "abc" + assert basic_auth.get_secret() == "abc" + + +def test_get_basic_authorization_header(): + header = basic_auth.get_basic_authorization_header("client_id", "abc") + assert header == "Basic Y2xpZW50X2lkOmFiYw==" + + +@patch('flytekit.clis.sdk_in_container.basic_auth._requests') +def test_get_token(mock_requests): + response = MagicMock() + response.status_code = 200 + response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") + mock_requests.post.return_value = response + access, expiration = basic_auth.get_token("https://corp.idp.net", "abc123", "my_scope") + assert access == "abc" + assert expiration == 60 From e380b7eafdd1eb2bea1ae2f6c327e7b2dc9735e0 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 26 Nov 2019 10:51:13 -0800 Subject: [PATCH 19/40] update the character range for code verifier to include - _ . ~ --- flytekit/clis/auth/auth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flytekit/clis/auth/auth.py b/flytekit/clis/auth/auth.py index 4d3dfedf6b..d97e194744 100644 --- a/flytekit/clis/auth/auth.py +++ b/flytekit/clis/auth/auth.py @@ -33,13 +33,13 @@ def _generate_code_verifier(): """ - Generates a 'code_verifier' as described in section 4.1 of RFC 7636. + Generates a 'code_verifier' as described in https://tools.ietf.org/html/rfc7636#section-4.1 Adapted from https://github.com/openstack/deb-python-oauth2client/blob/master/oauth2client/_pkce.py. :return str: """ code_verifier = _base64.urlsafe_b64encode(_os.urandom(_code_verifier_length)).decode(_utf_8) # Eliminate invalid characters. - code_verifier = _re.sub('[^a-zA-Z0-9]+', '', code_verifier) + code_verifier = _re.sub('[^a-zA-Z0-9_\-.~]+', '', code_verifier) if len(code_verifier) < 43: raise ValueError("Verifier too short. number of bytes must be > 30.") elif len(code_verifier) > 128: From 1ecf29f1d9a5207b5587981159476e4f25b01d25 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 26 Nov 2019 11:04:42 -0800 Subject: [PATCH 20/40] adding test for set token --- flytekit/clients/raw.py | 1 - flytekit/clis/auth/auth.py | 2 +- tests/flytekit/unit/clients/__init__.py | 0 tests/flytekit/unit/clients/test_raw.py | 12 ++++++++++++ 4 files changed, 13 insertions(+), 2 deletions(-) create mode 100644 tests/flytekit/unit/clients/__init__.py create mode 100644 tests/flytekit/unit/clients/test_raw.py diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index ea33a3902e..04643b2913 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -98,7 +98,6 @@ def __init__(self, url, insecure=False, credentials=None, options=None): """ self._channel = None - # TODO: Revert all the for loops below if insecure: self._channel = _insecure_channel(url, options=list((options or {}).items())) else: diff --git a/flytekit/clis/auth/auth.py b/flytekit/clis/auth/auth.py index d97e194744..c3bb32f0c8 100644 --- a/flytekit/clis/auth/auth.py +++ b/flytekit/clis/auth/auth.py @@ -39,7 +39,7 @@ def _generate_code_verifier(): """ code_verifier = _base64.urlsafe_b64encode(_os.urandom(_code_verifier_length)).decode(_utf_8) # Eliminate invalid characters. - code_verifier = _re.sub('[^a-zA-Z0-9_\-.~]+', '', code_verifier) + code_verifier = _re.sub(r'[^a-zA-Z0-9_\-.~]+', '', code_verifier) if len(code_verifier) < 43: raise ValueError("Verifier too short. number of bytes must be > 30.") elif len(code_verifier) > 128: diff --git a/tests/flytekit/unit/clients/__init__.py b/tests/flytekit/unit/clients/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/unit/clients/test_raw.py b/tests/flytekit/unit/clients/test_raw.py new file mode 100644 index 0000000000..487cd697e1 --- /dev/null +++ b/tests/flytekit/unit/clients/test_raw.py @@ -0,0 +1,12 @@ +from __future__ import absolute_import +from flytekit.clients.raw import RawSynchronousFlyteClient +import mock + +@mock.patch('flytekit.clients.raw._admin_service') +@mock.patch('flytekit.clients.raw._insecure_channel') +def test_client_set_token(mock_channel, mock_admin): + mock_channel.return_value = True + mock_admin.AdminServiceStub.return_value = True + client = RawSynchronousFlyteClient(url='a.b.com', insecure=True) + client.set_access_token('abc') + assert client._metadata[0][1] == 'Bearer abc' From aeb65297e16a4206cd1fe16fe8e6df0b941c769a Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 26 Nov 2019 11:48:14 -0800 Subject: [PATCH 21/40] unit test for basic auth handler --- flytekit/clients/raw.py | 10 ++++++++- tests/flytekit/unit/clients/test_raw.py | 27 +++++++++++++++++++++++-- 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index 04643b2913..c343f66c83 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -6,7 +6,6 @@ from flytekit.common.exceptions import user as _user_exceptions from flytekit.configuration.creds import ( CLIENT_ID as _CLIENT_ID, - CLIENT_CREDENTIALS_SECRET_LOCATION as _CREDENTIALS_SECRET_FILE, CLIENT_CREDENTIALS_SCOPE as _SCOPE, ) from flytekit.clis.sdk_in_container import basic_auth @@ -32,6 +31,15 @@ def _refresh_credentials_standard(flyte_client): def _refresh_credentials_basic(flyte_client): + """ + This function is used by the _handle_rpc_error decorator, depending on the AUTH_MODE config object. This handler + is meant for SDK use-cases of auth (like pyflyte, or when users call SDK functions that require access to Admin, + like when waiting for another workflow to complete from within a task). This function uses basic auth, which means + the credentials for basic auth must be present from wherever this code is running. + + :param flyte_client: RawSynchronousFlyteClient + :return: + """ auth_endpoints = _credentials_access.get_authorization_endpoints() token_endpoint = auth_endpoints.token_endpoint client_secret = basic_auth.get_secret() diff --git a/tests/flytekit/unit/clients/test_raw.py b/tests/flytekit/unit/clients/test_raw.py index 487cd697e1..5d37b5f590 100644 --- a/tests/flytekit/unit/clients/test_raw.py +++ b/tests/flytekit/unit/clients/test_raw.py @@ -1,12 +1,35 @@ from __future__ import absolute_import -from flytekit.clients.raw import RawSynchronousFlyteClient +from flytekit.clients.raw import ( + RawSynchronousFlyteClient as _RawSynchronousFlyteClient, + _refresh_credentials_basic +) +from flytekit.configuration.creds import CLIENT_CREDENTIALS_SECRET as _CREDENTIALS_SECRET +from flytekit.clis.auth.discovery import AuthorizationEndpoints as _AuthorizationEndpoints import mock +import os +import json + @mock.patch('flytekit.clients.raw._admin_service') @mock.patch('flytekit.clients.raw._insecure_channel') def test_client_set_token(mock_channel, mock_admin): mock_channel.return_value = True mock_admin.AdminServiceStub.return_value = True - client = RawSynchronousFlyteClient(url='a.b.com', insecure=True) + client = _RawSynchronousFlyteClient(url='a.b.com', insecure=True) client.set_access_token('abc') assert client._metadata[0][1] == 'Bearer abc' + + +@mock.patch('flytekit.clis.sdk_in_container.basic_auth._requests') +@mock.patch('flytekit.clients.raw._credentials_access') +def test_refresh_credentials_basic(mock_credentials_access, mock_requests): + mock_credentials_access.get_authorization_endpoints.return_value = _AuthorizationEndpoints('auth', 'token') + response = mock.MagicMock() + response.status_code = 200 + response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") + mock_requests.post.return_value = response + os.environ[_CREDENTIALS_SECRET.env_var] = "asdf12345" + + mock_client = mock.MagicMock() + _refresh_credentials_basic(mock_client) + mock_client.set_access_token.assert_called_with('abc') From b27ae352353e3f39d0a0ec4e4c2cd7eabe769af1 Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Tue, 26 Nov 2019 11:52:33 -0800 Subject: [PATCH 22/40] add preliminary auth and credentials tests --- flytekit/clients/raw.py | 6 ++++ flytekit/clis/auth/auth.py | 1 - setup.py | 1 + tests/flytekit/unit/cli/auth/test_auth.py | 31 ++++++++++++++++++- .../unit/cli/auth/test_credentials.py | 6 ++++ 5 files changed, 43 insertions(+), 2 deletions(-) create mode 100644 tests/flytekit/unit/cli/auth/test_credentials.py diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index ea33a3902e..cf67bce43c 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -54,6 +54,12 @@ def _get_refresh_handler(auth_mode): def _handle_rpc_error(fn): def handler(*args, **kwargs): + """ + Wraps rpc errors as Flyte exceptions and handles authentication the client. + :param args: + :param kwargs: + :return: + """ retries = 2 try: for i in range(retries): diff --git a/flytekit/clis/auth/auth.py b/flytekit/clis/auth/auth.py index 4d3dfedf6b..2f2c0558ce 100644 --- a/flytekit/clis/auth/auth.py +++ b/flytekit/clis/auth/auth.py @@ -121,7 +121,6 @@ def handle_authorization_code(self, auth_code): class Credentials(object): - # TODO(katrogan): Also add expires_in handling. def __init__(self, access_token=None, id_token=None): self._access_token = access_token self._id_token = id_token diff --git a/setup.py b/setup.py index a82c835f5c..969279f530 100644 --- a/setup.py +++ b/setup.py @@ -41,6 +41,7 @@ "protobuf>=3.6.1,<4", "pytimeparse>=1.1.8,<2.0.0", "pytz>=2017.2,<2018.5", + "keyring>=18.0.1", "requests>=2.18.4,<3.0.0", "six>=1.9.0,<2.0.0", "sortedcontainers>=1.5.9,<2.0.0", diff --git a/tests/flytekit/unit/cli/auth/test_auth.py b/tests/flytekit/unit/cli/auth/test_auth.py index 178eec1115..dd71234c58 100644 --- a/tests/flytekit/unit/cli/auth/test_auth.py +++ b/tests/flytekit/unit/cli/auth/test_auth.py @@ -1,9 +1,38 @@ from __future__ import absolute_import +import re + from flytekit.clis.auth import auth as _auth +from multiprocessing import Queue as _Queue +try: # Python 3 + import http.server as _BaseHTTPServer +except ImportError: # Python 2 + import BaseHTTPServer as _BaseHTTPServer + def test_generate_code_verifier(): verifier = _auth._generate_code_verifier() - # TODO: Write test later assert verifier is not None + assert 43 < len(verifier) < 128 + assert not re.search(r'[^a-zA-Z0-9]+', verifier) + + +def test_generate_state_parameter(): + param = _auth._generate_state_parameter() + assert not re.search(r'[^a-zA-Z0-9-_.,]+', param) + + +def test_create_code_challenge(): + test_code_verifier = "test_code_verifier" + assert _auth._create_code_challenge(test_code_verifier) == "Qq1fGD0HhxwbmeMrqaebgn1qhvKeguQPXqLdpmixaM4" + + +def test_oauth_http_server(): + queue = _Queue() + server = _auth.OAuthHTTPServer(("localhost", 9000), _BaseHTTPServer.BaseHTTPRequestHandler, queue=queue) + test_auth_code = "auth_code" + server.handle_authorization_code(test_auth_code) + auth_code = queue.get() + assert test_auth_code == auth_code + diff --git a/tests/flytekit/unit/cli/auth/test_credentials.py b/tests/flytekit/unit/cli/auth/test_credentials.py new file mode 100644 index 0000000000..69864ea791 --- /dev/null +++ b/tests/flytekit/unit/cli/auth/test_credentials.py @@ -0,0 +1,6 @@ +from flytekit.clis.auth import credentials as _credentials + + +def test_is_absolute(): + assert _credentials._is_absolute("http://localhost:9000/my_endpoint") + assert _credentials._is_absolute("/my_endpoint") is False From d2ddf4a541382876cfff883a2812ef618c106644 Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Tue, 26 Nov 2019 13:32:11 -0800 Subject: [PATCH 23/40] add discovery client tests --- flytekit/clients/raw.py | 13 +++++-- flytekit/clis/auth/discovery.py | 2 +- requirements.txt | 1 + .../flytekit/unit/cli/auth/test_discovery.py | 39 +++++++++++++++++++ 4 files changed, 51 insertions(+), 4 deletions(-) create mode 100644 tests/flytekit/unit/cli/auth/test_discovery.py diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index 22985aedc1..78a7eecf68 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -20,12 +20,19 @@ def _refresh_credentials_standard(flyte_client): - _credentials_access.get_client().refresh_access_token() - _set_global_access_token() - + """ + This function is used when the configuration value for AUTH_MODE is set to 'standard'. + This either fetches the existing access token or initiates the flow to request a valid access token and store it. + :param flyte_client: RawSynchronousFlyteClient + :return: + """ if not _platform_config.AUTH.get(): # nothing to do return + + _credentials_access.get_client().refresh_access_token() + _set_global_access_token() + access_token = _get_global_access_token() flyte_client.set_access_token(access_token) diff --git a/flytekit/clis/auth/discovery.py b/flytekit/clis/auth/discovery.py index fce6988da2..32493ee203 100644 --- a/flytekit/clis/auth/discovery.py +++ b/flytekit/clis/auth/discovery.py @@ -19,7 +19,7 @@ class AuthorizationEndpoints(object): A simple wrapper around commonly discovered endpoints used for the PKCE auth flow. """ def __init__(self, auth_endpoint=None, token_endpoint=None): - self._auth_endpoint = auth_endpoint + self._auth_endpoint = auth_endpoint self._token_endpoint = token_endpoint @property diff --git a/requirements.txt b/requirements.txt index 1e4d65baa2..b9203d0682 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ pytest==4.6.6 mock==3.0.5 six==1.12.0 +responses==0.10.7 \ No newline at end of file diff --git a/tests/flytekit/unit/cli/auth/test_discovery.py b/tests/flytekit/unit/cli/auth/test_discovery.py new file mode 100644 index 0000000000..4055d18b35 --- /dev/null +++ b/tests/flytekit/unit/cli/auth/test_discovery.py @@ -0,0 +1,39 @@ +import pytest +import responses + +from flytekit.clis.auth import discovery as _discovery + + +@responses.activate +def test_get_authorization_endpoints(): + discovery_url = "http://flyte-admin.com/discovery" + + auth_endpoint = "http://flyte-admin.com/authorization" + token_endpoint = "http://flyte-admin.com/token" + responses.add(responses.GET, discovery_url, + json={'authorization_endpoint': auth_endpoint, + 'token_endpoint': token_endpoint}) + + discovery_client = _discovery.DiscoveryClient(discovery_url=discovery_url) + assert discovery_client.get_authorization_endpoints().auth_endpoint == auth_endpoint + assert discovery_client.get_authorization_endpoints().token_endpoint == token_endpoint + + +@responses.activate +def test_get_authorization_endpoints_missing_authorization_endpoint(): + discovery_url = "http://flyte-admin.com/discovery" + responses.add(responses.GET, discovery_url, json={'token_endpoint': "http://flyte-admin.com/token"}) + + discovery_client = _discovery.DiscoveryClient(discovery_url=discovery_url) + with pytest.raises(Exception): + discovery_client.get_authorization_endpoints() + + +@responses.activate +def test_get_authorization_endpoints_missing_token_endpoint(): + discovery_url = "http://flyte-admin.com/discovery" + responses.add(responses.GET, discovery_url, json={'authorization_endpoint': "http://flyte-admin.com/authorization"}) + + discovery_client = _discovery.DiscoveryClient(discovery_url=discovery_url) + with pytest.raises(Exception): + discovery_client.get_authorization_endpoints() From 18180e962b7d62f83a7c8a7e7d267e9c7cb07e05 Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Tue, 26 Nov 2019 13:44:43 -0800 Subject: [PATCH 24/40] fix test for updated regex --- tests/flytekit/unit/cli/auth/test_auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/flytekit/unit/cli/auth/test_auth.py b/tests/flytekit/unit/cli/auth/test_auth.py index dd71234c58..757e6f4797 100644 --- a/tests/flytekit/unit/cli/auth/test_auth.py +++ b/tests/flytekit/unit/cli/auth/test_auth.py @@ -15,7 +15,7 @@ def test_generate_code_verifier(): verifier = _auth._generate_code_verifier() assert verifier is not None assert 43 < len(verifier) < 128 - assert not re.search(r'[^a-zA-Z0-9]+', verifier) + assert not re.search(r'[^a-zA-Z0-9_\-.~]+', verifier) def test_generate_state_parameter(): From 83c91e12c03cb99d4ebd492a27225321e6555fc8 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 26 Nov 2019 14:22:46 -0800 Subject: [PATCH 25/40] change str to decode because python2 --- flytekit/clis/sdk_in_container/basic_auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/clis/sdk_in_container/basic_auth.py b/flytekit/clis/sdk_in_container/basic_auth.py index 6f936df838..c711418263 100644 --- a/flytekit/clis/sdk_in_container/basic_auth.py +++ b/flytekit/clis/sdk_in_container/basic_auth.py @@ -53,7 +53,7 @@ def get_basic_authorization_header(client_id, client_secret): :rtype: Text """ concated = "{}:{}".format(client_id, client_secret) - return "Basic {}".format(str(_base64.b64encode(concated.encode(_utf_8)), _utf_8)) + return "Basic {}".format(_base64.b64encode(concated.encode(_utf_8)).decode(_utf_8)) def get_token(token_endpoint, authorization_header, scope): From 77ce838f6c182c09f1a6c3df9bba92d68cb2d57b Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Wed, 27 Nov 2019 10:39:07 -0800 Subject: [PATCH 26/40] no auth mode, rm comment --- flytekit/clients/raw.py | 10 ++++++++-- flytekit/clis/auth/auth.py | 4 ---- flytekit/configuration/creds.py | 3 ++- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index 78a7eecf68..ca6bf67c57 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -26,7 +26,7 @@ def _refresh_credentials_standard(flyte_client): :param flyte_client: RawSynchronousFlyteClient :return: """ - if not _platform_config.AUTH.get(): + if not _platform_config.AUTH.get() or not _creds_config.AUTH_MODE.get(): # nothing to do return @@ -57,8 +57,14 @@ def _refresh_credentials_basic(flyte_client): flyte_client.set_access_token(token) +def _refresh_credentials_noop(flyte_client): + pass + + def _get_refresh_handler(auth_mode): - if auth_mode == "standard": + if not auth_mode: + return _refresh_credentials_noop + elif auth_mode == "standard": return _refresh_credentials_standard elif auth_mode == "basic": return _refresh_credentials_basic diff --git a/flytekit/clis/auth/auth.py b/flytekit/clis/auth/auth.py index dd995ca473..42bc93ab8f 100644 --- a/flytekit/clis/auth/auth.py +++ b/flytekit/clis/auth/auth.py @@ -134,10 +134,6 @@ def id_token(self): return self._id_token -# TODO: -# do we need to support initiate login URI? https://devforum.okta.com/t/initiate-login-uri-for-all-subdomain-urls/3766 - - class AuthorizationClient(object): def __init__(self, auth_endpoint=None, token_endpoint=None, client_id=None, redirect_uri=None): self._auth_endpoint = auth_endpoint diff --git a/flytekit/configuration/creds.py b/flytekit/configuration/creds.py index 876fa727e9..64a32db8b8 100644 --- a/flytekit/configuration/creds.py +++ b/flytekit/configuration/creds.py @@ -55,11 +55,12 @@ of the authorization server. """ -AUTH_MODE = _config_common.FlyteStringConfigurationEntry('credentials', 'auth_mode', default="standard") +AUTH_MODE = _config_common.FlyteStringConfigurationEntry('credentials', 'auth_mode', default=None) """ The auth mode defines the behavior used to request and refresh credentials. The currently supported modes include: - 'standard' This uses the pkce-enhanced authorization code flow by opening a browser window to initiate credentials access. - 'basic' This uses cert-based auth in which the end user enters his/her username and password and public key encryption is used to facilitate authentication. +- None: No auth will be attempted. """ From bd7324ef2eb97a19fdb4a4addab26b9a849f05c1 Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Wed, 27 Nov 2019 10:47:20 -0800 Subject: [PATCH 27/40] nevermind --- flytekit/clients/raw.py | 6 ++---- flytekit/configuration/creds.py | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index ca6bf67c57..7639628b4e 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -26,7 +26,7 @@ def _refresh_credentials_standard(flyte_client): :param flyte_client: RawSynchronousFlyteClient :return: """ - if not _platform_config.AUTH.get() or not _creds_config.AUTH_MODE.get(): + if not _platform_config.AUTH.get(): # nothing to do return @@ -62,9 +62,7 @@ def _refresh_credentials_noop(flyte_client): def _get_refresh_handler(auth_mode): - if not auth_mode: - return _refresh_credentials_noop - elif auth_mode == "standard": + if auth_mode == "standard": return _refresh_credentials_standard elif auth_mode == "basic": return _refresh_credentials_basic diff --git a/flytekit/configuration/creds.py b/flytekit/configuration/creds.py index 64a32db8b8..558917d651 100644 --- a/flytekit/configuration/creds.py +++ b/flytekit/configuration/creds.py @@ -55,7 +55,7 @@ of the authorization server. """ -AUTH_MODE = _config_common.FlyteStringConfigurationEntry('credentials', 'auth_mode', default=None) +AUTH_MODE = _config_common.FlyteStringConfigurationEntry('credentials', 'auth_mode', default="standard") """ The auth mode defines the behavior used to request and refresh credentials. The currently supported modes include: - 'standard' This uses the pkce-enhanced authorization code flow by opening a browser window to initiate credentials From 7004c5998080eceaae17b0e7cf64947855644a90 Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Wed, 27 Nov 2019 10:48:18 -0800 Subject: [PATCH 28/40] one more revert --- flytekit/clients/raw.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index 7639628b4e..6a941f5709 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -26,13 +26,14 @@ def _refresh_credentials_standard(flyte_client): :param flyte_client: RawSynchronousFlyteClient :return: """ - if not _platform_config.AUTH.get(): - # nothing to do - return _credentials_access.get_client().refresh_access_token() _set_global_access_token() + if not _platform_config.AUTH.get(): + # nothing to do + return + access_token = _get_global_access_token() flyte_client.set_access_token(access_token) From 244d032a2cc24c0630733127a299e6881b4a1578 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Wed, 27 Nov 2019 11:49:07 -0800 Subject: [PATCH 29/40] change handling around default home directory config file loading so that if the host is specified in the user's ~/.flyte/config file, you don't need to specify it in the flyte-cli command --- flytekit/clis/flyte_cli/main.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/flytekit/clis/flyte_cli/main.py b/flytekit/clis/flyte_cli/main.py index 6dc9b34c40..2fbc9bf660 100644 --- a/flytekit/clis/flyte_cli/main.py +++ b/flytekit/clis/flyte_cli/main.py @@ -40,7 +40,6 @@ def _welcome_message(): _click.secho("Welcome to Flyte CLI! Version: {}".format(_tt(__version__)), bold=True) - _detect_default_config_file() def _detect_default_config_file(): @@ -53,6 +52,10 @@ def _detect_default_config_file(): _click.secho("Config file not found at default location, relying on environment variables instead", fg='blue') +# Run this as the module is loading to pick up settings that click can then use when constructing the commands +_detect_default_config_file() + + def _get_io_string(literal_map, verbose=False): """ :param flytekit.models.literals.LiteralMap literal_map: @@ -426,7 +429,6 @@ class _FlyteSubCommand(_click.Command): 'project': _PROJECT_FLAGS[0], 'domain': _DOMAIN_FLAGS[0], 'name': _NAME_FLAGS[0], - 'host': _HOST_FLAGS[0] } _PASSABLE_FLAGS = { @@ -441,6 +443,16 @@ def make_context(self, cmd_name, args, parent=None): parent.params[param.name] is not None: prefix_args.extend([type(self)._PASSABLE_ARGS[param.name], _six.text_type(parent.params[param.name])]) + # Special handling for the host option, because this option can be specified in the user's ~/.flyte/config + # file, and unlike other options, doesn't get picked up before click commands are run + if param.name == 'host': + if param.name in parent.params and parent.params[param.name] is not None: + prefix_args.extend( + [_HOST_FLAGS[0], _six.text_type(parent.params[param.name])]) + elif _platform_config.URL.get(): + prefix_args.extend( + [_HOST_FLAGS[0], _six.text_type(_platform_config.URL.get())]) + # For flags, we don't append the value of the flag, otherwise click will fail to parse if param.name in type(self)._PASSABLE_FLAGS and \ param.name in parent.params and \ From 067a604c5a57dda8e71d18d9809ea7fd1fb5a5e0 Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Wed, 4 Dec 2019 14:53:30 -0800 Subject: [PATCH 30/40] address review comments --- flytekit/clis/auth/credentials.py | 16 +++++++--------- flytekit/configuration/creds.py | 7 ------- requirements.txt | 3 +-- setup.py | 1 + tests/flytekit/unit/cli/auth/test_credentials.py | 6 ------ 5 files changed, 9 insertions(+), 24 deletions(-) delete mode 100644 tests/flytekit/unit/cli/auth/test_credentials.py diff --git a/flytekit/clis/auth/credentials.py b/flytekit/clis/auth/credentials.py index e0dde5cc76..7b945987fd 100644 --- a/flytekit/clis/auth/credentials.py +++ b/flytekit/clis/auth/credentials.py @@ -4,7 +4,6 @@ from flytekit.clis.auth.discovery import DiscoveryClient as _DiscoveryClient from flytekit.configuration.creds import ( - DISCOVERY_ENDPOINT as _DISCOVERY_ENDPOINT, REDIRECT_URI as _REDIRECT_URI, CLIENT_ID as _CLIENT_ID ) @@ -15,9 +14,12 @@ except ImportError: # Python 2 import urlparse as _urlparse +# Default, well known-URI string used for fetching JSON metadata. See https://tools.ietf.org/html/rfc8414#section-3. +discovery_endpoint_path = ".well-known/oauth-authorization-server" -def _is_absolute(url): - return bool(_urlparse.urlparse(url).netloc) + +def _get_discovery_endpoint(): + return _urlparse.urljoin(_URL.get(), discovery_endpoint_path) # Lazy initialized authorization client singleton @@ -28,9 +30,7 @@ def get_client(): global _authorization_client if _authorization_client is not None: return _authorization_client - discovery_endpoint = _DISCOVERY_ENDPOINT.get() - if not _is_absolute(discovery_endpoint): - discovery_endpoint = _urlparse.urljoin(_URL.get(), discovery_endpoint) + discovery_endpoint = _get_discovery_endpoint() discovery_client = _DiscoveryClient(discovery_url=discovery_endpoint) authorization_endpoints = discovery_client.get_authorization_endpoints() @@ -42,8 +42,6 @@ def get_client(): def get_authorization_endpoints(): - discovery_endpoint = _DISCOVERY_ENDPOINT.get() - if not _is_absolute(discovery_endpoint): - discovery_endpoint = _urlparse.urljoin(_URL.get(), discovery_endpoint) + discovery_endpoint = _get_discovery_endpoint() discovery_client = _DiscoveryClient(discovery_url=discovery_endpoint) return discovery_client.get_authorization_endpoints() diff --git a/flytekit/configuration/creds.py b/flytekit/configuration/creds.py index 558917d651..5b859da30c 100644 --- a/flytekit/configuration/creds.py +++ b/flytekit/configuration/creds.py @@ -2,13 +2,6 @@ from flytekit.configuration import common as _config_common -DISCOVERY_ENDPOINT = _config_common.FlyteStringConfigurationEntry('credentials', 'discovery_endpoint', default='https://company.idp.com/.well-known/oauth-authorization-server') -""" -This endpoint fetches authorization server metadata as described in: -https://tools.ietf.org/html/rfc8414 -The endpoint path can be relative or absolute. -""" - CLIENT_ID = _config_common.FlyteStringConfigurationEntry('credentials', 'client_id', default=None) """ This is the public identifier for the app which handles authorization for a Flyte deployment. diff --git a/requirements.txt b/requirements.txt index b9203d0682..8fc7f72648 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ pytest==4.6.6 mock==3.0.5 -six==1.12.0 -responses==0.10.7 \ No newline at end of file +six==1.12.0 \ No newline at end of file diff --git a/setup.py b/setup.py index 969279f530..65877e0d34 100644 --- a/setup.py +++ b/setup.py @@ -43,6 +43,7 @@ "pytz>=2017.2,<2018.5", "keyring>=18.0.1", "requests>=2.18.4,<3.0.0", + "responses>=0.10.7", "six>=1.9.0,<2.0.0", "sortedcontainers>=1.5.9,<2.0.0", "statsd>=3.0.0,<4.0.0", diff --git a/tests/flytekit/unit/cli/auth/test_credentials.py b/tests/flytekit/unit/cli/auth/test_credentials.py deleted file mode 100644 index 69864ea791..0000000000 --- a/tests/flytekit/unit/cli/auth/test_credentials.py +++ /dev/null @@ -1,6 +0,0 @@ -from flytekit.clis.auth import credentials as _credentials - - -def test_is_absolute(): - assert _credentials._is_absolute("http://localhost:9000/my_endpoint") - assert _credentials._is_absolute("/my_endpoint") is False From 23a14c86bb927ea24c4f18b88b65317252016f99 Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Wed, 4 Dec 2019 15:25:58 -0800 Subject: [PATCH 31/40] nits --- flytekit/clients/raw.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index 6a941f5709..38a7d495a0 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -8,8 +8,8 @@ CLIENT_ID as _CLIENT_ID, CLIENT_CREDENTIALS_SCOPE as _SCOPE, ) -from flytekit.clis.sdk_in_container import basic_auth -import logging +from flytekit.clis.sdk_in_container import basic_auth as _basic_auth +import logging as _logging import six as _six from flytekit.configuration import creds as _creds_config, platform as _platform_config @@ -50,11 +50,11 @@ def _refresh_credentials_basic(flyte_client): """ auth_endpoints = _credentials_access.get_authorization_endpoints() token_endpoint = auth_endpoints.token_endpoint - client_secret = basic_auth.get_secret() - logging.debug('Basic authorization flow with client id {} scope {}', _CLIENT_ID.get(), _SCOPE.get()) - authorization_header = basic_auth.get_basic_authorization_header(_CLIENT_ID.get(), client_secret) - token, expires_in = basic_auth.get_token(token_endpoint, authorization_header, _SCOPE.get()) - logging.info('Retrieved new token, expires in {}'.format(expires_in)) + client_secret = _basic_auth.get_secret() + _logging.debug('Basic authorization flow with client id {} scope {}', _CLIENT_ID.get(), _SCOPE.get()) + authorization_header = _basic_auth.get_basic_authorization_header(_CLIENT_ID.get(), client_secret) + token, expires_in = _basic_auth.get_token(token_endpoint, authorization_header, _SCOPE.get()) + _logging.info('Retrieved new token, expires in {}'.format(expires_in)) flyte_client.set_access_token(token) From a52ba077246a792ef8cbcab81c0f8c1f4c676828 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Thu, 5 Dec 2019 13:45:12 -0800 Subject: [PATCH 32/40] Env var (from pkce-auth pr) option 2 (#64) making a waterfall for how configuration settings are looked up from the environment variable --- flytekit/clis/sdk_in_container/basic_auth.py | 23 ++----- flytekit/configuration/common.py | 61 ++++++++++++++++--- tests/flytekit/unit/configuration/conftest.py | 4 +- .../unit/configuration/test_waterfall.py | 47 ++++++++++++++ 4 files changed, 107 insertions(+), 28 deletions(-) create mode 100644 tests/flytekit/unit/configuration/test_waterfall.py diff --git a/flytekit/clis/sdk_in_container/basic_auth.py b/flytekit/clis/sdk_in_container/basic_auth.py index c711418263..406f3e4306 100644 --- a/flytekit/clis/sdk_in_container/basic_auth.py +++ b/flytekit/clis/sdk_in_container/basic_auth.py @@ -7,7 +7,6 @@ from flytekit.common.exceptions.base import FlyteException as _FlyteException from flytekit.configuration.creds import ( - CLIENT_CREDENTIALS_SECRET_LOCATION as _CREDENTIALS_SECRET_FILE, CLIENT_CREDENTIALS_SECRET as _CREDENTIALS_SECRET, ) @@ -18,30 +17,16 @@ class FlyteAuthenticationException(_FlyteException): _ERROR_CODE = "FlyteAuthenticationFailed" -def get_file_contents(location): - """ - This reads an input file, and returns the string contents, and should be used for reading credentials. - This function will also strip newlines. - - :param Text location: The file path holding the client id or secret - :rtype: Text - """ - with open(location, 'r') as f: - return f.read().replace('\n', '') - - def get_secret(): """ This function will either read in the password from the file path given by the CLIENT_CREDENTIALS_SECRET_LOCATION config object, or from the environment variable using the CLIENT_CREDENTIALS_SECRET config object. :rtype: Text """ - if _CREDENTIALS_SECRET_FILE.get(): - return get_file_contents(_CREDENTIALS_SECRET_FILE.get()) - elif _CREDENTIALS_SECRET.get(): - return _CREDENTIALS_SECRET.get() - raise FlyteAuthenticationException('No secret could be found in either {} or the {} env variable'.format( - _CREDENTIALS_SECRET_FILE.get(), _CREDENTIALS_SECRET.env_var)) + secret = _CREDENTIALS_SECRET.get() + if secret: + return secret + raise FlyteAuthenticationException('No secret could be found') def get_basic_authorization_header(client_id, client_secret): diff --git a/flytekit/configuration/common.py b/flytekit/configuration/common.py index 8ef86286d1..54d9ca7ee8 100644 --- a/flytekit/configuration/common.py +++ b/flytekit/configuration/common.py @@ -118,6 +118,20 @@ def __exit__(self, exc_type, exc_val, exc_tb): del _os.environ[self._config.env_var] +def _get_file_contents(location): + """ + This reads an input file, and returns the string contents, and should be used for reading credentials. + This function will also strip newlines. + + :param Text location: The file path holding the client id or secret + :rtype: Text + """ + if _os.path.isfile(location): + with open(location, 'r') as f: + return f.read().replace('\n', '') + return None + + class _FlyteConfigurationEntry(_six.with_metaclass(_abc.ABCMeta, object)): def __init__(self, section, key, default=None, validator=None, fallback=None): @@ -138,6 +152,37 @@ def env_var(self): def _getter(self): pass + def retrieve_value(self): + """ + The logic in this function changes the lookup behavior for all configuration objects before hitting the + configuration file. + + For a given configuration object ('mysection', 'mysetting'), it will now look at this waterfall: + + i.) The environment variable named 'FLYTE_MYSECTION_MYSETTING' + + ii.) The value of the environment variable that is named the value of the environment variable named + 'FLYTE_MYSECTION_MYSETTING'. That is if os.environ['FLYTE_MYSECTION_MYSETTING'] = 'AAA' and + os.environ['AA'] = 'abc', then 'abc' will be the final value. + + iii.) The contents of the file pointed to by the environment variable named 'FLYTE_MYSECTION_MYSETTING', + assuming the value is a file. + + While it is helpful, this pattern does interrupt the manually specified fallback logic, by effective injecting + two more fallbacks behind the scenes. Just keep this in mind as you are using/creating configuration objects. + :rtype: Text + """ + val = _os.environ.get(self.env_var, None) + if val is None: + referenced_env_var = _os.environ.get("{}_FROM_ENV_VAR".format(self.env_var), None) + if referenced_env_var is not None: + val = _os.environ.get(referenced_env_var, None) + if val is None: + referenced_file = _os.environ.get("{}_FROM_FILE".format(self.env_var), None) + if referenced_file is not None: + val = _get_file_contents(referenced_file) + return val + def get(self): val = self._getter() if val is None and self._fallback is not None: @@ -178,7 +223,7 @@ def _validate_not_null(self, val): class FlyteStringConfigurationEntry(_FlyteConfigurationEntry): def _getter(self): - val = _os.environ.get(self.env_var, None) + val = self.retrieve_value() if val is None: val = CONFIGURATION_SINGLETON.get_string(self._section, self._key, default=self._default) return val @@ -186,7 +231,7 @@ def _getter(self): class FlyteIntegerConfigurationEntry(_FlyteConfigurationEntry): def _getter(self): - val = _os.environ.get(self.env_var, None) + val = self.retrieve_value() if val is None: val = CONFIGURATION_SINGLETON.get_int(self._section, self._key, default=self._default) if val is not None: @@ -196,7 +241,7 @@ def _getter(self): class FlyteBoolConfigurationEntry(_FlyteConfigurationEntry): def _getter(self): - val = _os.environ.get(self.env_var, None) + val = self.retrieve_value() if val is None: return CONFIGURATION_SINGLETON.get_bool(self._section, self._key, default=self._default) @@ -209,7 +254,7 @@ def _getter(self): class FlyteStringListConfigurationEntry(_FlyteConfigurationEntry): def _getter(self): - val = _os.environ.get(self.env_var, None) + val = self.retrieve_value() if val is None: val = CONFIGURATION_SINGLETON.get_string(self._section, self._key) if val is None: @@ -219,7 +264,7 @@ def _getter(self): class FlyteRequiredStringConfigurationEntry(_FlyteRequiredConfigurationEntry): def _getter(self): - val = _os.environ.get(self.env_var, None) + val = self.retrieve_value() if val is None: val = CONFIGURATION_SINGLETON.get_string(self._section, self._key, default=self._default) return val @@ -227,7 +272,7 @@ def _getter(self): class FlyteRequiredIntegerConfigurationEntry(_FlyteRequiredConfigurationEntry): def _getter(self): - val = _os.environ.get(self.env_var, None) + val = self.retrieve_value() if val is None: val = CONFIGURATION_SINGLETON.get_int(self._section, self._key, default=self._default) return int(val) @@ -235,7 +280,7 @@ def _getter(self): class FlyteRequiredBoolConfigurationEntry(_FlyteRequiredConfigurationEntry): def _getter(self): - val = _os.environ.get(self.env_var, None) + val = self.retrieve_value() if val is None: val = CONFIGURATION_SINGLETON.get_bool(self._section, self._key, default=self._default) return bool(val) @@ -243,7 +288,7 @@ def _getter(self): class FlyteRequiredStringListConfigurationEntry(_FlyteRequiredConfigurationEntry): def _getter(self): - val = _os.environ.get(self.env_var, None) + val = self.retrieve_value() if val is None: val = CONFIGURATION_SINGLETON.get_string(self._section, self._key) if val is None: diff --git a/tests/flytekit/unit/configuration/conftest.py b/tests/flytekit/unit/configuration/conftest.py index 0d7d1b314b..446fb12fac 100644 --- a/tests/flytekit/unit/configuration/conftest.py +++ b/tests/flytekit/unit/configuration/conftest.py @@ -1,10 +1,12 @@ from __future__ import absolute_import from flytekit.configuration import set_flyte_config_file as _set_config import pytest as _pytest - +import os as _os @_pytest.fixture(scope="function", autouse=True) def clear_configs(): _set_config(None) + environment_variables = _os.environ.copy() yield + _os.environ = environment_variables _set_config(None) diff --git a/tests/flytekit/unit/configuration/test_waterfall.py b/tests/flytekit/unit/configuration/test_waterfall.py new file mode 100644 index 0000000000..0de9317e08 --- /dev/null +++ b/tests/flytekit/unit/configuration/test_waterfall.py @@ -0,0 +1,47 @@ +from __future__ import absolute_import +from flytekit.configuration import set_flyte_config_file as _set_flyte_config_file, \ + common as _common, \ + TemporaryConfiguration as _TemporaryConfiguration + +from flytekit.common.utils import AutoDeletingTempDir as _AutoDeletingTempDir +import os as _os + + +def test_lookup_waterfall_raw_env_var(): + x = _common.FlyteStringConfigurationEntry('test', 'setting', default=None) + + if 'FLYTE_TEST_SETTING' in _os.environ: + del _os.environ['FLYTE_TEST_SETTING'] + assert x.get() is None + + _os.environ['FLYTE_TEST_SETTING'] = 'lorem' + assert x.get() == 'lorem' + + +def test_lookup_waterfall_referenced_env_var(): + x = _common.FlyteStringConfigurationEntry('test', 'setting', default=None) + + if 'FLYTE_TEST_SETTING' in _os.environ: + del _os.environ['FLYTE_TEST_SETTING'] + assert x.get() is None + + if 'TEMP_PLACEHOLDER' in _os.environ: + del _os.environ['TEMP_PLACEHOLDER'] + _os.environ['TEMP_PLACEHOLDER'] = 'lorem' + _os.environ['FLYTE_TEST_SETTING_FROM_ENV_VAR'] = 'TEMP_PLACEHOLDER' + assert x.get() == 'lorem' + + +def test_lookup_waterfall_referenced_file(): + x = _common.FlyteStringConfigurationEntry('test', 'setting', default=None) + + if 'FLYTE_TEST_SETTING' in _os.environ: + del _os.environ['FLYTE_TEST_SETTING'] + assert x.get() is None + + with _AutoDeletingTempDir("config_testing") as tmp_dir: + with open(tmp_dir.get_named_tempfile('name'), 'w') as fh: + fh.write('secret_password') + + _os.environ['FLYTE_TEST_SETTING_FROM_FILE'] = tmp_dir.get_named_tempfile('name') + assert x.get() == 'secret_password' From 7f54611ebc5d19762f6de58b0a5e35d7d3207b11 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Thu, 5 Dec 2019 14:11:12 -0800 Subject: [PATCH 33/40] remove no longer necessary location backup for credentials secret --- flytekit/configuration/creds.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/flytekit/configuration/creds.py b/flytekit/configuration/creds.py index 5b859da30c..2bc4bd7e74 100644 --- a/flytekit/configuration/creds.py +++ b/flytekit/configuration/creds.py @@ -26,12 +26,6 @@ Traditionally this value is 'authorization' however it is made configurable. """ -CLIENT_CREDENTIALS_SECRET_LOCATION = _config_common.FlyteStringConfigurationEntry( - 'credentials', 'client_secret_location', default=None) -""" -Used for basic auth, which is automatically called during pyflyte. This is the location to look for the password. The -client id config setting is shared across the basic and standard auth flows. -""" CLIENT_CREDENTIALS_SECRET = _config_common.FlyteStringConfigurationEntry('credentials', 'client_secret', default=None) """ From 053e03ed9c1480fe352d3a143f8def1d18b98ff8 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Thu, 5 Dec 2019 14:15:25 -0800 Subject: [PATCH 34/40] use the real auth exception class --- flytekit/clis/sdk_in_container/basic_auth.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/flytekit/clis/sdk_in_container/basic_auth.py b/flytekit/clis/sdk_in_container/basic_auth.py index 406f3e4306..f3211b97d5 100644 --- a/flytekit/clis/sdk_in_container/basic_auth.py +++ b/flytekit/clis/sdk_in_container/basic_auth.py @@ -5,7 +5,7 @@ import requests as _requests -from flytekit.common.exceptions.base import FlyteException as _FlyteException +from flytekit.common.exceptions.user import FlyteAuthenticationException as _FlyteAuthenticationException from flytekit.configuration.creds import ( CLIENT_CREDENTIALS_SECRET as _CREDENTIALS_SECRET, ) @@ -13,10 +13,6 @@ _utf_8 = 'utf-8' -class FlyteAuthenticationException(_FlyteException): - _ERROR_CODE = "FlyteAuthenticationFailed" - - def get_secret(): """ This function will either read in the password from the file path given by the CLIENT_CREDENTIALS_SECRET_LOCATION @@ -26,7 +22,7 @@ def get_secret(): secret = _CREDENTIALS_SECRET.get() if secret: return secret - raise FlyteAuthenticationException('No secret could be found') + raise _FlyteAuthenticationException('No secret could be found') def get_basic_authorization_header(client_id, client_secret): @@ -62,7 +58,7 @@ def get_token(token_endpoint, authorization_header, scope): response = _requests.post(token_endpoint, data=body, headers=headers) if response.status_code != 200: _logging.error("Non-200 ({}) received from IDP: {}".format(response.status_code, response.text)) - raise FlyteAuthenticationException('Non-200 received from IDP') + raise _FlyteAuthenticationException('Non-200 received from IDP') response = response.json() return response['access_token'], response['expires_in'] From 0dc31860539bed1652332cecd10b0d36dd805f34 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Thu, 5 Dec 2019 16:11:00 -0800 Subject: [PATCH 35/40] get_discovery_endpoint changes --- flytekit/clis/auth/credentials.py | 6 ++++-- flytekit/clis/auth/discovery.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/flytekit/clis/auth/credentials.py b/flytekit/clis/auth/credentials.py index 7b945987fd..f12a075d71 100644 --- a/flytekit/clis/auth/credentials.py +++ b/flytekit/clis/auth/credentials.py @@ -7,7 +7,7 @@ REDIRECT_URI as _REDIRECT_URI, CLIENT_ID as _CLIENT_ID ) -from flytekit.configuration.platform import URL as _URL +from flytekit.configuration.platform import URL as _URL, INSECURE as _INSECURE try: # Python 3 import urllib.parse as _urlparse @@ -19,7 +19,9 @@ def _get_discovery_endpoint(): - return _urlparse.urljoin(_URL.get(), discovery_endpoint_path) + if _INSECURE.get(): + return _urlparse.urljoin('http://{}/'.format(_URL.get()), discovery_endpoint_path) + return _urlparse.urljoin('https://{}/'.format(_URL.get()), discovery_endpoint_path) # Lazy initialized authorization client singleton diff --git a/flytekit/clis/auth/discovery.py b/flytekit/clis/auth/discovery.py index 32493ee203..fce6988da2 100644 --- a/flytekit/clis/auth/discovery.py +++ b/flytekit/clis/auth/discovery.py @@ -19,7 +19,7 @@ class AuthorizationEndpoints(object): A simple wrapper around commonly discovered endpoints used for the PKCE auth flow. """ def __init__(self, auth_endpoint=None, token_endpoint=None): - self._auth_endpoint = auth_endpoint + self._auth_endpoint = auth_endpoint self._token_endpoint = token_endpoint @property From 946a0582463a13175c1ba767e9d160d5ae7b36a5 Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Thu, 5 Dec 2019 16:59:48 -0800 Subject: [PATCH 36/40] cleanup keyring code --- flytekit/clients/helpers.py | 20 -------------- flytekit/clients/raw.py | 16 ++++-------- flytekit/clis/auth/auth.py | 43 ++++++++++++++++++++++++++++--- flytekit/clis/auth/credentials.py | 7 ++--- 4 files changed, 46 insertions(+), 40 deletions(-) diff --git a/flytekit/clients/helpers.py b/flytekit/clients/helpers.py index bee5b46d8c..10640b6d74 100644 --- a/flytekit/clients/helpers.py +++ b/flytekit/clients/helpers.py @@ -1,12 +1,6 @@ -import keyring as _keyring from flytekit.clis.auth import credentials as _credentials_access -# Identifies the service used for storing passwords in keyring -_keyring_service_name = "flytecli" -# Identifies the key used for storing and fetching from keyring. In our case, instead of a username as the keyring docs -# suggest, we are storing a user's oidc. -_keyring_storage_key = "access_token" def iterate_node_executions( @@ -85,17 +79,3 @@ def iterate_task_executions(client, node_execution_identifier, limit=None, filte break token = next_token - -# Fetches an existing authorization access token if it exists in keyring or sets if it's unassigned. -def get_global_access_token(): - access_token = _keyring.get_password(_keyring_service_name, _keyring_storage_key) - if access_token is None: - access_token = set_global_access_token() - return access_token - - -# Assigns and returns the authorization access token in keyring. -def set_global_access_token(): - credentials = _credentials_access.get_client().credentials - _keyring.set_password(_keyring_service_name, _keyring_storage_key, credentials.access_token) - return credentials.access_token diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index 38a7d495a0..e58b4a6e88 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -4,6 +4,7 @@ StatusCode as _GrpcStatusCode, ssl_channel_credentials as _ssl_channel_credentials from flyteidl.service import admin_pb2_grpc as _admin_service from flytekit.common.exceptions import user as _user_exceptions +from flytekit.configuration.platform import AUTH as _AUTH from flytekit.configuration.creds import ( CLIENT_ID as _CLIENT_ID, CLIENT_CREDENTIALS_SCOPE as _SCOPE, @@ -14,9 +15,6 @@ from flytekit.configuration import creds as _creds_config, platform as _platform_config from flytekit.clis.auth import credentials as _credentials_access -from flytekit.clients.helpers import ( - get_global_access_token as _get_global_access_token, set_global_access_token as _set_global_access_token -) def _refresh_credentials_standard(flyte_client): @@ -28,14 +26,7 @@ def _refresh_credentials_standard(flyte_client): """ _credentials_access.get_client().refresh_access_token() - _set_global_access_token() - - if not _platform_config.AUTH.get(): - # nothing to do - return - - access_token = _get_global_access_token() - flyte_client.set_access_token(access_token) + flyte_client.set_access_token(_credentials_access.get_client().credentials.access_token) def _refresh_credentials_basic(flyte_client): @@ -134,6 +125,8 @@ def __init__(self, url, insecure=False, credentials=None, options=None): ) self._stub = _admin_service.AdminServiceStub(self._channel) self._metadata = None + if _AUTH.get(): + self.force_auth_flow() def set_access_token(self, access_token): self._metadata = [(_creds_config.AUTHORIZATION_METADATA_KEY.get(), "Bearer {}".format(access_token))] @@ -279,6 +272,7 @@ def list_workflow_ids_paginated(self, identifier_list_request): :rtype: flyteidl.admin.common_pb2.NamedEntityIdentifierList :raises: TODO """ + _logging.warn("hi katrina, metadata is {}".format(self._metadata)) return self._stub.ListWorkflowIds(identifier_list_request, metadata=self._metadata) @_handle_rpc_error diff --git a/flytekit/clis/auth/auth.py b/flytekit/clis/auth/auth.py index 42bc93ab8f..005371d180 100644 --- a/flytekit/clis/auth/auth.py +++ b/flytekit/clis/auth/auth.py @@ -1,5 +1,6 @@ import base64 as _base64 import hashlib as _hashlib +import keyring as _keyring import os as _os import re as _re import requests as _requests @@ -31,6 +32,15 @@ _utf_8 = 'utf-8' +# Identifies the service used for storing passwords in keyring +_keyring_service_name = "flyteauth" +# Identifies the key used for storing and fetching from keyring. In our case, instead of a username as the keyring docs +# suggest, we are storing a user's oidc. +_keyring_access_token_storage_key = "access_token" +_keyring_id_token_storage_key = "id_token" +_keyring_refresh_token_storage_key = "refresh_token" + + def _generate_code_verifier(): """ Generates a 'code_verifier' as described in https://tools.ietf.org/html/rfc7636#section-4.1 @@ -148,6 +158,7 @@ def __init__(self, auth_endpoint=None, token_endpoint=None, client_id=None, redi self._credentials = None self._refresh_token = None self._headers = {'content-type': "application/x-www-form-urlencoded"} + self._expired = False self._params = { "client_id": client_id, # This must match the Client ID of the OAuth application. @@ -160,7 +171,15 @@ def __init__(self, auth_endpoint=None, token_endpoint=None, client_id=None, redi "code_challenge_method": "S256", } - # Initiate token request flow + # Prefer to use already-fetched token values when they've been set globally. + self._refresh_token = _keyring.get_password(_keyring_service_name, _keyring_refresh_token_storage_key) + access_token = _keyring.get_password(_keyring_service_name, _keyring_access_token_storage_key) + id_token = _keyring.get_password(_keyring_service_name, _keyring_id_token_storage_key) + if access_token and id_token: + self._credentials = Credentials(access_token=access_token, id_token=id_token) + return + + # In the absence of globally-set token values, initiate the token request flow q = _Queue() # First prepare the callback server in the background server = self._create_callback_server(q) @@ -203,7 +222,14 @@ def _initialize_credentials(self, auth_token_resp): if "refresh_token" in response_body: self._refresh_token = response_body["refresh_token"] - self._credentials = Credentials(access_token=response_body["access_token"], id_token=response_body["id_token"]) + access_token = response_body["access_token"] + id_token = response_body["id_token"] + refresh_token = response_body["refresh_token"] + + _keyring.set_password(_keyring_service_name, _keyring_access_token_storage_key, access_token) + _keyring.set_password(_keyring_service_name, _keyring_id_token_storage_key, id_token) + _keyring.set_password(_keyring_service_name, _keyring_refresh_token_storage_key, refresh_token) + self._credentials = Credentials(access_token=access_token, id_token=id_token) def request_access_token(self, auth_code): if self._state != auth_code.state: @@ -239,8 +265,10 @@ def refresh_access_token(self): allow_redirects=False ) if resp.status_code != _StatusCodes.OK: - raise Exception('Failed to request access token with response: [{}] {}'.format( - resp.status_code, resp.content)) + self._expired = True + # In the absence of a successful response, assume the refresh token is expired. This should indicate + # to the caller that the AuthorizationClient is defunct and a new one needs to be re-initialized. + return self._initialize_credentials(resp) @property @@ -249,3 +277,10 @@ def credentials(self): :return flytekit.clis.auth.auth.Credentials: """ return self._credentials + + @property + def expired(self): + """ + :return bool: + """ + return self._expired diff --git a/flytekit/clis/auth/credentials.py b/flytekit/clis/auth/credentials.py index 7b945987fd..f2d3744a8b 100644 --- a/flytekit/clis/auth/credentials.py +++ b/flytekit/clis/auth/credentials.py @@ -1,5 +1,4 @@ from __future__ import absolute_import - from flytekit.clis.auth.auth import AuthorizationClient as _AuthorizationClient from flytekit.clis.auth.discovery import DiscoveryClient as _DiscoveryClient @@ -28,11 +27,9 @@ def _get_discovery_endpoint(): def get_client(): global _authorization_client - if _authorization_client is not None: + if _authorization_client is not None and not _authorization_client.expired: return _authorization_client - discovery_endpoint = _get_discovery_endpoint() - discovery_client = _DiscoveryClient(discovery_url=discovery_endpoint) - authorization_endpoints = discovery_client.get_authorization_endpoints() + authorization_endpoints = get_authorization_endpoints() _authorization_client =\ _AuthorizationClient(redirect_uri=_REDIRECT_URI.get(), client_id=_CLIENT_ID.get(), From d200b55d0f6815057ff976fe63a74fc3b27a6042 Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Thu, 5 Dec 2019 17:20:18 -0800 Subject: [PATCH 37/40] add expiration handling --- flytekit/clients/raw.py | 1 - flytekit/clis/auth/auth.py | 20 +++++++------------- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index e58b4a6e88..88c79f8157 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -272,7 +272,6 @@ def list_workflow_ids_paginated(self, identifier_list_request): :rtype: flyteidl.admin.common_pb2.NamedEntityIdentifierList :raises: TODO """ - _logging.warn("hi katrina, metadata is {}".format(self._metadata)) return self._stub.ListWorkflowIds(identifier_list_request, metadata=self._metadata) @_handle_rpc_error diff --git a/flytekit/clis/auth/auth.py b/flytekit/clis/auth/auth.py index 005371d180..33f201bac1 100644 --- a/flytekit/clis/auth/auth.py +++ b/flytekit/clis/auth/auth.py @@ -37,7 +37,6 @@ # Identifies the key used for storing and fetching from keyring. In our case, instead of a username as the keyring docs # suggest, we are storing a user's oidc. _keyring_access_token_storage_key = "access_token" -_keyring_id_token_storage_key = "id_token" _keyring_refresh_token_storage_key = "refresh_token" @@ -131,18 +130,13 @@ def handle_authorization_code(self, auth_code): class Credentials(object): - def __init__(self, access_token=None, id_token=None): + def __init__(self, access_token=None): self._access_token = access_token - self._id_token = id_token @property def access_token(self): return self._access_token - @property - def id_token(self): - return self._id_token - class AuthorizationClient(object): def __init__(self, auth_endpoint=None, token_endpoint=None, client_id=None, redirect_uri=None): @@ -174,9 +168,8 @@ def __init__(self, auth_endpoint=None, token_endpoint=None, client_id=None, redi # Prefer to use already-fetched token values when they've been set globally. self._refresh_token = _keyring.get_password(_keyring_service_name, _keyring_refresh_token_storage_key) access_token = _keyring.get_password(_keyring_service_name, _keyring_access_token_storage_key) - id_token = _keyring.get_password(_keyring_service_name, _keyring_id_token_storage_key) - if access_token and id_token: - self._credentials = Credentials(access_token=access_token, id_token=id_token) + if access_token: + self._credentials = Credentials(access_token=access_token) return # In the absence of globally-set token values, initiate the token request flow @@ -223,13 +216,11 @@ def _initialize_credentials(self, auth_token_resp): self._refresh_token = response_body["refresh_token"] access_token = response_body["access_token"] - id_token = response_body["id_token"] refresh_token = response_body["refresh_token"] _keyring.set_password(_keyring_service_name, _keyring_access_token_storage_key, access_token) - _keyring.set_password(_keyring_service_name, _keyring_id_token_storage_key, id_token) _keyring.set_password(_keyring_service_name, _keyring_refresh_token_storage_key, refresh_token) - self._credentials = Credentials(access_token=access_token, id_token=id_token) + self._credentials = Credentials(access_token=access_token) def request_access_token(self, auth_code): if self._state != auth_code.state: @@ -268,6 +259,9 @@ def refresh_access_token(self): self._expired = True # In the absence of a successful response, assume the refresh token is expired. This should indicate # to the caller that the AuthorizationClient is defunct and a new one needs to be re-initialized. + + _keyring.delete_password(_keyring_service_name, _keyring_access_token_storage_key) + _keyring.delete_password(_keyring_service_name, _keyring_refresh_token_storage_key) return self._initialize_credentials(resp) From feb403eb6db15d34d903d2dde3138079617ebaeb Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Thu, 5 Dec 2019 17:21:22 -0800 Subject: [PATCH 38/40] bump flytekit version --- flytekit/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/__init__.py b/flytekit/__init__.py index b1dbc2b8aa..0f05c59553 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -1,4 +1,4 @@ from __future__ import absolute_import import flytekit.plugins -__version__ = '0.3.1' +__version__ = '0.4.0b0' From fc56ce6405abc051aa52044bada67eb20a4f632f Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Fri, 6 Dec 2019 09:10:47 -0800 Subject: [PATCH 39/40] no need to force twice, now that it's been added to the base client --- flytekit/clis/sdk_in_container/basic_auth.py | 9 +++++---- flytekit/engines/flyte/engine.py | 3 --- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/flytekit/clis/sdk_in_container/basic_auth.py b/flytekit/clis/sdk_in_container/basic_auth.py index f3211b97d5..05671752e0 100644 --- a/flytekit/clis/sdk_in_container/basic_auth.py +++ b/flytekit/clis/sdk_in_container/basic_auth.py @@ -39,9 +39,9 @@ def get_basic_authorization_header(client_id, client_secret): def get_token(token_endpoint, authorization_header, scope): """ - :param token_endpoint: - :param authorization_header: - :param scope: + :param Text token_endpoint: + :param Text authorization_header: This is the value for the "Authorization" key. (eg 'Bearer abc123') + :param Text scope: :rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration in seconds """ @@ -53,8 +53,9 @@ def get_token(token_endpoint, authorization_header, scope): } body = { 'grant_type': 'client_credentials', - 'scope': scope, } + if scope is not None: + body['scope'] = scope response = _requests.post(token_endpoint, data=body, headers=headers) if response.status_code != 200: _logging.error("Non-200 ({}) received from IDP: {}".format(response.status_code, response.text)) diff --git a/flytekit/engines/flyte/engine.py b/flytekit/engines/flyte/engine.py index a326274b6a..bfe76bf24a 100644 --- a/flytekit/engines/flyte/engine.py +++ b/flytekit/engines/flyte/engine.py @@ -33,9 +33,6 @@ def __init__(self, *args, **kwargs): # TODO: use cases. if type(self)._CLIENT is None: c = _SynchronousFlyteClient(*args, **kwargs) - if _platform_config.AUTH.get(): - # Force authentication - c.force_auth_flow() type(self)._CLIENT = c @property From bd9bbb6d1f81df9847b95a30f0e0623944c37582 Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Fri, 6 Dec 2019 11:39:23 -0800 Subject: [PATCH 40/40] use non-beta release --- flytekit/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/__init__.py b/flytekit/__init__.py index 0f05c59553..4919d93749 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -1,4 +1,4 @@ from __future__ import absolute_import import flytekit.plugins -__version__ = '0.4.0b0' +__version__ = '0.4.0'