diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py index bb89cce0a5c69..bd78fb7ded877 100644 --- a/airflow/providers_manager.py +++ b/airflow/providers_manager.py @@ -188,6 +188,7 @@ class ConnectionFormWidgetInfo(NamedTuple): hook_class_name: str package_name: str field: Any + field_name: str T = TypeVar("T", bound=Callable) @@ -797,24 +798,22 @@ def _import_hook( ) def _add_widgets(self, package_name: str, hook_class: type, widgets: Dict[str, Any]): - for field_name, field in widgets.items(): - if not field_name.startswith("extra__"): - log.warning( - "The field %s from class %s does not start with 'extra__'. Ignoring it.", - field_name, - hook_class.__name__, - ) - continue - if field_name in self._connection_form_widgets: + conn_type = hook_class.conn_type # type: ignore + for field_identifier, field in widgets.items(): + if field_identifier.startswith('extra__'): + prefixed_field_name = field_identifier + else: + prefixed_field_name = f"extra__{conn_type}__{field_identifier}" + if prefixed_field_name in self._connection_form_widgets: log.warning( "The field %s from class %s has already been added by another provider. Ignoring it.", - field_name, + field_identifier, hook_class.__name__, ) # In case of inherited hooks this might be happening several times continue - self._connection_form_widgets[field_name] = ConnectionFormWidgetInfo( - hook_class.__name__, package_name, field + self._connection_form_widgets[prefixed_field_name] = ConnectionFormWidgetInfo( + hook_class.__name__, package_name, field, field_identifier ) def _add_customized_fields(self, package_name: str, hook_class: type, customized_fields: Dict): diff --git a/airflow/www/static/js/connection_form.js b/airflow/www/static/js/connection_form.js index e814c0b7dfb18..aabe547a88f59 100644 --- a/airflow/www/static/js/connection_form.js +++ b/airflow/www/static/js/connection_form.js @@ -263,7 +263,11 @@ $(document).ready(() => { }); // Check if field is a custom form field. } else if (this.name.startsWith('extra__')) { + // prior to Airflow 2.3 custom fields were stored in the extra dict with prefix + // post-2.3 we allow to use with no prefix + // here we don't know which we are configured to use, so we populate both extrasObj[this.name] = this.value; + extrasObj[this.name.replace(/extra__.+?__/, '')] = this.value; } else { outObj[this.name] = this.value; } diff --git a/airflow/www/views.py b/airflow/www/views.py index 147132f93b34c..06751ec68ced4 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -3820,6 +3820,7 @@ def _get_connection_types() -> List[Tuple[str, str]]: ) for key, value in ProvidersManager().connection_form_widgets.items(): setattr(ConnectionForm, key, value.field) + ConnectionModelView.extra_field_name_mapping[key] = value.field_name ConnectionModelView.add_columns.append(key) ConnectionModelView.edit_columns.append(key) ConnectionModelView.extra_fields.append(key) @@ -3909,6 +3910,8 @@ class ConnectionModelView(AirflowModelView): base_order = ('conn_id', 'asc') + extra_field_name_mapping: Dict[str, str] = {} + @action('muldelete', 'Delete', 'Are you sure you want to delete selected records?', single=False) @auth.has_access( [ @@ -3994,7 +3997,6 @@ def action_mulduplicate(self, connections, session=None): def process_form(self, form, is_created): """Process form data.""" - conn_type = form.data['conn_type'] conn_id = form.data["conn_id"] # The extra value is the combination of custom fields for this conn_type and the Extra field. @@ -4002,11 +4004,11 @@ def process_form(self, form, is_created): # so we start with those values, and override them with anything in the custom fields. extra = {} - extra_field = form.data.get("extra") + extra_json = form.data.get("extra") - if extra_field: + if extra_json: try: - extra.update(json.loads(extra_field)) + extra.update(json.loads(extra_json)) except (JSONDecodeError, TypeError): flash( Markup( @@ -4015,18 +4017,19 @@ def process_form(self, form, is_created): "

If connection parameters need to be added to Extra, " "please make sure they are in the form of a single, valid JSON object.


" "The following Extra parameters were not added to the connection:
" - f"{extra_field}", + f"{extra_json}", ), category="error", ) + del extra_json - custom_fields = { - key: form.data[key] - for key in self.extra_fields - if key in form.data and key.startswith(f"extra__{conn_type}__") - } + for key in self.extra_fields: + if key in form.data and key.startswith("extra__"): + value = form.data[key] - extra.update(custom_fields) + if value: + field_name = self.extra_field_name_mapping[key] + extra[field_name] = value if extra.keys(): form.extra.data = json.dumps(extra) @@ -4046,10 +4049,16 @@ def prefill_form(self, form, pk): logging.warning('extra field for %s is not a dictionary', form.data.get('conn_id', '')) return - for field in self.extra_fields: - value = extra_dictionary.get(field, '') + for field_key in self.extra_fields: + field_name = self.extra_field_name_mapping[field_key] + value = extra_dictionary.get(field_name, '') + + if not value: + # check if connection `extra` json is using old prefixed field name style + value = extra_dictionary.get(field_key, '') + if value: - field = getattr(form, field) + field = getattr(form, field_key) field.data = value diff --git a/docs/apache-airflow/howto/connection.rst b/docs/apache-airflow/howto/connection.rst index 88e26e18d1234..6171561baa7bf 100644 --- a/docs/apache-airflow/howto/connection.rst +++ b/docs/apache-airflow/howto/connection.rst @@ -229,6 +229,63 @@ deprecated ``hook-class-names``) in the provider meta-data, you can customize Ai You can read more about details how to add custom provider packages in the :doc:`apache-airflow-providers:index` +Custom connection fields +------------------------ + +It is possible to add custom form fields in the connection add / edit views in the Airflow webserver. +Custom fields are stored in the ``Connection.extra`` field as JSON. To add a custom field, implement +method :meth:`~BaseHook.get_connection_form_widgets`. This method should return a dictionary. The keys +should be the string name of the field as it should be stored in the ``extra`` dict. The values should +be inheritors of :class:`wtforms.fields.core.Field`. + +Here's an example: + +.. code-block:: python + + @staticmethod + def get_connection_form_widgets() -> Dict[str, Any]: + """Returns connection widgets to add to connection form""" + from flask_appbuilder.fieldwidgets import BS3TextFieldWidget + from flask_babel import lazy_gettext + from wtforms import StringField + + return { + "workspace": StringField( + lazy_gettext("Workspace"), widget=BS3TextFieldWidget() + ), + "project": StringField(lazy_gettext("Project"), widget=BS3TextFieldWidget()), + } + +.. note:: Custom fields no longer need the ``extra____`` prefix + + Prior to Airflow 2.3, if you wanted a custom field in the UI, you had to prefix it with ``extra____``, + and this is how its value would be stored in the ``extra`` dict. From 2.3 onward, you no longer need to do this. + +Method :meth:`~BaseHook.get_ui_field_behaviour` lets you customize behavior of both . For example you can +hide or relabel a field (e.g. if it's unused or re-purposed) and you can add placeholder text. + +An example: + +.. code-block:: python + + @staticmethod + def get_ui_field_behaviour() -> Dict[str, Any]: + """Returns custom field behaviour""" + return { + "hidden_fields": ["port", "host", "login", "schema"], + "relabeling": {}, + "placeholders": { + "password": "Asana personal access token", + "extra__my_conn_type__workspace": "My workspace gid", + "extra__my_conn_type__project": "My project gid", + }, + } + +Note here that *here* (in contrast with ``get_connection_form_widgets``) we must add the prefix ``extra____`` when referencing a custom field. This is this is because it's possible to create a custom field whose name overlaps with a built-in field and we need to be able to reference it unambiguously. + +Take a look at providers for examples of what you can do, for example :py:class:`~airflow.providers.jdbc.hooks.jdbc.JdbcHook` +and :py:class:`~airflow.providers.asana.hooks.jdbc.AsanaHook` both make use of this feature. + .. note:: Deprecated ``hook-class-names`` Prior to Airflow 2.2.0, the connections in providers have been exposed via ``hook-class-names`` array diff --git a/newsfragments/22607.significant.rst b/newsfragments/22607.significant.rst new file mode 100644 index 0000000000000..4a3b5a689335c --- /dev/null +++ b/newsfragments/22607.significant.rst @@ -0,0 +1,5 @@ +Remove requirement that custom connection UI fields be prefixed + +Hooks can define custom connection fields for their connection type by implementing method ``get_connection_form_widgets``. These custom fields are presented in the web UI as additional connection attributes, but internally they are stored in the connection ``extra`` dict. For technical reasons, previously custom field when stored in ``extra`` had to be named with a prefix ``extra____``. This had the consequence of making it more cumbersome to define connections outside of the UI, since the prefix is tougher to read and work with. With #22607, we make it so that you can now define custom fields such that they can be read from and stored in ``extra`` without the prefix. + +To enable this, update the dict returned by the ``get_connection_form_widgets`` method to remove the prefix from the keys. Internally, the providers manager will still use a prefix to ensure each custom field is globally unique, but the absence of a prefix in the returned widget dict will signal to the Web UI to read and store custom fields without the prefix. Note that this is only a change to the Web UI behavior; when updating your hook in this way, you must make sure that when your *hook* reads the ``extra`` field, it will also check for the prefixed value for backward compatibility. diff --git a/tests/core/test_providers_manager.py b/tests/core/test_providers_manager.py index 174527c8fa136..4d97eb2fa04dd 100644 --- a/tests/core/test_providers_manager.py +++ b/tests/core/test_providers_manager.py @@ -17,16 +17,19 @@ # under the License. import logging import re -import unittest +from typing import Dict from unittest.mock import patch import pytest +from flask_appbuilder.fieldwidgets import BS3TextFieldWidget +from flask_babel import lazy_gettext +from wtforms import BooleanField, Field, StringField from airflow.exceptions import AirflowOptionalProviderFeatureException from airflow.providers_manager import HookClassProvider, ProviderInfo, ProvidersManager -class TestProviderManager(unittest.TestCase): +class TestProviderManager: @pytest.fixture(autouse=True) def inject_fixtures(self, caplog): self._caplog = caplog @@ -142,6 +145,49 @@ def test_connection_form_widgets(self): connections_form_widgets = list(provider_manager.connection_form_widgets.keys()) assert len(connections_form_widgets) > 29 + @pytest.mark.parametrize( + 'scenario', + [ + 'prefix', + 'no_prefix', + 'both_1', + 'both_2', + ], + ) + def test_connection_form__add_widgets_prefix_backcompat(self, scenario): + """ + When the field name is prefixed, it should be used as is. + When not prefixed, we should add the prefix + When there's a collision, the one that appears first in the list will be used. + """ + + class MyHook: + conn_type = 'test' + + provider_manager = ProvidersManager() + widget_field = StringField(lazy_gettext('My Param'), widget=BS3TextFieldWidget()) + dummy_field = BooleanField(label=lazy_gettext('Dummy param'), description="dummy") + widgets: Dict[str, Field] = {} + if scenario == 'prefix': + widgets['extra__test__my_param'] = widget_field + elif scenario == 'no_prefix': + widgets['my_param'] = widget_field + elif scenario == 'both_1': + widgets['my_param'] = widget_field + widgets['extra__test__my_param'] = dummy_field + elif scenario == 'both_2': + widgets['extra__test__my_param'] = widget_field + widgets['my_param'] = dummy_field + else: + raise Exception('unexpected') + + provider_manager._add_widgets( + package_name='abc', + hook_class=MyHook, + widgets=widgets, + ) + assert provider_manager.connection_form_widgets['extra__test__my_param'].field == widget_field + def test_field_behaviours(self): provider_manager = ProvidersManager() connections_with_field_behaviours = list(provider_manager.field_behaviours.keys()) diff --git a/tests/www/views/test_views_connection.py b/tests/www/views/test_views_connection.py index d0fb165be3f38..8de7959717703 100644 --- a/tests/www/views/test_views_connection.py +++ b/tests/www/views/test_views_connection.py @@ -17,8 +17,10 @@ # under the License. import json from unittest import mock +from unittest.mock import PropertyMock import pytest +from pytest import param from airflow.models import Connection from airflow.utils.session import create_session @@ -51,16 +53,54 @@ def test_create_connection(admin_client): def test_prefill_form_null_extra(): mock_form = mock.Mock() - mock_form.data = {"conn_id": "test", "extra": None} + mock_form.data = {"conn_id": "test", "extra": None, "conn_type": "test"} cmv = ConnectionModelView() cmv.prefill_form(form=mock_form, pk=1) -def test_process_form_extras(): +@pytest.mark.parametrize( + 'extras, expected', + [ + param({"extra__test__my_param": "this_val"}, "this_val", id='conn_not_upgraded'), + param({"my_param": "my_val"}, "my_val", id='conn_upgraded'), + param( + {"extra__test__my_param": "this_val", "my_param": "my_val"}, + "my_val", + id='conn_upgraded_old_val_present', + ), + ], +) +def test_prefill_form_backcompat(extras, expected): + """ + When populating custom fields in the connection form we should first check for the non-prefixed + value (since prefixes in extra are deprecated) and then fallback to the prefixed value. + + Either way, the field is known internally to the model view as the prefixed value. + """ + mock_form = mock.Mock() + mock_form.data = {"conn_id": "test", "extra": json.dumps(extras), "conn_type": "test"} + cmv = ConnectionModelView() + cmv.extra_fields = ['extra__test__my_param'] + + # this is set by `lazy_add_provider_discovered_options_to_connection_form` + cmv.extra_field_name_mapping['extra__test__my_param'] = 'my_param' + + cmv.prefill_form(form=mock_form, pk=1) + assert mock_form.extra__test__my_param.data == expected + + +@pytest.mark.parametrize('field_name', ['extra__test__custom_field', 'custom_field']) +@mock.patch('airflow.utils.module_loading.import_string') +@mock.patch('airflow.providers_manager.ProvidersManager.hooks', new_callable=PropertyMock) +def test_process_form_extras_both(mock_pm_hooks, mock_import_str, field_name): """ Test the handling of connection parameters set with the classic `Extra` field as well as custom fields. + The key used in the field definition returned by `get_connection_form_widgets` is stored in + attr `extra_field_name_mapping`. Whatever is defined there is what should end up in `extra` when + the form is processed. """ + mock_pm_hooks.get.return_value = True # ensure that hook appears registered # Testing parameters set in both `Extra` and custom fields. mock_form = mock.Mock() @@ -72,14 +112,27 @@ def test_process_form_extras(): } cmv = ConnectionModelView() + + # this is set by `lazy_add_provider_discovered_options_to_connection_form` + cmv.extra_field_name_mapping['extra__test__custom_field'] = field_name cmv.extra_fields = ["extra__test__custom_field"] # Custom field cmv.process_form(form=mock_form, is_created=True) assert json.loads(mock_form.extra.data) == { - "extra__test__custom_field": "custom_field_val", + field_name: "custom_field_val", "param1": "param1_val", } + +@mock.patch('airflow.utils.module_loading.import_string') +@mock.patch('airflow.providers_manager.ProvidersManager.hooks', new_callable=PropertyMock) +def test_process_form_extras_extra_only(mock_pm_hooks, mock_import_str): + """ + Test the handling of connection parameters set with the classic `Extra` field as well as custom fields. + The key used in the field definition returned by `get_connection_form_widgets` is stored in + attr `extra_field_name_mapping`. Whatever is defined there is what should end up in `extra` when + the form is processed. + """ # Testing parameters set in `Extra` field only. mock_form = mock.Mock() mock_form.data = { @@ -89,10 +142,23 @@ def test_process_form_extras(): } cmv = ConnectionModelView() + cmv.process_form(form=mock_form, is_created=True) assert json.loads(mock_form.extra.data) == {"param2": "param2_val"} + +@pytest.mark.parametrize('field_name', ['extra__test3__custom_field', 'custom_field']) +@mock.patch('airflow.utils.module_loading.import_string') +@mock.patch('airflow.providers_manager.ProvidersManager.hooks', new_callable=PropertyMock) +def test_process_form_extras_custom_only(mock_pm_hooks, mock_import_str, field_name): + """ + Test the handling of connection parameters set with the classic `Extra` field as well as custom fields. + The key used in the field definition returned by `get_connection_form_widgets` is stored in + attr `extra_field_name_mapping`. Whatever is defined there is what should end up in `extra` when + the form is processed. + """ + # Testing parameters set in custom fields only. mock_form = mock.Mock() mock_form.data = { @@ -103,11 +169,26 @@ def test_process_form_extras(): cmv = ConnectionModelView() cmv.extra_fields = ["extra__test3__custom_field"] # Custom field + + # this is set by `lazy_add_provider_discovered_options_to_connection_form` + cmv.extra_field_name_mapping['extra__test3__custom_field'] = field_name cmv.process_form(form=mock_form, is_created=True) - assert json.loads(mock_form.extra.data) == {"extra__test3__custom_field": "custom_field_val3"} + assert json.loads(mock_form.extra.data) == {field_name: "custom_field_val3"} - # Testing parameters set in both extra and custom fields (cunnection updates). + +@pytest.mark.parametrize('field_name', ['extra__test4__custom_field', 'custom_field']) +@mock.patch('airflow.utils.module_loading.import_string') +@mock.patch('airflow.providers_manager.ProvidersManager.hooks', new_callable=PropertyMock) +def test_process_form_extras_updates(mock_pm_hooks, mock_import_str, field_name): + """ + Test the handling of connection parameters set with the classic `Extra` field as well as custom fields. + The key used in the field definition returned by `get_connection_form_widgets` is stored in + attr `extra_field_name_mapping`. Whatever is defined there is what should end up in `extra` when + the form is processed. + """ + + # Testing parameters set in both extra and custom fields (connection updates). mock_form = mock.Mock() mock_form.data = { "conn_type": "test4", @@ -118,9 +199,19 @@ def test_process_form_extras(): cmv = ConnectionModelView() cmv.extra_fields = ["extra__test4__custom_field"] # Custom field + + # this is set by `lazy_add_provider_discovered_options_to_connection_form` + cmv.extra_field_name_mapping['extra__test4__custom_field'] = field_name + cmv.process_form(form=mock_form, is_created=True) - assert json.loads(mock_form.extra.data) == {"extra__test4__custom_field": "custom_field_val4"} + if field_name == 'custom_field': + assert json.loads(mock_form.extra.data) == { + "custom_field": "custom_field_val4", + "extra__test4__custom_field": "custom_field_val3", + } + else: + assert json.loads(mock_form.extra.data) == {"extra__test4__custom_field": "custom_field_val4"} def test_duplicate_connection(admin_client):