Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable use of custom conn extra fields without prefix #22607

Merged
merged 13 commits into from
Apr 25, 2022
23 changes: 11 additions & 12 deletions airflow/providers_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ class ConnectionFormWidgetInfo(NamedTuple):
hook_class_name: str
package_name: str
field: Any
field_name: str


T = TypeVar("T", bound=Callable)
Expand Down Expand Up @@ -797,24 +798,22 @@ def _import_hook(
)

def _add_widgets(self, package_name: str, hook_class: type, widgets: Dict[str, Any]):
for field_name, field in widgets.items():
if not field_name.startswith("extra__"):
log.warning(
"The field %s from class %s does not start with 'extra__'. Ignoring it.",
field_name,
hook_class.__name__,
)
continue
if field_name in self._connection_form_widgets:
conn_type = hook_class.conn_type # type: ignore
for field_identifier, field in widgets.items():
if field_identifier.startswith('extra__'):
prefixed_field_name = field_identifier
else:
prefixed_field_name = f"extra__{conn_type}__{field_identifier}"
if prefixed_field_name in self._connection_form_widgets:
log.warning(
"The field %s from class %s has already been added by another provider. Ignoring it.",
field_name,
field_identifier,
hook_class.__name__,
)
# In case of inherited hooks this might be happening several times
continue
self._connection_form_widgets[field_name] = ConnectionFormWidgetInfo(
hook_class.__name__, package_name, field
self._connection_form_widgets[prefixed_field_name] = ConnectionFormWidgetInfo(
hook_class.__name__, package_name, field, field_identifier
)

def _add_customized_fields(self, package_name: str, hook_class: type, customized_fields: Dict):
Expand Down
4 changes: 4 additions & 0 deletions airflow/www/static/js/connection_form.js
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,11 @@ $(document).ready(() => {
});
// Check if field is a custom form field.
} else if (this.name.startsWith('extra__')) {
// prior to Airflow 2.3 custom fields were stored in the extra dict with prefix
// post-2.3 we allow to use with no prefix
// here we don't know which we are configured to use, so we populate both
extrasObj[this.name] = this.value;
extrasObj[this.name.replace(/extra__.+?__/, '')] = this.value;
} else {
outObj[this.name] = this.value;
}
Expand Down
37 changes: 23 additions & 14 deletions airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3820,6 +3820,7 @@ def _get_connection_types() -> List[Tuple[str, str]]:
)
for key, value in ProvidersManager().connection_form_widgets.items():
setattr(ConnectionForm, key, value.field)
ConnectionModelView.extra_field_name_mapping[key] = value.field_name
ConnectionModelView.add_columns.append(key)
ConnectionModelView.edit_columns.append(key)
ConnectionModelView.extra_fields.append(key)
Expand Down Expand Up @@ -3909,6 +3910,8 @@ class ConnectionModelView(AirflowModelView):

base_order = ('conn_id', 'asc')

extra_field_name_mapping: Dict[str, str] = {}

@action('muldelete', 'Delete', 'Are you sure you want to delete selected records?', single=False)
@auth.has_access(
[
Expand Down Expand Up @@ -3994,19 +3997,18 @@ def action_mulduplicate(self, connections, session=None):

def process_form(self, form, is_created):
"""Process form data."""
conn_type = form.data['conn_type']
conn_id = form.data["conn_id"]

# The extra value is the combination of custom fields for this conn_type and the Extra field.
# The extra form field with all extra values (including custom fields) is in the form being processed
# so we start with those values, and override them with anything in the custom fields.
extra = {}

extra_field = form.data.get("extra")
extra_json = form.data.get("extra")

if extra_field:
if extra_json:
try:
extra.update(json.loads(extra_field))
extra.update(json.loads(extra_json))
except (JSONDecodeError, TypeError):
flash(
Markup(
Expand All @@ -4015,18 +4017,19 @@ def process_form(self, form, is_created):
"<p>If connection parameters need to be added to <em>Extra</em>, "
"please make sure they are in the form of a single, valid JSON object.</p><br>"
"The following <em>Extra</em> parameters were <b>not</b> added to the connection:<br>"
f"{extra_field}",
f"{extra_json}",
),
category="error",
)
del extra_json

custom_fields = {
key: form.data[key]
for key in self.extra_fields
if key in form.data and key.startswith(f"extra__{conn_type}__")
}
for key in self.extra_fields:
if key in form.data and key.startswith("extra__"):
value = form.data[key]

extra.update(custom_fields)
if value:
field_name = self.extra_field_name_mapping[key]
extra[field_name] = value

if extra.keys():
form.extra.data = json.dumps(extra)
Expand All @@ -4046,10 +4049,16 @@ def prefill_form(self, form, pk):
logging.warning('extra field for %s is not a dictionary', form.data.get('conn_id', '<unknown>'))
return

for field in self.extra_fields:
value = extra_dictionary.get(field, '')
for field_key in self.extra_fields:
field_name = self.extra_field_name_mapping[field_key]
value = extra_dictionary.get(field_name, '')

if not value:
# check if connection `extra` json is using old prefixed field name style
value = extra_dictionary.get(field_key, '')

if value:
field = getattr(form, field)
field = getattr(form, field_key)
field.data = value


Expand Down
57 changes: 57 additions & 0 deletions docs/apache-airflow/howto/connection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,63 @@ deprecated ``hook-class-names``) in the provider meta-data, you can customize Ai

You can read more about details how to add custom provider packages in the :doc:`apache-airflow-providers:index`

Custom connection fields
------------------------

It is possible to add custom form fields in the connection add / edit views in the Airflow webserver.
Custom fields are stored in the ``Connection.extra`` field as JSON. To add a custom field, implement
method :meth:`~BaseHook.get_connection_form_widgets`. This method should return a dictionary. The keys
should be the string name of the field as it should be stored in the ``extra`` dict. The values should
be inheritors of :class:`wtforms.fields.core.Field`.

Here's an example:

.. code-block:: python

@staticmethod
def get_connection_form_widgets() -> Dict[str, Any]:
"""Returns connection widgets to add to connection form"""
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
from flask_babel import lazy_gettext
from wtforms import StringField

return {
"workspace": StringField(
lazy_gettext("Workspace"), widget=BS3TextFieldWidget()
),
"project": StringField(lazy_gettext("Project"), widget=BS3TextFieldWidget()),
}

.. note:: Custom fields no longer need the ``extra__<conn type>__`` prefix

Prior to Airflow 2.3, if you wanted a custom field in the UI, you had to prefix it with ``extra__<conn type>__``,
and this is how its value would be stored in the ``extra`` dict. From 2.3 onward, you no longer need to do this.

Method :meth:`~BaseHook.get_ui_field_behaviour` lets you customize behavior of both . For example you can
hide or relabel a field (e.g. if it's unused or re-purposed) and you can add placeholder text.

An example:

.. code-block:: python

@staticmethod
def get_ui_field_behaviour() -> Dict[str, Any]:
"""Returns custom field behaviour"""
return {
"hidden_fields": ["port", "host", "login", "schema"],
"relabeling": {},
"placeholders": {
"password": "Asana personal access token",
"extra__my_conn_type__workspace": "My workspace gid",
"extra__my_conn_type__project": "My project gid",
},
}

Note here that *here* (in contrast with ``get_connection_form_widgets``) we must add the prefix ``extra__<conn type>__`` when referencing a custom field. This is this is because it's possible to create a custom field whose name overlaps with a built-in field and we need to be able to reference it unambiguously.

Take a look at providers for examples of what you can do, for example :py:class:`~airflow.providers.jdbc.hooks.jdbc.JdbcHook`
and :py:class:`~airflow.providers.asana.hooks.jdbc.AsanaHook` both make use of this feature.

.. note:: Deprecated ``hook-class-names``

Prior to Airflow 2.2.0, the connections in providers have been exposed via ``hook-class-names`` array
Expand Down
5 changes: 5 additions & 0 deletions newsfragments/22607.significant.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Remove requirement that custom connection UI fields be prefixed

Hooks can define custom connection fields for their connection type by implementing method ``get_connection_form_widgets``. These custom fields are presented in the web UI as additional connection attributes, but internally they are stored in the connection ``extra`` dict. For technical reasons, previously custom field when stored in ``extra`` had to be named with a prefix ``extra__<conn type>__<field name>``. This had the consequence of making it more cumbersome to define connections outside of the UI, since the prefix is tougher to read and work with. With #22607, we make it so that you can now define custom fields such that they can be read from and stored in ``extra`` without the prefix.

To enable this, update the dict returned by the ``get_connection_form_widgets`` method to remove the prefix from the keys. Internally, the providers manager will still use a prefix to ensure each custom field is globally unique, but the absence of a prefix in the returned widget dict will signal to the Web UI to read and store custom fields without the prefix. Note that this is only a change to the Web UI behavior; when updating your hook in this way, you must make sure that when your *hook* reads the ``extra`` field, it will also check for the prefixed value for backward compatibility.
50 changes: 48 additions & 2 deletions tests/core/test_providers_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,19 @@
# under the License.
import logging
import re
import unittest
from typing import Dict
from unittest.mock import patch

import pytest
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
from flask_babel import lazy_gettext
from wtforms import BooleanField, Field, StringField

from airflow.exceptions import AirflowOptionalProviderFeatureException
from airflow.providers_manager import HookClassProvider, ProviderInfo, ProvidersManager


class TestProviderManager(unittest.TestCase):
class TestProviderManager:
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog
Expand Down Expand Up @@ -142,6 +145,49 @@ def test_connection_form_widgets(self):
connections_form_widgets = list(provider_manager.connection_form_widgets.keys())
assert len(connections_form_widgets) > 29

@pytest.mark.parametrize(
'scenario',
[
'prefix',
'no_prefix',
'both_1',
'both_2',
],
)
def test_connection_form__add_widgets_prefix_backcompat(self, scenario):
"""
When the field name is prefixed, it should be used as is.
When not prefixed, we should add the prefix
When there's a collision, the one that appears first in the list will be used.
"""

class MyHook:
conn_type = 'test'

provider_manager = ProvidersManager()
widget_field = StringField(lazy_gettext('My Param'), widget=BS3TextFieldWidget())
dummy_field = BooleanField(label=lazy_gettext('Dummy param'), description="dummy")
widgets: Dict[str, Field] = {}
if scenario == 'prefix':
widgets['extra__test__my_param'] = widget_field
elif scenario == 'no_prefix':
widgets['my_param'] = widget_field
elif scenario == 'both_1':
widgets['my_param'] = widget_field
widgets['extra__test__my_param'] = dummy_field
elif scenario == 'both_2':
widgets['extra__test__my_param'] = widget_field
widgets['my_param'] = dummy_field
else:
raise Exception('unexpected')

provider_manager._add_widgets(
package_name='abc',
hook_class=MyHook,
widgets=widgets,
)
assert provider_manager.connection_form_widgets['extra__test__my_param'].field == widget_field

def test_field_behaviours(self):
provider_manager = ProvidersManager()
connections_with_field_behaviours = list(provider_manager.field_behaviours.keys())
Expand Down
Loading