-
Notifications
You must be signed in to change notification settings - Fork 14.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add test_connection method to GoogleBaseHook #24682
Conversation
"""Test the Google cloud connectivity from UI""" | ||
status, message = False, '' | ||
try: | ||
if self.project_id: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While creating a credential object does it make some call to GCP and confirm or it just deserialises the key JSON file locally and create the credential object if it does locally then this may not be right always for example if wrong project-id in the key file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on provided key JSON , it will connect to GCP oauth module and derive project_id
def _get_credentials_using_keyfile_dict(self):
self._log_debug('Getting connection using JSON Dict')
# Depending on how the JSON was formatted, it may contain
# escaped newlines. Convert those to actual newlines.
self.keyfile_dict['private_key'] = self.keyfile_dict['private_key'].replace('\\n', '\n')
credentials = google.oauth2.service_account.Credentials.from_service_account_info(
self.keyfile_dict, scopes=self.scopes
)
project_id = credentials.project_id
return credentials, project_id
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation of projec_id
may change (see: #23719 ) We should not rely on this to verify that the connection was working.
I think it will be safer if we add an explicit access key check. To do it, you should send the request to the following URL:
https://www.googleapis.com/oauth2/v1/tokeninfo?access_token=accessToken
Source: https://stackoverflow.com/questions/359472/how-can-i-verify-a-google-authentication-api-access-token
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@rbiegacz WDYT?
# credentials.token is None | ||
# Need to refresh credentials to populate the token |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How was it working before the change if it was returning None
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The method _get_access_token()
wasn't being used anywhere earlier. A sample dump of the Credentials object just after calling self._get_credentials()
is :
{'token': None, 'expiry': None, '_quota_project_id': None, '_scopes': ('https://www.googleapis.com/auth/cloud-platform',), '_default_scopes': None, '_signer': <google.auth.crypt._cryptography_rsa.RSASigner object at 0x40359f3af0>, '_service_account_email': '[email protected]', '_subject': None, '_project_id': 'xxxxxx-providers', '_token_uri': 'https://oauth2.googleapis.com/token', '_always_use_jwt_access': False, '_jwt_credentials': None, '_additional_claims': {}}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx.
Left some comments.
@@ -580,3 +586,19 @@ def download_content_from_request(file_handle, request: dict, chunk_size: int) - | |||
while done is False: | |||
_, done = downloader.next_chunk() | |||
file_handle.flush() | |||
|
|||
def test_connection(self): | |||
"""Test the Google cloud connectivity from UI""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: "Google Cloud"
status, message = False, '' | ||
try: | ||
token = self._get_access_token() | ||
url = f"https://www.googleapis.com/oauth2/v3/tokeninfo?access_token={token}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't look safe, as we do not have url encoding here.
If token contains "?", "/" etc. this url will be broken.
if response.status_code == 200: | ||
status = True | ||
message = 'Connection successfully tested' | ||
except Exception as e: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAIK catching all type of exceptions is bad practice.
What specific type of exceptions we want to catch here? Probably exception on request post - there has to be a contract what exception will be raised there.
@@ -341,6 +341,24 @@ def test_get_credentials_and_project_id_with_default_auth(self, mock_get_creds_a | |||
) | |||
assert ('CREDENTIALS', 'PROJECT_ID') == result | |||
|
|||
@mock.patch('requests.post') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good practice is to mock with autospec=True to catch passing incorrect parameters to method.
Unless there are real objections to use autospec, I would use it, as this actually the only way sometimes (and here specifically) we may spot misalignment in the contract of the function that is mocked and usage of it.
@@ -341,6 +341,24 @@ def test_get_credentials_and_project_id_with_default_auth(self, mock_get_creds_a | |||
) | |||
assert ('CREDENTIALS', 'PROJECT_ID') == result | |||
|
|||
@mock.patch('requests.post') | |||
@mock.patch(MODULE_NAME + '.get_credentials_and_project_id') | |||
def test_connection_success(self, mock_get_creds_and_proj_id, requests_post): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for the consistency sake I would name "requests_post" -> "mock_requests_post"
def test_connection_success(self, mock_get_creds_and_proj_id, requests_post): | ||
requests_post.return_value.status_code = 200 | ||
credentials = mock.MagicMock() | ||
type(credentials).token = mock.PropertyMock(return_value="TOKEN") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should work instead:
credentials = mock.MagicMock(token="TOKEN")
This PR adds test connection functionality to Google Cloud connection type in airflow UI
cc @kaxil @potiuk
Test connection Pass
Test connection Failed
^ Add meaningful description above
Read the Pull Request Guidelines for more information.
In case of fundamental code change, Airflow Improvement Proposal (AIP) is needed.
In case of a new dependency, check compliance with the ASF 3rd Party License Policy.
In case of backwards incompatible changes please leave a note in a newsfragement file, named
{pr_number}.significant.rst
, in newsfragments.