Skip to content

Commit

Permalink
Pass cert and verify settings to auth client
Browse files Browse the repository at this point in the history
Signed-off-by: Ankit Goyal <[email protected]>
  • Loading branch information
goyalankit committed Feb 13, 2023
1 parent ecded3e commit e10cc95
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 7 deletions.
7 changes: 7 additions & 0 deletions flytekit/clients/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,19 @@ def _refresh_credentials_standard(self):
"Check your Admin server's .well-known endpoints to make sure they're working as expected."
)

verify = None
if self._cfg.insecure_skip_verify:
verify = False
elif self._cfg.ca_cert_file_path:
verify = self._cfg.ca_cert_file_path

client = _credentials_access.get_client(
redirect_endpoint=self.public_client_config.redirect_uri,
client_id=self.public_client_config.client_id,
scopes=self.public_client_config.scopes,
auth_endpoint=self.oauth2_metadata.authorization_endpoint,
token_endpoint=self.oauth2_metadata.token_endpoint,
verify=verify,
)

if client.has_valid_credentials and not self.check_access_token(client.credentials.access_token):
Expand Down
10 changes: 4 additions & 6 deletions flytekit/clis/auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,7 @@ def access_token(self):

class AuthorizationClient(object):
def __init__(
self,
auth_endpoint=None,
token_endpoint=None,
scopes=None,
client_id=None,
redirect_uri=None,
self, auth_endpoint=None, token_endpoint=None, scopes=None, client_id=None, redirect_uri=None, verify=None
):
self._auth_endpoint = auth_endpoint
self._token_endpoint = token_endpoint
Expand All @@ -160,6 +155,7 @@ def __init__(
self._refresh_token = None
self._headers = {"content-type": "application/x-www-form-urlencoded"}
self._expired = False
self._verify = verify

self._params = {
"client_id": client_id, # This must match the Client ID of the OAuth application.
Expand Down Expand Up @@ -262,6 +258,7 @@ def request_access_token(self, auth_code):
data=self._params,
headers=self._headers,
allow_redirects=False,
verify=self._verify,
)
if resp.status_code != _StatusCodes.OK:
# TODO: handle expected (?) error cases:
Expand All @@ -280,6 +277,7 @@ def refresh_access_token(self):
data={"grant_type": "refresh_token", "client_id": self._client_id, "refresh_token": self._refresh_token},
headers=self._headers,
allow_redirects=False,
verify=self._verify,
)
if resp.status_code != _StatusCodes.OK:
self._expired = True
Expand Down
3 changes: 2 additions & 1 deletion flytekit/clis/auth/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


def get_client(
redirect_endpoint: str, client_id: str, scopes: List[str], auth_endpoint: str, token_endpoint: str
redirect_endpoint: str, client_id: str, scopes: List[str], auth_endpoint: str, token_endpoint: str, verify: bool
) -> AuthorizationClient:
global _authorization_client
if _authorization_client is not None and not _authorization_client.expired:
Expand All @@ -23,6 +23,7 @@ def get_client(
scopes=scopes,
auth_endpoint=auth_endpoint,
token_endpoint=token_endpoint,
verify=verify,
)

auth_logger.debug(f"Created oauth client with redirect {_authorization_client}")
Expand Down
3 changes: 3 additions & 0 deletions flytekit/configuration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ class PlatformConfig(object):
endpoint: str = "localhost:30080"
insecure: bool = False
insecure_skip_verify: bool = False
ca_cert_file_path: typing.Optional[str] = None
console_endpoint: typing.Optional[str] = None
command: typing.Optional[typing.List[str]] = None
client_id: typing.Optional[str] = None
Expand All @@ -334,6 +335,7 @@ def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None
kwargs = set_if_exists(
kwargs, "insecure_skip_verify", _internal.Platform.INSECURE_SKIP_VERIFY.read(config_file)
)
kwargs = set_if_exists(kwargs, "ca_cert_file_path", _internal.Platform.CA_CERT_FILE_PATH.read(config_file))
kwargs = set_if_exists(kwargs, "command", _internal.Credentials.COMMAND.read(config_file))
kwargs = set_if_exists(kwargs, "client_id", _internal.Credentials.CLIENT_ID.read(config_file))
kwargs = set_if_exists(
Expand All @@ -355,6 +357,7 @@ def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None
kwargs = set_if_exists(kwargs, "auth_mode", _internal.Credentials.AUTH_MODE.read(config_file))
kwargs = set_if_exists(kwargs, "endpoint", _internal.Platform.URL.read(config_file))
kwargs = set_if_exists(kwargs, "console_endpoint", _internal.Platform.CONSOLE_ENDPOINT.read(config_file))

return PlatformConfig(**kwargs)

@classmethod
Expand Down
3 changes: 3 additions & 0 deletions flytekit/configuration/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ class Platform(object):
LegacyConfigEntry(SECTION, "insecure_skip_verify", bool), YamlConfigEntry("admin.insecureSkipVerify", bool)
)
CONSOLE_ENDPOINT = ConfigEntry(LegacyConfigEntry(SECTION, "console_endpoint"), YamlConfigEntry("console.endpoint"))
CA_CERT_FILE_PATH = ConfigEntry(
LegacyConfigEntry(SECTION, "ca_cert_file_path"), YamlConfigEntry("admin.caCertFilePath")
)


class LocalSDK(object):
Expand Down

0 comments on commit e10cc95

Please sign in to comment.