diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index 6c8f54e9ce..1049fa162b 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -201,12 +201,19 @@ def _refresh_credentials_standard(self): "Check your Admin server's .well-known endpoints to make sure they're working as expected." ) + verify = None + if self._cfg.insecure_skip_verify: + verify = False + elif self._cfg.ca_cert_file_path: + verify = self._cfg.ca_cert_file_path + client = _credentials_access.get_client( redirect_endpoint=self.public_client_config.redirect_uri, client_id=self.public_client_config.client_id, scopes=self.public_client_config.scopes, auth_endpoint=self.oauth2_metadata.authorization_endpoint, token_endpoint=self.oauth2_metadata.token_endpoint, + verify=verify, ) if client.has_valid_credentials and not self.check_access_token(client.credentials.access_token): diff --git a/flytekit/clis/auth/auth.py b/flytekit/clis/auth/auth.py index f54379485a..9589f8392d 100644 --- a/flytekit/clis/auth/auth.py +++ b/flytekit/clis/auth/auth.py @@ -139,12 +139,7 @@ def access_token(self): class AuthorizationClient(object): def __init__( - self, - auth_endpoint=None, - token_endpoint=None, - scopes=None, - client_id=None, - redirect_uri=None, + self, auth_endpoint=None, token_endpoint=None, scopes=None, client_id=None, redirect_uri=None, verify=None ): self._auth_endpoint = auth_endpoint self._token_endpoint = token_endpoint @@ -160,6 +155,7 @@ def __init__( self._refresh_token = None self._headers = {"content-type": "application/x-www-form-urlencoded"} self._expired = False + self._verify = verify self._params = { "client_id": client_id, # This must match the Client ID of the OAuth application. @@ -262,6 +258,7 @@ def request_access_token(self, auth_code): data=self._params, headers=self._headers, allow_redirects=False, + verify=self._verify, ) if resp.status_code != _StatusCodes.OK: # TODO: handle expected (?) error cases: @@ -280,6 +277,7 @@ def refresh_access_token(self): data={"grant_type": "refresh_token", "client_id": self._client_id, "refresh_token": self._refresh_token}, headers=self._headers, allow_redirects=False, + verify=self._verify, ) if resp.status_code != _StatusCodes.OK: self._expired = True diff --git a/flytekit/clis/auth/credentials.py b/flytekit/clis/auth/credentials.py index a8475c8dfc..efdc1b9870 100644 --- a/flytekit/clis/auth/credentials.py +++ b/flytekit/clis/auth/credentials.py @@ -11,7 +11,7 @@ def get_client( - redirect_endpoint: str, client_id: str, scopes: List[str], auth_endpoint: str, token_endpoint: str + redirect_endpoint: str, client_id: str, scopes: List[str], auth_endpoint: str, token_endpoint: str, verify: Union[bool, str] ) -> AuthorizationClient: global _authorization_client if _authorization_client is not None and not _authorization_client.expired: @@ -23,6 +23,7 @@ def get_client( scopes=scopes, auth_endpoint=auth_endpoint, token_endpoint=token_endpoint, + verify=verify, ) auth_logger.debug(f"Created oauth client with redirect {_authorization_client}") diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index 220f9209ea..93e430477c 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -314,6 +314,7 @@ class PlatformConfig(object): endpoint: str = "localhost:30080" insecure: bool = False insecure_skip_verify: bool = False + ca_cert_file_path: typing.Optional[str] = None console_endpoint: typing.Optional[str] = None command: typing.Optional[typing.List[str]] = None client_id: typing.Optional[str] = None @@ -334,6 +335,7 @@ def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None kwargs = set_if_exists( kwargs, "insecure_skip_verify", _internal.Platform.INSECURE_SKIP_VERIFY.read(config_file) ) + kwargs = set_if_exists(kwargs, "ca_cert_file_path", _internal.Platform.CA_CERT_FILE_PATH.read(config_file)) kwargs = set_if_exists(kwargs, "command", _internal.Credentials.COMMAND.read(config_file)) kwargs = set_if_exists(kwargs, "client_id", _internal.Credentials.CLIENT_ID.read(config_file)) kwargs = set_if_exists( @@ -355,6 +357,7 @@ def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None kwargs = set_if_exists(kwargs, "auth_mode", _internal.Credentials.AUTH_MODE.read(config_file)) kwargs = set_if_exists(kwargs, "endpoint", _internal.Platform.URL.read(config_file)) kwargs = set_if_exists(kwargs, "console_endpoint", _internal.Platform.CONSOLE_ENDPOINT.read(config_file)) + return PlatformConfig(**kwargs) @classmethod diff --git a/flytekit/configuration/internal.py b/flytekit/configuration/internal.py index 5c29045db5..5c3729e63b 100644 --- a/flytekit/configuration/internal.py +++ b/flytekit/configuration/internal.py @@ -108,6 +108,9 @@ class Platform(object): LegacyConfigEntry(SECTION, "insecure_skip_verify", bool), YamlConfigEntry("admin.insecureSkipVerify", bool) ) CONSOLE_ENDPOINT = ConfigEntry(LegacyConfigEntry(SECTION, "console_endpoint"), YamlConfigEntry("console.endpoint")) + CA_CERT_FILE_PATH = ConfigEntry( + LegacyConfigEntry(SECTION, "ca_cert_file_path"), YamlConfigEntry("admin.caCertFilePath") + ) class LocalSDK(object):