diff --git a/airflow/providers/github/hooks/github.py b/airflow/providers/github/hooks/github.py index 07a8566a7f575..9a71ef5b3897a 100644 --- a/airflow/providers/github/hooks/github.py +++ b/airflow/providers/github/hooks/github.py @@ -16,17 +16,18 @@ # specific language governing permissions and limitations # under the License. -"""This module allows to connect to a Github.""" -from typing import Dict, Optional +"""This module allows you to connect to GitHub.""" +from typing import Dict, Optional, Tuple from github import Github as GithubClient +from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook class GithubHook(BaseHook): """ - Interact with Github. + Interact with GitHub. Performs a connection to GitHub and retrieves client. @@ -36,7 +37,7 @@ class GithubHook(BaseHook): conn_name_attr = 'github_conn_id' default_conn_name = 'github_default' conn_type = 'github' - hook_name = 'Github' + hook_name = 'GitHub' def __init__(self, github_conn_id: str = default_conn_name, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -45,10 +46,7 @@ def __init__(self, github_conn_id: str = default_conn_name, *args, **kwargs) -> self.get_conn() def get_conn(self) -> GithubClient: - """ - Function that initiates a new GitHub connection - with token and hostname ( for GitHub Enterprise ) - """ + """Function that initiates a new GitHub connection with token and hostname (for GitHub Enterprise).""" if self.client is not None: return self.client @@ -56,6 +54,12 @@ def get_conn(self) -> GithubClient: access_token = conn.password host = conn.host + # Currently the only method of authenticating to GitHub in Airflow is via a token. This is not the + # only means available, but raising an exception to enforce this method for now. + # TODO: When/If other auth methods are implemented this exception should be removed/modified. + if not access_token: + raise AirflowException("An access token is required to authenticate to GitHub.") + if not host: self.client = GithubClient(login_or_token=access_token) else: @@ -68,12 +72,15 @@ def get_ui_field_behaviour() -> Dict: """Returns custom field behaviour""" return { "hidden_fields": ['schema', 'port', 'login', 'extra'], - "relabeling": { - 'host': 'GitHub Enterprise Url (Optional)', - 'password': 'GitHub Access Token', - }, - "placeholders": { - 'host': 'https://{hostname}/api/v3 (for GitHub Enterprise Connection)', - 'password': 'token credentials auth', - }, + "relabeling": {'host': 'GitHub Enterprise URL (Optional)', 'password': 'GitHub Access Token'}, + "placeholders": {'host': 'https://{hostname}/api/v3 (for GitHub Enterprise)'}, } + + def test_connection(self) -> Tuple[bool, str]: + """Test GitHub connection.""" + try: + assert self.client # For mypy union-attr check of Optional[GithubClient]. + self.client.get_user().id + return True, "Successfully connected to GitHub." + except Exception as e: + return False, str(e) diff --git a/tests/providers/github/hooks/test_github.py b/tests/providers/github/hooks/test_github.py index b4feab3183756..4bad1d8e8c7e7 100644 --- a/tests/providers/github/hooks/test_github.py +++ b/tests/providers/github/hooks/test_github.py @@ -17,9 +17,10 @@ # under the License. # -import unittest from unittest.mock import Mock, patch +from github import BadCredentialsException, Github, NamedUser + from airflow.models import Connection from airflow.providers.github.hooks.github import GithubHook from airflow.utils import db @@ -27,15 +28,14 @@ github_client_mock = Mock(name="github_client_for_test") -class TestGithubHook(unittest.TestCase): - def setUp(self): +class TestGithubHook: + def setup_class(self): db.merge_conn( Connection( - conn_id='github_default', + conn_id="github_default", conn_type='github', - host='https://localhost/github/', - port=443, - extra='{"verify": "False", "project": "AIRFLOW"}', + password='my-access-token', + host='https://mygithub.com/api/v3', ) ) @@ -48,3 +48,27 @@ def test_github_client_connection(self, github_mock): assert github_mock.called assert isinstance(github_hook.client, Mock) assert github_hook.client.name == github_mock.return_value.name + + def test_connection_success(self): + hook = GithubHook() + hook.client = Mock(spec=Github) + hook.client.get_user.return_value = NamedUser.NamedUser + + status, msg = hook.test_connection() + + assert status is True + assert msg == "Successfully connected to GitHub." + + def test_connection_failure(self): + hook = GithubHook() + hook.client.get_user = Mock( + side_effect=BadCredentialsException( + status=401, + data={"message": "Bad credentials"}, + headers={}, + ) + ) + status, msg = hook.test_connection() + + assert status is False + assert msg == '401 {"message": "Bad credentials"}' diff --git a/tests/providers/github/operators/test_github.py b/tests/providers/github/operators/test_github.py index 23461cbbdf006..8b3b6263212e9 100644 --- a/tests/providers/github/operators/test_github.py +++ b/tests/providers/github/operators/test_github.py @@ -17,7 +17,6 @@ # under the License. # -import unittest from unittest.mock import Mock, patch from airflow.models import Connection @@ -29,8 +28,8 @@ github_client_mock = Mock(name="github_client_for_test") -class TestGithubOperator(unittest.TestCase): - def setUp(self): +class TestGithubOperator: + def setup_class(self): args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} dag = DAG('test_dag_id', default_args=args) self.dag = dag @@ -38,9 +37,8 @@ def setUp(self): Connection( conn_id='github_default', conn_type='github', - host='https://localhost/github/', - port=443, - extra='{"verify": "False", "project": "AIRFLOW"}', + password='my-access-token', + host='https://mygithub.com/api/v3', ) ) diff --git a/tests/providers/github/sensors/test_github.py b/tests/providers/github/sensors/test_github.py index 71cb0a75cacda..14d168fa4669c 100644 --- a/tests/providers/github/sensors/test_github.py +++ b/tests/providers/github/sensors/test_github.py @@ -17,7 +17,6 @@ # under the License. # -import unittest from unittest.mock import Mock, patch from airflow.models import Connection @@ -29,8 +28,8 @@ github_client_mock = Mock(name="github_client_for_test") -class TestGithubSensor(unittest.TestCase): - def setUp(self): +class TestGithubSensor: + def setup_class(self): args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} dag = DAG('test_dag_id', default_args=args) self.dag = dag @@ -38,9 +37,8 @@ def setUp(self): Connection( conn_id='github_default', conn_type='github', - host='https://localhost/github/', - port=443, - extra='{"verify": "False", "project": "AIRFLOW"}', + password='my-access-token', + host='https://mygithub.com/api/v3', ) )