diff --git a/flytekit/__init__.py b/flytekit/__init__.py index b1dbc2b8aa..4919d93749 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -1,4 +1,4 @@ from __future__ import absolute_import import flytekit.plugins -__version__ = '0.3.1' +__version__ = '0.4.0' diff --git a/flytekit/clients/helpers.py b/flytekit/clients/helpers.py index 783d801df0..10640b6d74 100644 --- a/flytekit/clients/helpers.py +++ b/flytekit/clients/helpers.py @@ -1,4 +1,7 @@ +from flytekit.clis.auth import credentials as _credentials_access + + def iterate_node_executions( client, @@ -75,3 +78,4 @@ def iterate_task_executions(client, node_execution_identifier, limit=None, filte if not next_token: break token = next_token + diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index fc4cc82551..88c79f8157 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -1,15 +1,90 @@ from __future__ import absolute_import + from grpc import insecure_channel as _insecure_channel, secure_channel as _secure_channel, RpcError as _RpcError, \ StatusCode as _GrpcStatusCode, ssl_channel_credentials as _ssl_channel_credentials from flyteidl.service import admin_pb2_grpc as _admin_service from flytekit.common.exceptions import user as _user_exceptions +from flytekit.configuration.platform import AUTH as _AUTH +from flytekit.configuration.creds import ( + CLIENT_ID as _CLIENT_ID, + CLIENT_CREDENTIALS_SCOPE as _SCOPE, +) +from flytekit.clis.sdk_in_container import basic_auth as _basic_auth +import logging as _logging import six as _six +from flytekit.configuration import creds as _creds_config, platform as _platform_config + +from flytekit.clis.auth import credentials as _credentials_access + + +def _refresh_credentials_standard(flyte_client): + """ + This function is used when the configuration value for AUTH_MODE is set to 'standard'. + This either fetches the existing access token or initiates the flow to request a valid access token and store it. + :param flyte_client: RawSynchronousFlyteClient + :return: + """ + + _credentials_access.get_client().refresh_access_token() + flyte_client.set_access_token(_credentials_access.get_client().credentials.access_token) + + +def _refresh_credentials_basic(flyte_client): + """ + This function is used by the _handle_rpc_error decorator, depending on the AUTH_MODE config object. This handler + is meant for SDK use-cases of auth (like pyflyte, or when users call SDK functions that require access to Admin, + like when waiting for another workflow to complete from within a task). This function uses basic auth, which means + the credentials for basic auth must be present from wherever this code is running. + + :param flyte_client: RawSynchronousFlyteClient + :return: + """ + auth_endpoints = _credentials_access.get_authorization_endpoints() + token_endpoint = auth_endpoints.token_endpoint + client_secret = _basic_auth.get_secret() + _logging.debug('Basic authorization flow with client id {} scope {}', _CLIENT_ID.get(), _SCOPE.get()) + authorization_header = _basic_auth.get_basic_authorization_header(_CLIENT_ID.get(), client_secret) + token, expires_in = _basic_auth.get_token(token_endpoint, authorization_header, _SCOPE.get()) + _logging.info('Retrieved new token, expires in {}'.format(expires_in)) + flyte_client.set_access_token(token) + + +def _refresh_credentials_noop(flyte_client): + pass + + +def _get_refresh_handler(auth_mode): + if auth_mode == "standard": + return _refresh_credentials_standard + elif auth_mode == "basic": + return _refresh_credentials_basic + else: + raise ValueError( + "Invalid auth mode [{}] specified. Please update the creds config to use a valid value".format(auth_mode)) def _handle_rpc_error(fn): def handler(*args, **kwargs): + """ + Wraps rpc errors as Flyte exceptions and handles authentication the client. + :param args: + :param kwargs: + :return: + """ + retries = 2 try: - return fn(*args, **kwargs) + for i in range(retries): + try: + return fn(*args, **kwargs) + except _RpcError as e: + if e.code() == _GrpcStatusCode.UNAUTHENTICATED: + if i == (retries - 1): + # Exit the loop and wrap the authentication error. + raise _user_exceptions.FlyteAuthenticationException(_six.text_type(e)) + refresh_handler_fn = _get_refresh_handler(_creds_config.AUTH_MODE.get()) + refresh_handler_fn(args[0]) + else: + raise except _RpcError as e: if e.code() == _GrpcStatusCode.ALREADY_EXISTS: raise _user_exceptions.FlyteEntityAlreadyExistsException(_six.text_type(e)) @@ -35,10 +110,11 @@ def __init__(self, url, insecure=False, credentials=None, options=None): :param Text credentials: [Optional] If provided, a secure channel will be opened with the Flyte Admin Service. :param dict[Text, Text] options: [Optional] A dict of key-value string pairs for configuring the gRPC core runtime. + :param list[(Text,Text)] metadata: [Optional] metadata pairs to be transmitted to the + service-side of the RPC. """ self._channel = None - # TODO: Revert all the for loops below if insecure: self._channel = _insecure_channel(url, options=list((options or {}).items())) else: @@ -48,6 +124,16 @@ def __init__(self, url, insecure=False, credentials=None, options=None): options=list((options or {}).items()) ) self._stub = _admin_service.AdminServiceStub(self._channel) + self._metadata = None + if _AUTH.get(): + self.force_auth_flow() + + def set_access_token(self, access_token): + self._metadata = [(_creds_config.AUTHORIZATION_METADATA_KEY.get(), "Bearer {}".format(access_token))] + + def force_auth_flow(self): + refresh_handler_fn = _get_refresh_handler(_creds_config.AUTH_MODE.get()) + refresh_handler_fn(self) #################################################################################################################### # @@ -74,7 +160,7 @@ def create_task(self, task_create_request): task is already registered. :raises grpc.RpcError: """ - return self._stub.CreateTask(task_create_request) + return self._stub.CreateTask(task_create_request, metadata=self._metadata) @_handle_rpc_error def list_task_ids_paginated(self, identifier_list_request): @@ -100,7 +186,7 @@ def list_task_ids_paginated(self, identifier_list_request): :rtype: flyteidl.admin.common_pb2.NamedEntityIdentifierList :raises: TODO """ - return self._stub.ListTaskIds(identifier_list_request) + return self._stub.ListTaskIds(identifier_list_request, metadata=self._metadata) @_handle_rpc_error def list_tasks_paginated(self, resource_list_request): @@ -122,7 +208,7 @@ def list_tasks_paginated(self, resource_list_request): :rtype: flyteidl.admin.task_pb2.TaskList :raises: TODO """ - return self._stub.ListTasks(resource_list_request) + return self._stub.ListTasks(resource_list_request, metadata=self._metadata) @_handle_rpc_error def get_task(self, get_object_request): @@ -133,7 +219,7 @@ def get_task(self, get_object_request): :rtype: flyteidl.admin.task_pb2.Task :raises: TODO """ - return self._stub.GetTask(get_object_request) + return self._stub.GetTask(get_object_request, metadata=self._metadata) #################################################################################################################### # @@ -160,7 +246,7 @@ def create_workflow(self, workflow_create_request): identical workflow is already registered. :raises grpc.RpcError: """ - return self._stub.CreateWorkflow(workflow_create_request) + return self._stub.CreateWorkflow(workflow_create_request, metadata=self._metadata) @_handle_rpc_error def list_workflow_ids_paginated(self, identifier_list_request): @@ -186,7 +272,7 @@ def list_workflow_ids_paginated(self, identifier_list_request): :rtype: flyteidl.admin.common_pb2.NamedEntityIdentifierList :raises: TODO """ - return self._stub.ListWorkflowIds(identifier_list_request) + return self._stub.ListWorkflowIds(identifier_list_request, metadata=self._metadata) @_handle_rpc_error def list_workflows_paginated(self, resource_list_request): @@ -208,7 +294,7 @@ def list_workflows_paginated(self, resource_list_request): :rtype: flyteidl.admin.workflow_pb2.WorkflowList :raises: TODO """ - return self._stub.ListWorkflows(resource_list_request) + return self._stub.ListWorkflows(resource_list_request, metadata=self._metadata) @_handle_rpc_error def get_workflow(self, get_object_request): @@ -219,7 +305,7 @@ def get_workflow(self, get_object_request): :rtype: flyteidl.admin.workflow_pb2.Workflow :raises: TODO """ - return self._stub.GetWorkflow(get_object_request) + return self._stub.GetWorkflow(get_object_request, metadata=self._metadata) #################################################################################################################### # @@ -247,7 +333,7 @@ def create_launch_plan(self, launch_plan_create_request): the identical launch plan is already registered. :raises grpc.RpcError: """ - return self._stub.CreateLaunchPlan(launch_plan_create_request) + return self._stub.CreateLaunchPlan(launch_plan_create_request, metadata=self._metadata) # TODO: List endpoints when they come in @@ -259,7 +345,7 @@ def get_launch_plan(self, object_get_request): :param flyteidl.admin.common_pb2.ObjectGetRequest object_get_request: :rtype: flyteidl.admin.launch_plan_pb2.LaunchPlan """ - return self._stub.GetLaunchPlan(object_get_request) + return self._stub.GetLaunchPlan(object_get_request, metadata=self._metadata) @_handle_rpc_error def get_active_launch_plan(self, active_launch_plan_request): @@ -269,7 +355,7 @@ def get_active_launch_plan(self, active_launch_plan_request): :param flyteidl.admin.common_pb2.ActiveLaunchPlanRequest active_launch_plan_request: :rtype: flyteidl.admin.launch_plan_pb2.LaunchPlan """ - return self._stub.GetActiveLaunchPlan(active_launch_plan_request) + return self._stub.GetActiveLaunchPlan(active_launch_plan_request, metadata=self._metadata) @_handle_rpc_error def update_launch_plan(self, update_request): @@ -280,7 +366,7 @@ def update_launch_plan(self, update_request): :param flyteidl.admin.launch_plan_pb2.LaunchPlanUpdateRequest update_request: :rtype: flyteidl.admin.launch_plan_pb2.LaunchPlanUpdateResponse """ - return self._stub.UpdateLaunchPlan(update_request) + return self._stub.UpdateLaunchPlan(update_request, metadata=self._metadata) @_handle_rpc_error def list_launch_plan_ids_paginated(self, identifier_list_request): @@ -290,7 +376,7 @@ def list_launch_plan_ids_paginated(self, identifier_list_request): :param: flyteidl.admin.common_pb2.NamedEntityIdentifierListRequest identifier_list_request: :rtype: flyteidl.admin.common_pb2.NamedEntityIdentifierList """ - return self._stub.ListLaunchPlanIds(identifier_list_request) + return self._stub.ListLaunchPlanIds(identifier_list_request, metadata=self._metadata) @_handle_rpc_error def list_launch_plans_paginated(self, resource_list_request): @@ -300,7 +386,7 @@ def list_launch_plans_paginated(self, resource_list_request): :param: flyteidl.admin.common_pb2.ResourceListRequest resource_list_request: :rtype: flyteidl.admin.launch_plan_pb2.LaunchPlanList """ - return self._stub.ListLaunchPlans(resource_list_request) + return self._stub.ListLaunchPlans(resource_list_request, metadata=self._metadata) @_handle_rpc_error def list_active_launch_plans_paginated(self, active_launch_plan_list_request): @@ -310,7 +396,7 @@ def list_active_launch_plans_paginated(self, active_launch_plan_list_request): :param: flyteidl.admin.common_pb2.ActiveLaunchPlanListRequest active_launch_plan_list_request: :rtype: flyteidl.admin.launch_plan_pb2.LaunchPlanList """ - return self._stub.ListActiveLaunchPlans(active_launch_plan_list_request) + return self._stub.ListActiveLaunchPlans(active_launch_plan_list_request, metadata=self._metadata) #################################################################################################################### # @@ -325,7 +411,7 @@ def create_execution(self, create_execution_request): :param flyteidl.admin.execution_pb2.ExecutionCreateRequest create_execution_request: :rtype: flyteidl.admin.execution_pb2.ExecutionCreateResponse """ - return self._stub.CreateExecution(create_execution_request) + return self._stub.CreateExecution(create_execution_request, metadata=self._metadata) @_handle_rpc_error def get_execution(self, get_object_request): @@ -335,7 +421,7 @@ def get_execution(self, get_object_request): :param flyteidl.admin.execution_pb2.WorkflowExecutionGetRequest get_object_request: :rtype: flyteidl.admin.execution_pb2.Execution """ - return self._stub.GetExecution(get_object_request) + return self._stub.GetExecution(get_object_request, metadata=self._metadata) @_handle_rpc_error def get_execution_data(self, get_execution_data_request): @@ -355,7 +441,7 @@ def list_executions_paginated(self, resource_list_request): :param flyteidl.admin.common_pb2.ResourceListRequest resource_list_request: :rtype: flyteidl.admin.execution_pb2.ExecutionList """ - return self._stub.ListExecutions(resource_list_request) + return self._stub.ListExecutions(resource_list_request, metadata=self._metadata) @_handle_rpc_error def terminate_execution(self, terminate_execution_request): @@ -363,7 +449,7 @@ def terminate_execution(self, terminate_execution_request): :param flyteidl.admin.execution_pb2.TerminateExecutionRequest terminate_execution_request: :rtype: flyteidl.admin.execution_pb2.TerminateExecutionResponse """ - return self._stub.TerminateExecution(terminate_execution_request) + return self._stub.TerminateExecution(terminate_execution_request, metadata=self._metadata) @_handle_rpc_error def relaunch_execution(self, relaunch_execution_request): @@ -385,7 +471,7 @@ def get_node_execution(self, node_execution_request): :param flyteidl.admin.node_execution_pb2.NodeExecutionGetRequest node_execution_request: :rtype: flyteidl.admin.node_execution_pb2.NodeExecution """ - return self._stub.GetNodeExecution(node_execution_request) + return self._stub.GetNodeExecution(node_execution_request, metadata=self._metadata) @_handle_rpc_error def get_node_execution_data(self, get_node_execution_data_request): @@ -403,7 +489,7 @@ def list_node_executions_paginated(self, node_execution_list_request): :param flyteidl.admin.node_execution_pb2.NodeExecutionListRequest node_execution_list_request: :rtype: flyteidl.admin.node_execution_pb2.NodeExecutionList """ - return self._stub.ListNodeExecutions(node_execution_list_request) + return self._stub.ListNodeExecutions(node_execution_list_request, metadata=self._metadata) @_handle_rpc_error def list_node_executions_for_task_paginated(self, node_execution_for_task_list_request): @@ -411,7 +497,7 @@ def list_node_executions_for_task_paginated(self, node_execution_for_task_list_r :param flyteidl.admin.node_execution_pb2.NodeExecutionListRequest node_execution_for_task_list_request: :rtype: flyteidl.admin.node_execution_pb2.NodeExecutionList """ - return self._stub.ListNodeExecutionsForTask(node_execution_for_task_list_request) + return self._stub.ListNodeExecutionsForTask(node_execution_for_task_list_request, metadata=self._metadata) #################################################################################################################### # @@ -425,7 +511,7 @@ def get_task_execution(self, task_execution_request): :param flyteidl.admin.task_execution_pb2.TaskExecutionGetRequest task_execution_request: :rtype: flyteidl.admin.task_execution_pb2.TaskExecution """ - return self._stub.GetTaskExecution(task_execution_request) + return self._stub.GetTaskExecution(task_execution_request, metadata=self._metadata) @_handle_rpc_error def get_task_execution_data(self, get_task_execution_data_request): @@ -443,7 +529,7 @@ def list_task_executions_paginated(self, task_execution_list_request): :param flyteidl.admin.task_execution_pb2.TaskExecutionListRequest task_execution_list_request: :rtype: flyteidl.admin.task_execution_pb2.TaskExecutionList """ - return self._stub.ListTaskExecutions(task_execution_list_request) + return self._stub.ListTaskExecutions(task_execution_list_request, metadata=self._metadata) #################################################################################################################### # @@ -458,7 +544,7 @@ def list_projects(self, project_list_request): :param flyteidl.admin.project_pb2.ProjectListRequest project_list_request: :rtype: flyteidl.admin.project_pb2.Projects """ - return self._stub.ListProjects(project_list_request) + return self._stub.ListProjects(project_list_request, metadata=self._metadata) @_handle_rpc_error def register_project(self, project_register_request): @@ -467,7 +553,7 @@ def register_project(self, project_register_request): :param flyteidl.admin.project_pb2.ProjectRegisterRequest project_register_request: :rtype: flyteidl.admin.project_pb2.ProjectRegisterResponse """ - return self._stub.RegisterProject(project_register_request) + return self._stub.RegisterProject(project_register_request, metadata=self._metadata) #################################################################################################################### # diff --git a/flytekit/clis/auth/__init__.py b/flytekit/clis/auth/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flytekit/clis/auth/auth.py b/flytekit/clis/auth/auth.py new file mode 100644 index 0000000000..33f201bac1 --- /dev/null +++ b/flytekit/clis/auth/auth.py @@ -0,0 +1,280 @@ +import base64 as _base64 +import hashlib as _hashlib +import keyring as _keyring +import os as _os +import re as _re +import requests as _requests +import webbrowser as _webbrowser + +from multiprocessing import Process as _Process, Queue as _Queue + +try: # Python 3.5+ + from http import HTTPStatus as _StatusCodes +except ImportError: + try: # Python 3 + from http import client as _StatusCodes + except ImportError: # Python 2 + import httplib as _StatusCodes +try: # Python 3 + import http.server as _BaseHTTPServer +except ImportError: # Python 2 + import BaseHTTPServer as _BaseHTTPServer + +try: # Python 3 + import urllib.parse as _urlparse + from urllib.parse import urlencode as _urlencode +except ImportError: # Python 2 + import urlparse as _urlparse + from urllib import urlencode as _urlencode + +_code_verifier_length = 64 +_random_seed_length = 40 +_utf_8 = 'utf-8' + + +# Identifies the service used for storing passwords in keyring +_keyring_service_name = "flyteauth" +# Identifies the key used for storing and fetching from keyring. In our case, instead of a username as the keyring docs +# suggest, we are storing a user's oidc. +_keyring_access_token_storage_key = "access_token" +_keyring_refresh_token_storage_key = "refresh_token" + + +def _generate_code_verifier(): + """ + Generates a 'code_verifier' as described in https://tools.ietf.org/html/rfc7636#section-4.1 + Adapted from https://github.com/openstack/deb-python-oauth2client/blob/master/oauth2client/_pkce.py. + :return str: + """ + code_verifier = _base64.urlsafe_b64encode(_os.urandom(_code_verifier_length)).decode(_utf_8) + # Eliminate invalid characters. + code_verifier = _re.sub(r'[^a-zA-Z0-9_\-.~]+', '', code_verifier) + if len(code_verifier) < 43: + raise ValueError("Verifier too short. number of bytes must be > 30.") + elif len(code_verifier) > 128: + raise ValueError("Verifier too long. number of bytes must be < 97.") + return code_verifier + + +def _generate_state_parameter(): + state = _base64.urlsafe_b64encode(_os.urandom(_random_seed_length)).decode(_utf_8) + # Eliminate invalid characters. + code_verifier = _re.sub('[^a-zA-Z0-9-_.,]+', '', state) + return code_verifier + + +def _create_code_challenge(code_verifier): + """ + Adapted from https://github.com/openstack/deb-python-oauth2client/blob/master/oauth2client/_pkce.py. + :param str code_verifier: represents a code verifier generated by generate_code_verifier() + :return str: urlsafe base64-encoded sha256 hash digest + """ + code_challenge = _hashlib.sha256(code_verifier.encode(_utf_8)).digest() + code_challenge = _base64.urlsafe_b64encode(code_challenge).decode(_utf_8) + # Eliminate invalid characters + code_challenge = code_challenge.replace('=', '') + return code_challenge + + +class AuthorizationCode(object): + def __init__(self, code, state): + self._code = code + self._state = state + + @property + def code(self): + return self._code + + @property + def state(self): + return self._state + + +class OAuthCallbackHandler(_BaseHTTPServer.BaseHTTPRequestHandler): + """ + A simple wrapper around BaseHTTPServer.BaseHTTPRequestHandler that handles a callback URL that accepts an + authorization token. + """ + + def do_GET(self): + url = _urlparse.urlparse(self.path) + if url.path == self.server.redirect_path: + self.send_response(_StatusCodes.OK) + self.end_headers() + self.handle_login(dict(_urlparse.parse_qsl(url.query))) + else: + self.send_response(_StatusCodes.NOT_FOUND) + + def handle_login(self, data): + self.server.handle_authorization_code(AuthorizationCode(data['code'], data['state'])) + + +class OAuthHTTPServer(_BaseHTTPServer.HTTPServer): + """ + A simple wrapper around the BaseHTTPServer.HTTPServer implementation that binds an authorization_client for handling + authorization code callbacks. + """ + def __init__(self, server_address, RequestHandlerClass, bind_and_activate=True, + redirect_path=None, queue=None): + _BaseHTTPServer.HTTPServer.__init__(self, server_address, RequestHandlerClass, bind_and_activate) + self._redirect_path = redirect_path + self._auth_code = None + self._queue = queue + + @property + def redirect_path(self): + return self._redirect_path + + def handle_authorization_code(self, auth_code): + self._queue.put(auth_code) + + +class Credentials(object): + def __init__(self, access_token=None): + self._access_token = access_token + + @property + def access_token(self): + return self._access_token + + +class AuthorizationClient(object): + def __init__(self, auth_endpoint=None, token_endpoint=None, client_id=None, redirect_uri=None): + self._auth_endpoint = auth_endpoint + self._token_endpoint = token_endpoint + self._client_id = client_id + self._redirect_uri = redirect_uri + self._code_verifier = _generate_code_verifier() + code_challenge = _create_code_challenge(self._code_verifier) + self._code_challenge = code_challenge + state = _generate_state_parameter() + self._state = state + self._credentials = None + self._refresh_token = None + self._headers = {'content-type': "application/x-www-form-urlencoded"} + self._expired = False + + self._params = { + "client_id": client_id, # This must match the Client ID of the OAuth application. + "response_type": "code", # Indicates the authorization code grant + "scope": "openid offline_access", # ensures that the /token endpoint returns an ID and refresh token + # callback location where the user-agent will be directed to. + "redirect_uri": self._redirect_uri, + "state": state, + "code_challenge": code_challenge, + "code_challenge_method": "S256", + } + + # Prefer to use already-fetched token values when they've been set globally. + self._refresh_token = _keyring.get_password(_keyring_service_name, _keyring_refresh_token_storage_key) + access_token = _keyring.get_password(_keyring_service_name, _keyring_access_token_storage_key) + if access_token: + self._credentials = Credentials(access_token=access_token) + return + + # In the absence of globally-set token values, initiate the token request flow + q = _Queue() + # First prepare the callback server in the background + server = self._create_callback_server(q) + server_process = _Process(target=server.handle_request) + server_process.start() + + # Send the call to request the authorization code + self._request_authorization_code() + + # Request the access token once the auth code has been received. + auth_code = q.get() + server_process.terminate() + self.request_access_token(auth_code) + + def _create_callback_server(self, q): + server_url = _urlparse.urlparse(self._redirect_uri) + server_address = (server_url.hostname, server_url.port) + return OAuthHTTPServer(server_address, OAuthCallbackHandler, redirect_path=server_url.path, queue=q) + + def _request_authorization_code(self): + scheme, netloc, path, _, _, _ = _urlparse.urlparse(self._auth_endpoint) + query = _urlencode(self._params) + endpoint = _urlparse.urlunparse((scheme, netloc, path, None, query, None)) + _webbrowser.open_new_tab(endpoint) + + def _initialize_credentials(self, auth_token_resp): + + """ + The auth_token_resp body is of the form: + { + "access_token": "foo", + "refresh_token": "bar", + "id_token": "baz", + "token_type": "Bearer" + } + """ + response_body = auth_token_resp.json() + if "access_token" not in response_body: + raise ValueError('Expected "access_token" in response from oauth server') + if "refresh_token" in response_body: + self._refresh_token = response_body["refresh_token"] + + access_token = response_body["access_token"] + refresh_token = response_body["refresh_token"] + + _keyring.set_password(_keyring_service_name, _keyring_access_token_storage_key, access_token) + _keyring.set_password(_keyring_service_name, _keyring_refresh_token_storage_key, refresh_token) + self._credentials = Credentials(access_token=access_token) + + def request_access_token(self, auth_code): + if self._state != auth_code.state: + raise ValueError("Unexpected state parameter [{}] passed".format(auth_code.state)) + self._params.update({ + "code": auth_code.code, + "code_verifier": self._code_verifier, + "grant_type": "authorization_code", + }) + resp = _requests.post( + url=self._token_endpoint, + data=self._params, + headers=self._headers, + allow_redirects=False + ) + if resp.status_code != _StatusCodes.OK: + # TODO: handle expected (?) error cases: + # https://auth0.com/docs/flows/guides/device-auth/call-api-device-auth#token-responses + raise Exception('Failed to request access token with response: [{}] {}'.format( + resp.status_code, resp.content)) + self._initialize_credentials(resp) + + def refresh_access_token(self): + if self._refresh_token is None: + raise ValueError("no refresh token available with which to refresh authorization credentials") + + resp = _requests.post( + url=self._token_endpoint, + data={'grant_type': 'refresh_token', + 'client_id': self._client_id, + 'refresh_token': self._refresh_token}, + headers=self._headers, + allow_redirects=False + ) + if resp.status_code != _StatusCodes.OK: + self._expired = True + # In the absence of a successful response, assume the refresh token is expired. This should indicate + # to the caller that the AuthorizationClient is defunct and a new one needs to be re-initialized. + + _keyring.delete_password(_keyring_service_name, _keyring_access_token_storage_key) + _keyring.delete_password(_keyring_service_name, _keyring_refresh_token_storage_key) + return + self._initialize_credentials(resp) + + @property + def credentials(self): + """ + :return flytekit.clis.auth.auth.Credentials: + """ + return self._credentials + + @property + def expired(self): + """ + :return bool: + """ + return self._expired diff --git a/flytekit/clis/auth/credentials.py b/flytekit/clis/auth/credentials.py new file mode 100644 index 0000000000..8be3d91c77 --- /dev/null +++ b/flytekit/clis/auth/credentials.py @@ -0,0 +1,46 @@ +from __future__ import absolute_import +from flytekit.clis.auth.auth import AuthorizationClient as _AuthorizationClient +from flytekit.clis.auth.discovery import DiscoveryClient as _DiscoveryClient + +from flytekit.configuration.creds import ( + REDIRECT_URI as _REDIRECT_URI, + CLIENT_ID as _CLIENT_ID +) +from flytekit.configuration.platform import URL as _URL, INSECURE as _INSECURE + +try: # Python 3 + import urllib.parse as _urlparse +except ImportError: # Python 2 + import urlparse as _urlparse + +# Default, well known-URI string used for fetching JSON metadata. See https://tools.ietf.org/html/rfc8414#section-3. +discovery_endpoint_path = ".well-known/oauth-authorization-server" + + +def _get_discovery_endpoint(): + if _INSECURE.get(): + return _urlparse.urljoin('http://{}/'.format(_URL.get()), discovery_endpoint_path) + return _urlparse.urljoin('https://{}/'.format(_URL.get()), discovery_endpoint_path) + + +# Lazy initialized authorization client singleton +_authorization_client = None + + +def get_client(): + global _authorization_client + if _authorization_client is not None and not _authorization_client.expired: + return _authorization_client + authorization_endpoints = get_authorization_endpoints() + + _authorization_client =\ + _AuthorizationClient(redirect_uri=_REDIRECT_URI.get(), client_id=_CLIENT_ID.get(), + auth_endpoint=authorization_endpoints.auth_endpoint, + token_endpoint=authorization_endpoints.token_endpoint) + return _authorization_client + + +def get_authorization_endpoints(): + discovery_endpoint = _get_discovery_endpoint() + discovery_client = _DiscoveryClient(discovery_url=discovery_endpoint) + return discovery_client.get_authorization_endpoints() diff --git a/flytekit/clis/auth/discovery.py b/flytekit/clis/auth/discovery.py new file mode 100644 index 0000000000..fce6988da2 --- /dev/null +++ b/flytekit/clis/auth/discovery.py @@ -0,0 +1,68 @@ +import requests as _requests +import logging + +try: # Python 3.5+ + from http import HTTPStatus as _StatusCodes +except ImportError: + try: # Python 3 + from http import client as _StatusCodes + except ImportError: # Python 2 + import httplib as _StatusCodes + +# These response keys are defined in https://tools.ietf.org/id/draft-ietf-oauth-discovery-08.html. +_authorization_endpoint_key = "authorization_endpoint" +_token_endpoint_key = "token_endpoint" + + +class AuthorizationEndpoints(object): + """ + A simple wrapper around commonly discovered endpoints used for the PKCE auth flow. + """ + def __init__(self, auth_endpoint=None, token_endpoint=None): + self._auth_endpoint = auth_endpoint + self._token_endpoint = token_endpoint + + @property + def auth_endpoint(self): + return self._auth_endpoint + + @property + def token_endpoint(self): + return self._token_endpoint + + +class DiscoveryClient(object): + """ + Discovers well known OpenID configuration and parses out authorization endpoints required for initiating the PKCE + auth flow. + """ + + def __init__(self, discovery_url=None): + logging.debug("Initializing discovery client with {}".format(discovery_url)) + self._discovery_url = discovery_url + self._authorization_endpoints = None + + @property + def authorization_endpoints(self): + """ + :rtype: flytekit.clis.auth.discovery.AuthorizationEndpoints: + """ + return self._authorization_endpoints + + def get_authorization_endpoints(self): + if self.authorization_endpoints is not None: + return self.authorization_endpoints + resp = _requests.get( + url=self._discovery_url, + ) + + response_body = resp.json() + if response_body[_authorization_endpoint_key] is None: + raise ValueError('Unable to discover authorization endpoint') + + if response_body[_token_endpoint_key] is None: + raise ValueError('Unable to discover token endpoint') + + self._authorization_endpoints = AuthorizationEndpoints(auth_endpoint=response_body[_authorization_endpoint_key], + token_endpoint=response_body[_token_endpoint_key]) + return self.authorization_endpoints diff --git a/flytekit/clis/flyte_cli/example.config b/flytekit/clis/flyte_cli/example.config new file mode 100644 index 0000000000..cd0bc07223 --- /dev/null +++ b/flytekit/clis/flyte_cli/example.config @@ -0,0 +1,11 @@ +# This is an example of what a config file may look like. +# Place this in ~/.flyte/config and flyte-cli will pick it up automatically + +[platform] +url=flyte.company.com +auth=true + +[credentials] +discovery_endpoint=http://corp.idp.com/.well-known/oauth-authorization-server +client_id=123abc123 +redirect_uri=http://localhost:53593/callback diff --git a/flytekit/clis/flyte_cli/main.py b/flytekit/clis/flyte_cli/main.py index cbe796f0a4..2fbc9bf660 100644 --- a/flytekit/clis/flyte_cli/main.py +++ b/flytekit/clis/flyte_cli/main.py @@ -7,6 +7,7 @@ import click as _click import six as _six + from flyteidl.core import literals_pb2 as _literals_pb2 from flytekit import __version__ @@ -19,6 +20,7 @@ from flytekit.common.types import helpers as _type_helpers from flytekit.common.utils import load_proto_from_file as _load_proto_from_file from flytekit.configuration import platform as _platform_config +from flytekit.configuration import set_flyte_config_file from flytekit.interfaces.data import data_proxy as _data_proxy from flytekit.models import common as _common_models, filters as _filters, launch_plan as _launch_plan, literals as \ _literals @@ -30,11 +32,30 @@ _tt = _six.text_type +# Similar to how kubectl has a config file in the users home directory, this Flyte CLI will also look for one. +# The format of this config file is the same as a workflow's config file, except that the relevant fields are different. +# Please see the example.config file +_default_config_file_path = ".flyte/config" + def _welcome_message(): _click.secho("Welcome to Flyte CLI! Version: {}".format(_tt(__version__)), bold=True) +def _detect_default_config_file(): + home = _os.path.expanduser("~") + config_file = _os.path.join(home, _default_config_file_path) + if home and _os.path.exists(config_file): + _click.secho("Using default config file at {}".format(_tt(config_file)), fg='blue') + set_flyte_config_file(config_file_path=config_file) + else: + _click.secho("Config file not found at default location, relying on environment variables instead", fg='blue') + + +# Run this as the module is loading to pick up settings that click can then use when constructing the commands +_detect_default_config_file() + + def _get_io_string(literal_map, verbose=False): """ :param flytekit.models.literals.LiteralMap literal_map: @@ -408,7 +429,6 @@ class _FlyteSubCommand(_click.Command): 'project': _PROJECT_FLAGS[0], 'domain': _DOMAIN_FLAGS[0], 'name': _NAME_FLAGS[0], - 'host': _HOST_FLAGS[0] } _PASSABLE_FLAGS = { @@ -423,6 +443,16 @@ def make_context(self, cmd_name, args, parent=None): parent.params[param.name] is not None: prefix_args.extend([type(self)._PASSABLE_ARGS[param.name], _six.text_type(parent.params[param.name])]) + # Special handling for the host option, because this option can be specified in the user's ~/.flyte/config + # file, and unlike other options, doesn't get picked up before click commands are run + if param.name == 'host': + if param.name in parent.params and parent.params[param.name] is not None: + prefix_args.extend( + [_HOST_FLAGS[0], _six.text_type(parent.params[param.name])]) + elif _platform_config.URL.get(): + prefix_args.extend( + [_HOST_FLAGS[0], _six.text_type(_platform_config.URL.get())]) + # For flags, we don't append the value of the flag, otherwise click will fail to parse if param.name in type(self)._PASSABLE_FLAGS and \ param.name in parent.params and \ diff --git a/flytekit/clis/sdk_in_container/basic_auth.py b/flytekit/clis/sdk_in_container/basic_auth.py new file mode 100644 index 0000000000..05671752e0 --- /dev/null +++ b/flytekit/clis/sdk_in_container/basic_auth.py @@ -0,0 +1,65 @@ +from __future__ import absolute_import + +import base64 as _base64 +import logging as _logging + +import requests as _requests + +from flytekit.common.exceptions.user import FlyteAuthenticationException as _FlyteAuthenticationException +from flytekit.configuration.creds import ( + CLIENT_CREDENTIALS_SECRET as _CREDENTIALS_SECRET, +) + +_utf_8 = 'utf-8' + + +def get_secret(): + """ + This function will either read in the password from the file path given by the CLIENT_CREDENTIALS_SECRET_LOCATION + config object, or from the environment variable using the CLIENT_CREDENTIALS_SECRET config object. + :rtype: Text + """ + secret = _CREDENTIALS_SECRET.get() + if secret: + return secret + raise _FlyteAuthenticationException('No secret could be found') + + +def get_basic_authorization_header(client_id, client_secret): + """ + This function transforms the client id and the client secret into a header that conforms with http basic auth. + It joins the id and the secret with a : then base64 encodes it, then adds the appropriate text. + :param Text client_id: + :param Text client_secret: + :rtype: Text + """ + concated = "{}:{}".format(client_id, client_secret) + return "Basic {}".format(_base64.b64encode(concated.encode(_utf_8)).decode(_utf_8)) + + +def get_token(token_endpoint, authorization_header, scope): + """ + :param Text token_endpoint: + :param Text authorization_header: This is the value for the "Authorization" key. (eg 'Bearer abc123') + :param Text scope: + :rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration + in seconds + """ + headers = { + 'Authorization': authorization_header, + 'Cache-Control': 'no-cache', + 'Accept': 'application/json', + 'Content-Type': 'application/x-www-form-urlencoded' + } + body = { + 'grant_type': 'client_credentials', + } + if scope is not None: + body['scope'] = scope + response = _requests.post(token_endpoint, data=body, headers=headers) + if response.status_code != 200: + _logging.error("Non-200 ({}) received from IDP: {}".format(response.status_code, response.text)) + raise _FlyteAuthenticationException('Non-200 received from IDP') + + response = response.json() + return response['access_token'], response['expires_in'] diff --git a/flytekit/common/exceptions/user.py b/flytekit/common/exceptions/user.py index d050c05667..d68cb1985f 100644 --- a/flytekit/common/exceptions/user.py +++ b/flytekit/common/exceptions/user.py @@ -72,3 +72,7 @@ class FlyteTimeout(FlyteAssertion): class FlyteRecoverableException(FlyteUserException, _Recoverable): _ERROR_CODE = "USER:Recoverable" + + +class FlyteAuthenticationException(FlyteAssertion): + _ERROR_CODE = "USER:AuthenticationError" diff --git a/flytekit/configuration/common.py b/flytekit/configuration/common.py index 8ef86286d1..54d9ca7ee8 100644 --- a/flytekit/configuration/common.py +++ b/flytekit/configuration/common.py @@ -118,6 +118,20 @@ def __exit__(self, exc_type, exc_val, exc_tb): del _os.environ[self._config.env_var] +def _get_file_contents(location): + """ + This reads an input file, and returns the string contents, and should be used for reading credentials. + This function will also strip newlines. + + :param Text location: The file path holding the client id or secret + :rtype: Text + """ + if _os.path.isfile(location): + with open(location, 'r') as f: + return f.read().replace('\n', '') + return None + + class _FlyteConfigurationEntry(_six.with_metaclass(_abc.ABCMeta, object)): def __init__(self, section, key, default=None, validator=None, fallback=None): @@ -138,6 +152,37 @@ def env_var(self): def _getter(self): pass + def retrieve_value(self): + """ + The logic in this function changes the lookup behavior for all configuration objects before hitting the + configuration file. + + For a given configuration object ('mysection', 'mysetting'), it will now look at this waterfall: + + i.) The environment variable named 'FLYTE_MYSECTION_MYSETTING' + + ii.) The value of the environment variable that is named the value of the environment variable named + 'FLYTE_MYSECTION_MYSETTING'. That is if os.environ['FLYTE_MYSECTION_MYSETTING'] = 'AAA' and + os.environ['AA'] = 'abc', then 'abc' will be the final value. + + iii.) The contents of the file pointed to by the environment variable named 'FLYTE_MYSECTION_MYSETTING', + assuming the value is a file. + + While it is helpful, this pattern does interrupt the manually specified fallback logic, by effective injecting + two more fallbacks behind the scenes. Just keep this in mind as you are using/creating configuration objects. + :rtype: Text + """ + val = _os.environ.get(self.env_var, None) + if val is None: + referenced_env_var = _os.environ.get("{}_FROM_ENV_VAR".format(self.env_var), None) + if referenced_env_var is not None: + val = _os.environ.get(referenced_env_var, None) + if val is None: + referenced_file = _os.environ.get("{}_FROM_FILE".format(self.env_var), None) + if referenced_file is not None: + val = _get_file_contents(referenced_file) + return val + def get(self): val = self._getter() if val is None and self._fallback is not None: @@ -178,7 +223,7 @@ def _validate_not_null(self, val): class FlyteStringConfigurationEntry(_FlyteConfigurationEntry): def _getter(self): - val = _os.environ.get(self.env_var, None) + val = self.retrieve_value() if val is None: val = CONFIGURATION_SINGLETON.get_string(self._section, self._key, default=self._default) return val @@ -186,7 +231,7 @@ def _getter(self): class FlyteIntegerConfigurationEntry(_FlyteConfigurationEntry): def _getter(self): - val = _os.environ.get(self.env_var, None) + val = self.retrieve_value() if val is None: val = CONFIGURATION_SINGLETON.get_int(self._section, self._key, default=self._default) if val is not None: @@ -196,7 +241,7 @@ def _getter(self): class FlyteBoolConfigurationEntry(_FlyteConfigurationEntry): def _getter(self): - val = _os.environ.get(self.env_var, None) + val = self.retrieve_value() if val is None: return CONFIGURATION_SINGLETON.get_bool(self._section, self._key, default=self._default) @@ -209,7 +254,7 @@ def _getter(self): class FlyteStringListConfigurationEntry(_FlyteConfigurationEntry): def _getter(self): - val = _os.environ.get(self.env_var, None) + val = self.retrieve_value() if val is None: val = CONFIGURATION_SINGLETON.get_string(self._section, self._key) if val is None: @@ -219,7 +264,7 @@ def _getter(self): class FlyteRequiredStringConfigurationEntry(_FlyteRequiredConfigurationEntry): def _getter(self): - val = _os.environ.get(self.env_var, None) + val = self.retrieve_value() if val is None: val = CONFIGURATION_SINGLETON.get_string(self._section, self._key, default=self._default) return val @@ -227,7 +272,7 @@ def _getter(self): class FlyteRequiredIntegerConfigurationEntry(_FlyteRequiredConfigurationEntry): def _getter(self): - val = _os.environ.get(self.env_var, None) + val = self.retrieve_value() if val is None: val = CONFIGURATION_SINGLETON.get_int(self._section, self._key, default=self._default) return int(val) @@ -235,7 +280,7 @@ def _getter(self): class FlyteRequiredBoolConfigurationEntry(_FlyteRequiredConfigurationEntry): def _getter(self): - val = _os.environ.get(self.env_var, None) + val = self.retrieve_value() if val is None: val = CONFIGURATION_SINGLETON.get_bool(self._section, self._key, default=self._default) return bool(val) @@ -243,7 +288,7 @@ def _getter(self): class FlyteRequiredStringListConfigurationEntry(_FlyteRequiredConfigurationEntry): def _getter(self): - val = _os.environ.get(self.env_var, None) + val = self.retrieve_value() if val is None: val = CONFIGURATION_SINGLETON.get_string(self._section, self._key) if val is None: diff --git a/flytekit/configuration/creds.py b/flytekit/configuration/creds.py new file mode 100644 index 0000000000..2bc4bd7e74 --- /dev/null +++ b/flytekit/configuration/creds.py @@ -0,0 +1,53 @@ +from __future__ import absolute_import + +from flytekit.configuration import common as _config_common + +CLIENT_ID = _config_common.FlyteStringConfigurationEntry('credentials', 'client_id', default=None) +""" +This is the public identifier for the app which handles authorization for a Flyte deployment. +More details here: https://www.oauth.com/oauth2-servers/client-registration/client-id-secret/. +""" + +REDIRECT_URI = _config_common.FlyteStringConfigurationEntry('credentials', 'redirect_uri', + default="http://localhost:12345/callback") +""" +This is the callback uri registered with the app which handles authorization for a Flyte deployment. +Please note the hardcoded port number. Ideally we would not do this, but some IDPs do not allow wildcards for +the URL, which means we have to use the same port every time. This is the only reason this is a configuration option, +otherwise, we'd just hardcode the callback path as a constant. +FYI, to see if a given port is already in use, run `sudo lsof -i :` if on a Linux system. +More details here: https://www.oauth.com/oauth2-servers/redirect-uris/. +""" + +AUTHORIZATION_METADATA_KEY = _config_common.FlyteStringConfigurationEntry('credentials', 'authorization_metadata_key', + default="authorization") +""" +The authorization metadata key used for passing access tokens in gRPC requests. +Traditionally this value is 'authorization' however it is made configurable. +""" + + +CLIENT_CREDENTIALS_SECRET = _config_common.FlyteStringConfigurationEntry('credentials', 'client_secret', default=None) +""" +Used for basic auth, which is automatically called during pyflyte. This will allow the Flyte engine to read the +password directly from the environment variable. Note that this is less secure! Please only use this if mounting the +secret as a file is impossible. +""" + + +CLIENT_CREDENTIALS_SCOPE = _config_common.FlyteStringConfigurationEntry('credentials', 'scope', default=None) +""" +Used for basic auth, which is automatically called during pyflyte. This is the scope that will be requested. Because +there is no user explicitly in this auth flow, certain IDPs require a custom scope for basic auth in the configuration +of the authorization server. +""" + +AUTH_MODE = _config_common.FlyteStringConfigurationEntry('credentials', 'auth_mode', default="standard") +""" +The auth mode defines the behavior used to request and refresh credentials. The currently supported modes include: +- 'standard' This uses the pkce-enhanced authorization code flow by opening a browser window to initiate credentials + access. +- 'basic' This uses cert-based auth in which the end user enters his/her username and password and public key encryption + is used to facilitate authentication. +- None: No auth will be attempted. +""" diff --git a/flytekit/configuration/platform.py b/flytekit/configuration/platform.py index 4d63e97b0a..5ea0bd7185 100644 --- a/flytekit/configuration/platform.py +++ b/flytekit/configuration/platform.py @@ -5,6 +5,14 @@ URL = _config_common.FlyteRequiredStringConfigurationEntry('platform', 'url') INSECURE = _config_common.FlyteBoolConfigurationEntry('platform', 'insecure', default=False) + CLOUD_PROVIDER = _config_common.FlyteStringConfigurationEntry( 'platform', 'cloud_provider', default=_constants.CloudProvider.AWS ) + +AUTH = _config_common.FlyteBoolConfigurationEntry('platform', 'auth', default=False) +""" +This config setting should not normally be filled in. Whether or not an admin server requires authentication should be +something published by the admin server itself (typically by returning a 401). However, to help with migration, this +config object is here to force the SDK to attempt the auth flow even without prompting by Admin. +""" diff --git a/flytekit/engines/flyte/engine.py b/flytekit/engines/flyte/engine.py index 53bf34e708..bfe76bf24a 100644 --- a/flytekit/engines/flyte/engine.py +++ b/flytekit/engines/flyte/engine.py @@ -6,11 +6,12 @@ from datetime import datetime as _datetime import six as _six +from flyteidl.core import literals_pb2 as _literals_pb2 -from flytekit.clients.helpers import iterate_node_executions as _iterate_node_executions, iterate_task_executions as \ - _iterate_task_executions from flytekit import __version__ as _api_version from flytekit.clients.friendly import SynchronousFlyteClient as _SynchronousFlyteClient +from flytekit.clients.helpers import iterate_node_executions as _iterate_node_executions, iterate_task_executions as \ + _iterate_task_executions from flytekit.common import utils as _common_utils, constants as _constants from flytekit.common.exceptions import user as _user_exceptions, scopes as _exception_scopes from flytekit.configuration import platform as _platform_config, internal as _internal_config, sdk as _sdk_config @@ -21,7 +22,6 @@ literals as _literals, common as _common_models from flytekit.models.admin import workflow as _workflow_model from flytekit.models.core import errors as _error_models, identifier as _identifier -from flyteidl.core import literals_pb2 as _literals_pb2 class _FlyteClientManager(object): @@ -32,7 +32,8 @@ def __init__(self, *args, **kwargs): # TODO: React to changing configs. For now this is frozen for the lifetime of the process, which covers most # TODO: use cases. if type(self)._CLIENT is None: - type(self)._CLIENT = _SynchronousFlyteClient(*args, **kwargs) + c = _SynchronousFlyteClient(*args, **kwargs) + type(self)._CLIENT = c @property def client(self): diff --git a/requirements.txt b/requirements.txt index 1e4d65baa2..8fc7f72648 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ pytest==4.6.6 mock==3.0.5 -six==1.12.0 +six==1.12.0 \ No newline at end of file diff --git a/setup.py b/setup.py index 07f9a12275..27a8b6797c 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,9 @@ "protobuf>=3.6.1,<4", "pytimeparse>=1.1.8,<2.0.0", "pytz>=2017.2,<2018.5", + "keyring>=18.0.1", "requests>=2.18.4,<3.0.0", + "responses>=0.10.7", "six>=1.9.0,<2.0.0", "sortedcontainers>=1.5.9,<2.0.0", "statsd>=3.0.0,<4.0.0", diff --git a/tests/flytekit/unit/cli/auth/test_auth.py b/tests/flytekit/unit/cli/auth/test_auth.py new file mode 100644 index 0000000000..757e6f4797 --- /dev/null +++ b/tests/flytekit/unit/cli/auth/test_auth.py @@ -0,0 +1,38 @@ +from __future__ import absolute_import + +import re + +from flytekit.clis.auth import auth as _auth + +from multiprocessing import Queue as _Queue +try: # Python 3 + import http.server as _BaseHTTPServer +except ImportError: # Python 2 + import BaseHTTPServer as _BaseHTTPServer + + +def test_generate_code_verifier(): + verifier = _auth._generate_code_verifier() + assert verifier is not None + assert 43 < len(verifier) < 128 + assert not re.search(r'[^a-zA-Z0-9_\-.~]+', verifier) + + +def test_generate_state_parameter(): + param = _auth._generate_state_parameter() + assert not re.search(r'[^a-zA-Z0-9-_.,]+', param) + + +def test_create_code_challenge(): + test_code_verifier = "test_code_verifier" + assert _auth._create_code_challenge(test_code_verifier) == "Qq1fGD0HhxwbmeMrqaebgn1qhvKeguQPXqLdpmixaM4" + + +def test_oauth_http_server(): + queue = _Queue() + server = _auth.OAuthHTTPServer(("localhost", 9000), _BaseHTTPServer.BaseHTTPRequestHandler, queue=queue) + test_auth_code = "auth_code" + server.handle_authorization_code(test_auth_code) + auth_code = queue.get() + assert test_auth_code == auth_code + diff --git a/tests/flytekit/unit/cli/auth/test_discovery.py b/tests/flytekit/unit/cli/auth/test_discovery.py new file mode 100644 index 0000000000..4055d18b35 --- /dev/null +++ b/tests/flytekit/unit/cli/auth/test_discovery.py @@ -0,0 +1,39 @@ +import pytest +import responses + +from flytekit.clis.auth import discovery as _discovery + + +@responses.activate +def test_get_authorization_endpoints(): + discovery_url = "http://flyte-admin.com/discovery" + + auth_endpoint = "http://flyte-admin.com/authorization" + token_endpoint = "http://flyte-admin.com/token" + responses.add(responses.GET, discovery_url, + json={'authorization_endpoint': auth_endpoint, + 'token_endpoint': token_endpoint}) + + discovery_client = _discovery.DiscoveryClient(discovery_url=discovery_url) + assert discovery_client.get_authorization_endpoints().auth_endpoint == auth_endpoint + assert discovery_client.get_authorization_endpoints().token_endpoint == token_endpoint + + +@responses.activate +def test_get_authorization_endpoints_missing_authorization_endpoint(): + discovery_url = "http://flyte-admin.com/discovery" + responses.add(responses.GET, discovery_url, json={'token_endpoint': "http://flyte-admin.com/token"}) + + discovery_client = _discovery.DiscoveryClient(discovery_url=discovery_url) + with pytest.raises(Exception): + discovery_client.get_authorization_endpoints() + + +@responses.activate +def test_get_authorization_endpoints_missing_token_endpoint(): + discovery_url = "http://flyte-admin.com/discovery" + responses.add(responses.GET, discovery_url, json={'authorization_endpoint': "http://flyte-admin.com/authorization"}) + + discovery_client = _discovery.DiscoveryClient(discovery_url=discovery_url) + with pytest.raises(Exception): + discovery_client.get_authorization_endpoints() diff --git a/tests/flytekit/unit/cli/pyflyte/test_basic_auth.py b/tests/flytekit/unit/cli/pyflyte/test_basic_auth.py new file mode 100644 index 0000000000..7c4436113b --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/test_basic_auth.py @@ -0,0 +1,33 @@ +from __future__ import absolute_import +import json + +from mock import MagicMock, patch, PropertyMock +from flytekit.clis.flyte_cli.main import _welcome_message +from flytekit.clis.sdk_in_container import basic_auth +from flytekit.configuration.creds import ( + CLIENT_CREDENTIALS_SECRET as _CREDENTIALS_SECRET, +) + +_welcome_message() + + +def test_get_secret(): + import os + os.environ[_CREDENTIALS_SECRET.env_var] = "abc" + assert basic_auth.get_secret() == "abc" + + +def test_get_basic_authorization_header(): + header = basic_auth.get_basic_authorization_header("client_id", "abc") + assert header == "Basic Y2xpZW50X2lkOmFiYw==" + + +@patch('flytekit.clis.sdk_in_container.basic_auth._requests') +def test_get_token(mock_requests): + response = MagicMock() + response.status_code = 200 + response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") + mock_requests.post.return_value = response + access, expiration = basic_auth.get_token("https://corp.idp.net", "abc123", "my_scope") + assert access == "abc" + assert expiration == 60 diff --git a/tests/flytekit/unit/clients/__init__.py b/tests/flytekit/unit/clients/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/unit/clients/test_raw.py b/tests/flytekit/unit/clients/test_raw.py new file mode 100644 index 0000000000..5d37b5f590 --- /dev/null +++ b/tests/flytekit/unit/clients/test_raw.py @@ -0,0 +1,35 @@ +from __future__ import absolute_import +from flytekit.clients.raw import ( + RawSynchronousFlyteClient as _RawSynchronousFlyteClient, + _refresh_credentials_basic +) +from flytekit.configuration.creds import CLIENT_CREDENTIALS_SECRET as _CREDENTIALS_SECRET +from flytekit.clis.auth.discovery import AuthorizationEndpoints as _AuthorizationEndpoints +import mock +import os +import json + + +@mock.patch('flytekit.clients.raw._admin_service') +@mock.patch('flytekit.clients.raw._insecure_channel') +def test_client_set_token(mock_channel, mock_admin): + mock_channel.return_value = True + mock_admin.AdminServiceStub.return_value = True + client = _RawSynchronousFlyteClient(url='a.b.com', insecure=True) + client.set_access_token('abc') + assert client._metadata[0][1] == 'Bearer abc' + + +@mock.patch('flytekit.clis.sdk_in_container.basic_auth._requests') +@mock.patch('flytekit.clients.raw._credentials_access') +def test_refresh_credentials_basic(mock_credentials_access, mock_requests): + mock_credentials_access.get_authorization_endpoints.return_value = _AuthorizationEndpoints('auth', 'token') + response = mock.MagicMock() + response.status_code = 200 + response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") + mock_requests.post.return_value = response + os.environ[_CREDENTIALS_SECRET.env_var] = "asdf12345" + + mock_client = mock.MagicMock() + _refresh_credentials_basic(mock_client) + mock_client.set_access_token.assert_called_with('abc') diff --git a/tests/flytekit/unit/configuration/conftest.py b/tests/flytekit/unit/configuration/conftest.py index 0d7d1b314b..446fb12fac 100644 --- a/tests/flytekit/unit/configuration/conftest.py +++ b/tests/flytekit/unit/configuration/conftest.py @@ -1,10 +1,12 @@ from __future__ import absolute_import from flytekit.configuration import set_flyte_config_file as _set_config import pytest as _pytest - +import os as _os @_pytest.fixture(scope="function", autouse=True) def clear_configs(): _set_config(None) + environment_variables = _os.environ.copy() yield + _os.environ = environment_variables _set_config(None) diff --git a/tests/flytekit/unit/configuration/test_waterfall.py b/tests/flytekit/unit/configuration/test_waterfall.py new file mode 100644 index 0000000000..0de9317e08 --- /dev/null +++ b/tests/flytekit/unit/configuration/test_waterfall.py @@ -0,0 +1,47 @@ +from __future__ import absolute_import +from flytekit.configuration import set_flyte_config_file as _set_flyte_config_file, \ + common as _common, \ + TemporaryConfiguration as _TemporaryConfiguration + +from flytekit.common.utils import AutoDeletingTempDir as _AutoDeletingTempDir +import os as _os + + +def test_lookup_waterfall_raw_env_var(): + x = _common.FlyteStringConfigurationEntry('test', 'setting', default=None) + + if 'FLYTE_TEST_SETTING' in _os.environ: + del _os.environ['FLYTE_TEST_SETTING'] + assert x.get() is None + + _os.environ['FLYTE_TEST_SETTING'] = 'lorem' + assert x.get() == 'lorem' + + +def test_lookup_waterfall_referenced_env_var(): + x = _common.FlyteStringConfigurationEntry('test', 'setting', default=None) + + if 'FLYTE_TEST_SETTING' in _os.environ: + del _os.environ['FLYTE_TEST_SETTING'] + assert x.get() is None + + if 'TEMP_PLACEHOLDER' in _os.environ: + del _os.environ['TEMP_PLACEHOLDER'] + _os.environ['TEMP_PLACEHOLDER'] = 'lorem' + _os.environ['FLYTE_TEST_SETTING_FROM_ENV_VAR'] = 'TEMP_PLACEHOLDER' + assert x.get() == 'lorem' + + +def test_lookup_waterfall_referenced_file(): + x = _common.FlyteStringConfigurationEntry('test', 'setting', default=None) + + if 'FLYTE_TEST_SETTING' in _os.environ: + del _os.environ['FLYTE_TEST_SETTING'] + assert x.get() is None + + with _AutoDeletingTempDir("config_testing") as tmp_dir: + with open(tmp_dir.get_named_tempfile('name'), 'w') as fh: + fh.write('secret_password') + + _os.environ['FLYTE_TEST_SETTING_FROM_FILE'] = tmp_dir.get_named_tempfile('name') + assert x.get() == 'secret_password'