Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework provider manager to treat Airflow core hooks like other provider hooks #33051

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow/cli/commands/connection_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def _valid_uri(uri: str) -> bool:
@cache
def _get_connection_types() -> list[str]:
"""Returns connection types available."""
_connection_types = ["fs", "mesos_framework-id", "email", "generic"]
_connection_types = []
providers_manager = ProvidersManager()
for connection_type, provider_info in providers_manager.hooks.items():
if provider_info:
Expand Down
42 changes: 40 additions & 2 deletions airflow/hooks/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
# under the License.
from __future__ import annotations

from pathlib import Path
from typing import Any

from airflow.hooks.base import BaseHook


Expand All @@ -33,9 +36,32 @@ class FSHook(BaseHook):
Extra: {"path": "/tmp"}
"""

def __init__(self, conn_id: str = "fs_default"):
conn_name_attr = "fs_conn_id"
default_conn_name = "fs_default"
conn_type = "fs"
hook_name = "File (path)"

@staticmethod
def get_connection_form_widgets() -> dict[str, Any]:
"""Returns connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
from flask_babel import lazy_gettext
from wtforms import StringField

return {"path": StringField(lazy_gettext("Path"), widget=BS3TextFieldWidget())}

@staticmethod
def get_ui_field_behaviour() -> dict[str, Any]:
"""Returns custom field behaviour."""
return {
"hidden_fields": ["host", "schema", "port", "login", "password", "extra"],
"relabeling": {},
"placeholders": {},
}

def __init__(self, fs_conn_id: str = default_conn_name):
super().__init__()
conn = self.get_connection(conn_id)
conn = self.get_connection(fs_conn_id)
self.basepath = conn.extra_dejson.get("path", "")
self.conn = conn

Expand All @@ -49,3 +75,15 @@ def get_path(self) -> str:
:return: the path.
"""
return self.basepath

def test_connection(self):
"""Test File connection."""
try:
p = self.get_path()
if not p:
return False, "File Path is undefined."
if not Path(p).exists():
return False, f"Path {p} does not exist."
return True, f"Path {p} is existing."
except Exception as e:
return False, str(e)
32 changes: 32 additions & 0 deletions airflow/providers_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from packaging.utils import canonicalize_name

from airflow.exceptions import AirflowOptionalProviderFeatureException
from airflow.hooks.filesystem import FSHook
from airflow.typing_compat import Literal
from airflow.utils import yaml
from airflow.utils.entry_points import entry_points_with_dist
Expand Down Expand Up @@ -431,6 +432,37 @@ def __init__(self):
)
# Set of plugins contained in providers
self._plugins_set: set[PluginInfo] = set()
self._init_airflow_core_hooks()

def _init_airflow_core_hooks(self):
"""Initializes the hooks dict with default hooks from Airflow core."""
core_dummy_hooks = {
"generic": "Generic",
"email": "Email",
"mesos_framework-id": "Mesos Framework ID",
}
for key, display in core_dummy_hooks.items():
self._hooks_lazy_dict[key] = HookInfo(
hook_class_name=None,
connection_id_attribute_name=None,
package_name=None,
hook_name=display,
connection_type=None,
connection_testable=False,
)
for cls in [FSHook]:
package_name = cls.__module__
hook_class_name = f"{cls.__module__}.{cls.__name__}"
hook_info = self._import_hook(
connection_type=None,
provider_info=None,
hook_class_name=hook_class_name,
package_name=package_name,
)
self._hook_provider_dict[hook_info.connection_type] = HookClassProvider(
hook_class_name=hook_class_name, package_name=package_name
)
self._hooks_lazy_dict[hook_info.connection_type] = hook_info

@provider_info_cache("list")
def initialize_providers_list(self):
Expand Down
4 changes: 0 additions & 4 deletions airflow/www/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,6 @@ def create_connection_form_class() -> type[DynamicForm]:

def _iter_connection_types() -> Iterator[tuple[str, str]]:
"""List available connection types."""
yield ("email", "Email")
yield ("fs", "File (path)")
yield ("generic", "Generic")
yield ("mesos_framework-id", "Mesos Framework ID")
for connection_type, provider_info in providers_manager.hooks.items():
if provider_info:
yield (connection_type, provider_info.hook_name)
Expand Down
6 changes: 3 additions & 3 deletions tests/always/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,14 +753,14 @@ def test_connection_test_success(self):
@mock.patch.dict(
"os.environ",
{
"AIRFLOW_CONN_TEST_URI_NO_HOOK": "fs://",
"AIRFLOW_CONN_TEST_URI_NO_HOOK": "unknown://",
},
)
def test_connection_test_no_hook(self):
conn = Connection(conn_id="test_uri_no_hook", conn_type="fs")
conn = Connection(conn_id="test_uri_no_hook", conn_type="unknown")
res = conn.test_connection()
assert res[0] is False
assert res[1] == 'Unknown hook type "fs"'
assert res[1] == 'Unknown hook type "unknown"'

@mock.patch.dict(
"os.environ",
Expand Down