Skip to content

Commit

Permalink
Pass verify flag to all authenticators (#1641)
Browse files Browse the repository at this point in the history
Signed-off-by: byhsu <[email protected]>
  • Loading branch information
ByronHsu authored and eapolinario committed Jun 29, 2023
1 parent 61caf13 commit c380e96
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 24 deletions.
21 changes: 15 additions & 6 deletions flytekit/clients/auth/authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,13 @@ def __init__(
header_key: str,
credentials: Credentials = None,
http_proxy_url: typing.Optional[str] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
):
self._endpoint = endpoint
self._creds = credentials
self._header_key = header_key if header_key else "authorization"
self._http_proxy_url = http_proxy_url
self._verify = verify

def get_credentials(self) -> Credentials:
return self._creds
Expand Down Expand Up @@ -94,10 +96,9 @@ def __init__(
"""
Initialize with default creds from KeyStore using the endpoint name
"""
super().__init__(endpoint, header_key, KeyringStore.retrieve(endpoint))
super().__init__(endpoint, header_key, KeyringStore.retrieve(endpoint), verify=verify)
self._cfg_store = cfg_store
self._auth_client = None
self._verify = verify

def _initialize_auth_client(self):
if not self._auth_client:
Expand Down Expand Up @@ -170,6 +171,7 @@ def __init__(
header_key: typing.Optional[str] = None,
scopes: typing.Optional[typing.List[str]] = None,
http_proxy_url: typing.Optional[str] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
):
if not client_id or not client_secret:
raise ValueError("Client ID and Client SECRET both are required.")
Expand All @@ -179,7 +181,7 @@ def __init__(
self._scopes = scopes or cfg.scopes
self._client_id = client_id
self._client_secret = client_secret
super().__init__(endpoint, cfg.header_key or header_key, http_proxy_url=http_proxy_url)
super().__init__(endpoint, cfg.header_key or header_key, http_proxy_url=http_proxy_url, verify=verify)

def refresh_credentials(self):
"""
Expand All @@ -195,8 +197,9 @@ def refresh_credentials(self):
# Note that unlike the Pkce flow, the client ID does not come from Admin.
logging.debug(f"Basic authorization flow with client id {self._client_id} scope {scopes}")
authorization_header = token_client.get_basic_authorization_header(self._client_id, self._client_secret)

token, expires_in = token_client.get_token(
token_endpoint, scopes, authorization_header, http_proxy_url=self._http_proxy_url
token_endpoint, scopes, authorization_header, http_proxy_url=self._http_proxy_url, verify=self._verify
)
logging.info("Retrieved new token, expires in {}".format(expires_in))
self._creds = Credentials(token)
Expand All @@ -218,6 +221,7 @@ def __init__(
header_key: typing.Optional[str] = None,
audience: typing.Optional[str] = None,
http_proxy_url: typing.Optional[str] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
):
self._audience = audience
cfg = cfg_store.get_client_config()
Expand All @@ -234,11 +238,12 @@ def __init__(
header_key=header_key or cfg.header_key,
credentials=KeyringStore.retrieve(endpoint),
http_proxy_url=http_proxy_url,
verify=verify,
)

def refresh_credentials(self):
resp = token_client.get_device_code(
self._device_auth_endpoint, self._client_id, self._audience, self._scope, self._http_proxy_url
self._device_auth_endpoint, self._client_id, self._audience, self._scope, self._http_proxy_url, self._verify
)
print(
f"""
Expand All @@ -249,7 +254,11 @@ def refresh_credentials(self):
# Currently the refresh token is not retreived. We may want to add support for refreshTokens so that
# access tokens can be refreshed for once authenticated machines
token, expires_in = token_client.poll_token_endpoint(
resp, self._token_endpoint, client_id=self._client_id, http_proxy_url=self._http_proxy_url
resp,
self._token_endpoint,
client_id=self._client_id,
http_proxy_url=self._http_proxy_url,
verify=self._verify,
)
self._creds = Credentials(access_token=token, expires_in=expires_in, for_endpoint=self._endpoint)
KeyringStore.store(self._creds)
Expand Down
13 changes: 10 additions & 3 deletions flytekit/clients/auth/token_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def get_token(
device_code: typing.Optional[str] = None,
grant_type: GrantType = GrantType.CLIENT_CREDS,
http_proxy_url: typing.Optional[str] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
) -> typing.Tuple[str, int]:
"""
:rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration
Expand All @@ -99,7 +100,7 @@ def get_token(
body["scope"] = ",".join(scopes)

proxies = {"https": http_proxy_url, "http": http_proxy_url} if http_proxy_url else None
response = requests.post(token_endpoint, data=body, headers=headers, proxies=proxies)
response = requests.post(token_endpoint, data=body, headers=headers, proxies=proxies, verify=verify)
if not response.ok:
j = response.json()
if "error" in j:
Expand All @@ -119,6 +120,7 @@ def get_device_code(
audience: typing.Optional[str] = None,
scope: typing.Optional[typing.List[str]] = None,
http_proxy_url: typing.Optional[str] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
) -> DeviceCodeResponse:
"""
Retrieves the device Authentication code that can be done to authenticate the request using a browser on a
Expand All @@ -127,14 +129,18 @@ def get_device_code(
_scope = " ".join(scope) if scope is not None else ""
payload = {"client_id": client_id, "scope": _scope, "audience": audience}
proxies = {"https": http_proxy_url, "http": http_proxy_url} if http_proxy_url else None
resp = requests.post(device_auth_endpoint, payload, proxies=proxies)
resp = requests.post(device_auth_endpoint, payload, proxies=proxies, verify=verify)
if not resp.ok:
raise AuthenticationError(f"Unable to retrieve Device Authentication Code for {payload}, Reason {resp.reason}")
return DeviceCodeResponse.from_json_response(resp.json())


def poll_token_endpoint(
resp: DeviceCodeResponse, token_endpoint: str, client_id: str, http_proxy_url: typing.Optional[str] = None
resp: DeviceCodeResponse,
token_endpoint: str,
client_id: str,
http_proxy_url: typing.Optional[str] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
) -> typing.Tuple[str, int]:
tick = datetime.now()
interval = timedelta(seconds=resp.interval)
Expand All @@ -147,6 +153,7 @@ def poll_token_endpoint(
client_id=client_id,
device_code=resp.device_code,
http_proxy_url=http_proxy_url,
verify=verify,
)
print("Authentication successful!")
return access_token, expires_in
Expand Down
18 changes: 12 additions & 6 deletions flytekit/clients/auth_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,13 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth
logging.warning(f"Authentication type {cfg_auth} does not exist, defaulting to standard")
cfg_auth = AuthType.STANDARD

verify = None
if cfg.insecure_skip_verify:
verify = False
elif cfg.ca_cert_file_path:
verify = cfg.ca_cert_file_path

if cfg_auth == AuthType.STANDARD or cfg_auth == AuthType.PKCE:
verify = None
if cfg.insecure_skip_verify:
verify = False
elif cfg.ca_cert_file_path:
verify = cfg.ca_cert_file_path
return PKCEAuthenticator(cfg.endpoint, cfg_store, verify=verify)
elif cfg_auth == AuthType.BASIC or cfg_auth == AuthType.CLIENT_CREDENTIALS or cfg_auth == AuthType.CLIENTSECRET:
return ClientCredentialsAuthenticator(
Expand All @@ -73,6 +74,7 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth
cfg_store=cfg_store,
scopes=cfg.scopes,
http_proxy_url=cfg.http_proxy_url,
verify=verify,
)
elif cfg_auth == AuthType.EXTERNAL_PROCESS or cfg_auth == AuthType.EXTERNALCOMMAND:
client_cfg = None
Expand All @@ -84,7 +86,11 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth
)
elif cfg_auth == AuthType.DEVICEFLOW:
return DeviceCodeAuthenticator(
endpoint=cfg.endpoint, cfg_store=cfg_store, audience=cfg.audience, http_proxy_url=cfg.http_proxy_url
endpoint=cfg.endpoint,
cfg_store=cfg_store,
audience=cfg.audience,
http_proxy_url=cfg.http_proxy_url,
verify=verify,
)
else:
raise ValueError(
Expand Down
11 changes: 5 additions & 6 deletions tests/flytekit/unit/clients/auth/test_authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,7 @@ def test_client_creds_authenticator(mock_requests):
@patch("flytekit.clients.auth.token_client.poll_token_endpoint")
def test_device_flow_authenticator(poll_mock: MagicMock, device_mock: MagicMock, mock_keyring: MagicMock):
with pytest.raises(AuthenticationError):
DeviceCodeAuthenticator(
ENDPOINT,
static_cfg_store,
audience="x",
)
DeviceCodeAuthenticator(ENDPOINT, static_cfg_store, audience="x", verify=True)

cfg_store = StaticClientConfigStore(
ClientConfig(
Expand All @@ -107,7 +103,9 @@ def test_device_flow_authenticator(poll_mock: MagicMock, device_mock: MagicMock,
device_authorization_endpoint="dev",
)
)
authn = DeviceCodeAuthenticator(ENDPOINT, cfg_store, audience="x", http_proxy_url="http://my-proxy:9000")
authn = DeviceCodeAuthenticator(
ENDPOINT, cfg_store, audience="x", http_proxy_url="http://my-proxy:9000", verify=False
)

device_mock.return_value = DeviceCodeResponse("x", "y", "s", 1000, 0)
poll_mock.return_value = ("access", 100)
Expand All @@ -124,6 +122,7 @@ def test_client_creds_authenticator_with_custom_scopes(mock_requests):
client_secret="secret",
cfg_store=static_cfg_store,
scopes=expected_scopes,
verify=True,
)
response = MagicMock()
response.status_code = 200
Expand Down
6 changes: 3 additions & 3 deletions tests/flytekit/unit/clients/auth/test_token_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_get_token(mock_requests):
response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""")
mock_requests.post.return_value = response
access, expiration = get_token(
"https://corp.idp.net", client_id="abc123", scopes=["my_scope"], http_proxy_url="http://proxy:3000"
"https://corp.idp.net", client_id="abc123", scopes=["my_scope"], http_proxy_url="http://proxy:3000", verify=True
)
assert access == "abc"
assert expiration == 60
Expand Down Expand Up @@ -66,13 +66,13 @@ def test_poll_token_endpoint(mock_requests):

r = DeviceCodeResponse(device_code="x", user_code="y", verification_uri="v", expires_in=1, interval=1)
with pytest.raises(AuthenticationError):
poll_token_endpoint(r, "test.com", "test", http_proxy_url="http://proxy:3000")
poll_token_endpoint(r, "test.com", "test", http_proxy_url="http://proxy:3000", verify=True)

response = MagicMock()
response.ok = True
response.json.return_value = {"access_token": "abc", "expires_in": 60}
mock_requests.post.return_value = response
r = DeviceCodeResponse(device_code="x", user_code="y", verification_uri="v", expires_in=1, interval=0)
t, e = poll_token_endpoint(r, "test.com", "test", http_proxy_url="http://proxy:3000")
t, e = poll_token_endpoint(r, "test.com", "test", http_proxy_url="http://proxy:3000", verify=True)
assert t
assert e

0 comments on commit c380e96

Please sign in to comment.