Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass verify flag to all authenticators #1641

Merged
merged 5 commits into from
May 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions flytekit/clients/auth/authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,13 @@ def __init__(
header_key: str,
credentials: Credentials = None,
http_proxy_url: typing.Optional[str] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
ByronHsu marked this conversation as resolved.
Show resolved Hide resolved
):
self._endpoint = endpoint
self._creds = credentials
self._header_key = header_key if header_key else "authorization"
self._http_proxy_url = http_proxy_url
self._verify = verify

def get_credentials(self) -> Credentials:
return self._creds
Expand Down Expand Up @@ -94,10 +96,9 @@ def __init__(
"""
Initialize with default creds from KeyStore using the endpoint name
"""
super().__init__(endpoint, header_key, KeyringStore.retrieve(endpoint))
super().__init__(endpoint, header_key, KeyringStore.retrieve(endpoint), verify=verify)
self._cfg_store = cfg_store
self._auth_client = None
self._verify = verify

def _initialize_auth_client(self):
if not self._auth_client:
Expand Down Expand Up @@ -170,6 +171,7 @@ def __init__(
header_key: typing.Optional[str] = None,
scopes: typing.Optional[typing.List[str]] = None,
http_proxy_url: typing.Optional[str] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
):
if not client_id or not client_secret:
raise ValueError("Client ID and Client SECRET both are required.")
Expand All @@ -179,7 +181,7 @@ def __init__(
self._scopes = scopes or cfg.scopes
self._client_id = client_id
self._client_secret = client_secret
super().__init__(endpoint, cfg.header_key or header_key, http_proxy_url=http_proxy_url)
super().__init__(endpoint, cfg.header_key or header_key, http_proxy_url=http_proxy_url, verify=verify)

def refresh_credentials(self):
"""
Expand All @@ -195,8 +197,9 @@ def refresh_credentials(self):
# Note that unlike the Pkce flow, the client ID does not come from Admin.
logging.debug(f"Basic authorization flow with client id {self._client_id} scope {scopes}")
authorization_header = token_client.get_basic_authorization_header(self._client_id, self._client_secret)

token, expires_in = token_client.get_token(
token_endpoint, scopes, authorization_header, http_proxy_url=self._http_proxy_url
token_endpoint, scopes, authorization_header, http_proxy_url=self._http_proxy_url, verify=self._verify
)
logging.info("Retrieved new token, expires in {}".format(expires_in))
self._creds = Credentials(token)
Expand All @@ -218,6 +221,7 @@ def __init__(
header_key: typing.Optional[str] = None,
audience: typing.Optional[str] = None,
http_proxy_url: typing.Optional[str] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
):
self._audience = audience
cfg = cfg_store.get_client_config()
Expand All @@ -234,11 +238,12 @@ def __init__(
header_key=header_key or cfg.header_key,
credentials=KeyringStore.retrieve(endpoint),
http_proxy_url=http_proxy_url,
verify=verify,
)

def refresh_credentials(self):
resp = token_client.get_device_code(
self._device_auth_endpoint, self._client_id, self._audience, self._scope, self._http_proxy_url
self._device_auth_endpoint, self._client_id, self._audience, self._scope, self._http_proxy_url, self._verify
)
print(
f"""
Expand All @@ -249,7 +254,11 @@ def refresh_credentials(self):
# Currently the refresh token is not retreived. We may want to add support for refreshTokens so that
# access tokens can be refreshed for once authenticated machines
token, expires_in = token_client.poll_token_endpoint(
resp, self._token_endpoint, client_id=self._client_id, http_proxy_url=self._http_proxy_url
resp,
self._token_endpoint,
client_id=self._client_id,
http_proxy_url=self._http_proxy_url,
verify=self._verify,
)
self._creds = Credentials(access_token=token, expires_in=expires_in, for_endpoint=self._endpoint)
KeyringStore.store(self._creds)
Expand Down
13 changes: 10 additions & 3 deletions flytekit/clients/auth/token_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def get_token(
device_code: typing.Optional[str] = None,
grant_type: GrantType = GrantType.CLIENT_CREDS,
http_proxy_url: typing.Optional[str] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
) -> typing.Tuple[str, int]:
"""
:rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration
Expand All @@ -99,7 +100,7 @@ def get_token(
body["scope"] = ",".join(scopes)

proxies = {"https": http_proxy_url, "http": http_proxy_url} if http_proxy_url else None
response = requests.post(token_endpoint, data=body, headers=headers, proxies=proxies)
response = requests.post(token_endpoint, data=body, headers=headers, proxies=proxies, verify=verify)
if not response.ok:
j = response.json()
if "error" in j:
Expand All @@ -119,6 +120,7 @@ def get_device_code(
audience: typing.Optional[str] = None,
scope: typing.Optional[typing.List[str]] = None,
http_proxy_url: typing.Optional[str] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
) -> DeviceCodeResponse:
"""
Retrieves the device Authentication code that can be done to authenticate the request using a browser on a
Expand All @@ -127,14 +129,18 @@ def get_device_code(
_scope = " ".join(scope) if scope is not None else ""
payload = {"client_id": client_id, "scope": _scope, "audience": audience}
proxies = {"https": http_proxy_url, "http": http_proxy_url} if http_proxy_url else None
resp = requests.post(device_auth_endpoint, payload, proxies=proxies)
resp = requests.post(device_auth_endpoint, payload, proxies=proxies, verify=verify)
if not resp.ok:
raise AuthenticationError(f"Unable to retrieve Device Authentication Code for {payload}, Reason {resp.reason}")
return DeviceCodeResponse.from_json_response(resp.json())


def poll_token_endpoint(
resp: DeviceCodeResponse, token_endpoint: str, client_id: str, http_proxy_url: typing.Optional[str] = None
resp: DeviceCodeResponse,
token_endpoint: str,
client_id: str,
http_proxy_url: typing.Optional[str] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
) -> typing.Tuple[str, int]:
tick = datetime.now()
interval = timedelta(seconds=resp.interval)
Expand All @@ -147,6 +153,7 @@ def poll_token_endpoint(
client_id=client_id,
device_code=resp.device_code,
http_proxy_url=http_proxy_url,
verify=verify,
)
print("Authentication successful!")
return access_token, expires_in
Expand Down
18 changes: 12 additions & 6 deletions flytekit/clients/auth_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,13 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth
logging.warning(f"Authentication type {cfg_auth} does not exist, defaulting to standard")
cfg_auth = AuthType.STANDARD

verify = None
if cfg.insecure_skip_verify:
verify = False
elif cfg.ca_cert_file_path:
verify = cfg.ca_cert_file_path
ByronHsu marked this conversation as resolved.
Show resolved Hide resolved

if cfg_auth == AuthType.STANDARD or cfg_auth == AuthType.PKCE:
verify = None
if cfg.insecure_skip_verify:
verify = False
elif cfg.ca_cert_file_path:
verify = cfg.ca_cert_file_path
return PKCEAuthenticator(cfg.endpoint, cfg_store, verify=verify)
elif cfg_auth == AuthType.BASIC or cfg_auth == AuthType.CLIENT_CREDENTIALS or cfg_auth == AuthType.CLIENTSECRET:
return ClientCredentialsAuthenticator(
Expand All @@ -73,6 +74,7 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth
cfg_store=cfg_store,
scopes=cfg.scopes,
http_proxy_url=cfg.http_proxy_url,
verify=verify,
)
elif cfg_auth == AuthType.EXTERNAL_PROCESS or cfg_auth == AuthType.EXTERNALCOMMAND:
client_cfg = None
Expand All @@ -84,7 +86,11 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth
)
elif cfg_auth == AuthType.DEVICEFLOW:
return DeviceCodeAuthenticator(
endpoint=cfg.endpoint, cfg_store=cfg_store, audience=cfg.audience, http_proxy_url=cfg.http_proxy_url
endpoint=cfg.endpoint,
cfg_store=cfg_store,
audience=cfg.audience,
http_proxy_url=cfg.http_proxy_url,
verify=verify,
)
else:
raise ValueError(
Expand Down
11 changes: 5 additions & 6 deletions tests/flytekit/unit/clients/auth/test_authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,7 @@ def test_client_creds_authenticator(mock_requests):
@patch("flytekit.clients.auth.token_client.poll_token_endpoint")
def test_device_flow_authenticator(poll_mock: MagicMock, device_mock: MagicMock, mock_keyring: MagicMock):
with pytest.raises(AuthenticationError):
DeviceCodeAuthenticator(
ENDPOINT,
static_cfg_store,
audience="x",
)
DeviceCodeAuthenticator(ENDPOINT, static_cfg_store, audience="x", verify=True)

cfg_store = StaticClientConfigStore(
ClientConfig(
Expand All @@ -107,7 +103,9 @@ def test_device_flow_authenticator(poll_mock: MagicMock, device_mock: MagicMock,
device_authorization_endpoint="dev",
)
)
authn = DeviceCodeAuthenticator(ENDPOINT, cfg_store, audience="x", http_proxy_url="http://my-proxy:9000")
authn = DeviceCodeAuthenticator(
ENDPOINT, cfg_store, audience="x", http_proxy_url="http://my-proxy:9000", verify=False
)

device_mock.return_value = DeviceCodeResponse("x", "y", "s", 1000, 0)
poll_mock.return_value = ("access", 100)
Expand All @@ -124,6 +122,7 @@ def test_client_creds_authenticator_with_custom_scopes(mock_requests):
client_secret="secret",
cfg_store=static_cfg_store,
scopes=expected_scopes,
verify=True,
)
response = MagicMock()
response.status_code = 200
Expand Down
6 changes: 3 additions & 3 deletions tests/flytekit/unit/clients/auth/test_token_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_get_token(mock_requests):
response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""")
mock_requests.post.return_value = response
access, expiration = get_token(
"https://corp.idp.net", client_id="abc123", scopes=["my_scope"], http_proxy_url="http://proxy:3000"
"https://corp.idp.net", client_id="abc123", scopes=["my_scope"], http_proxy_url="http://proxy:3000", verify=True
)
assert access == "abc"
assert expiration == 60
Expand Down Expand Up @@ -66,13 +66,13 @@ def test_poll_token_endpoint(mock_requests):

r = DeviceCodeResponse(device_code="x", user_code="y", verification_uri="v", expires_in=1, interval=1)
with pytest.raises(AuthenticationError):
poll_token_endpoint(r, "test.com", "test", http_proxy_url="http://proxy:3000")
poll_token_endpoint(r, "test.com", "test", http_proxy_url="http://proxy:3000", verify=True)

response = MagicMock()
response.ok = True
response.json.return_value = {"access_token": "abc", "expires_in": 60}
mock_requests.post.return_value = response
r = DeviceCodeResponse(device_code="x", user_code="y", verification_uri="v", expires_in=1, interval=0)
t, e = poll_token_endpoint(r, "test.com", "test", http_proxy_url="http://proxy:3000")
t, e = poll_token_endpoint(r, "test.com", "test", http_proxy_url="http://proxy:3000", verify=True)
assert t
assert e