diff --git a/airflow/cli/commands/connection_command.py b/airflow/cli/commands/connection_command.py index bc3dcecacfbbe..c79ba6a4d1b32 100644 --- a/airflow/cli/commands/connection_command.py +++ b/airflow/cli/commands/connection_command.py @@ -28,7 +28,7 @@ from airflow.exceptions import AirflowNotFoundException from airflow.hooks.base import BaseHook from airflow.models import Connection -from airflow.secrets.local_filesystem import _create_connection, load_connections_dict +from airflow.secrets.local_filesystem import load_connections_dict from airflow.utils import cli as cli_utils, yaml from airflow.utils.cli import suppress_logs_and_warning from airflow.utils.session import create_session @@ -238,7 +238,7 @@ def connections_delete(args): @cli_utils.action_logging def connections_import(args): - """Imports connections from a given file""" + """Imports connections from a file""" if os.path.exists(args.file): _import_helper(args.file) else: @@ -246,31 +246,14 @@ def connections_import(args): def _import_helper(file_path): - """Helps import connections from a file""" + """Load connections from a file and save them to the DB. On collision, skip.""" connections_dict = load_connections_dict(file_path) with create_session() as session: - for conn_id, conn_values in connections_dict.items(): + for conn_id, conn in connections_dict.items(): if session.query(Connection).filter(Connection.conn_id == conn_id).first(): print(f'Could not import connection {conn_id}: connection already exists.') continue - allowed_fields = [ - 'extra', - 'description', - 'conn_id', - 'login', - 'conn_type', - 'host', - 'password', - 'schema', - 'port', - 'uri', - 'extra_dejson', - ] - filtered_connection_values = { - key: value for key, value in conn_values.items() if key in allowed_fields - } - connection = _create_connection(conn_id, filtered_connection_values) - session.add(connection) + session.add(conn) session.commit() print(f'Imported connection {conn_id}') diff --git a/airflow/models/connection.py b/airflow/models/connection.py index da0a6ce5b5bc2..d3000876bc781 100644 --- a/airflow/models/connection.py +++ b/airflow/models/connection.py @@ -19,7 +19,7 @@ import json import warnings from json import JSONDecodeError -from typing import Dict, Optional +from typing import Dict, Optional, Union from urllib.parse import parse_qsl, quote, unquote, urlencode, urlparse from sqlalchemy import Boolean, Column, Integer, String, Text @@ -117,12 +117,14 @@ def __init__( # pylint: disable=too-many-arguments password: Optional[str] = None, schema: Optional[str] = None, port: Optional[int] = None, - extra: Optional[str] = None, + extra: Optional[Union[str, dict]] = None, uri: Optional[str] = None, ): super().__init__() self.conn_id = conn_id self.description = description + if extra and not isinstance(extra, str): + extra = json.dumps(extra) if uri and ( # pylint: disable=too-many-boolean-expressions conn_type or host or login or password or schema or port or extra ): diff --git a/tests/cli/commands/test_connection_command.py b/tests/cli/commands/test_connection_command.py index 136811dbe3e44..5339083155eee 100644 --- a/tests/cli/commands/test_connection_command.py +++ b/tests/cli/commands/test_connection_command.py @@ -758,9 +758,9 @@ def test_cli_connections_import_should_return_error_if_file_format_is_invalid( ): connection_command.connections_import(self.parser.parse_args(["connections", "import", filepath])) - @mock.patch('airflow.cli.commands.connection_command.load_connections_dict') + @mock.patch('airflow.secrets.local_filesystem._parse_secret_file') @mock.patch('os.path.exists') - def test_cli_connections_import_should_load_connections(self, mock_exists, mock_load_connections_dict): + def test_cli_connections_import_should_load_connections(self, mock_exists, mock_parse_secret_file): mock_exists.return_value = True # Sample connections to import @@ -769,26 +769,26 @@ def test_cli_connections_import_should_load_connections(self, mock_exists, mock_ "conn_type": "postgres", "description": "new0 description", "host": "host", - "is_encrypted": False, - "is_extra_encrypted": False, "login": "airflow", + "password": "password", "port": 5432, "schema": "airflow", + "extra": "test", }, "new1": { "conn_type": "mysql", "description": "new1 description", "host": "host", - "is_encrypted": False, - "is_extra_encrypted": False, "login": "airflow", + "password": "password", "port": 3306, "schema": "airflow", + "extra": "test", }, } - # We're not testing the behavior of load_connections_dict, assume successfully reads JSON, YAML or env - mock_load_connections_dict.return_value = expected_connections + # We're not testing the behavior of _parse_secret_file, assume it successfully reads JSON, YAML or env + mock_parse_secret_file.return_value = expected_connections connection_command.connections_import( self.parser.parse_args(["connections", "import", 'sample.json']) @@ -799,14 +799,15 @@ def test_cli_connections_import_should_load_connections(self, mock_exists, mock_ current_conns = session.query(Connection).all() comparable_attrs = [ + "conn_id", "conn_type", "description", "host", - "is_encrypted", - "is_extra_encrypted", "login", + "password", "port", "schema", + "extra", ] current_conns_as_dicts = { @@ -816,80 +817,81 @@ def test_cli_connections_import_should_load_connections(self, mock_exists, mock_ assert expected_connections == current_conns_as_dicts @provide_session - @mock.patch('airflow.cli.commands.connection_command.load_connections_dict') + @mock.patch('airflow.secrets.local_filesystem._parse_secret_file') @mock.patch('os.path.exists') def test_cli_connections_import_should_not_overwrite_existing_connections( - self, mock_exists, mock_load_connections_dict, session=None + self, mock_exists, mock_parse_secret_file, session=None ): mock_exists.return_value = True - # Add a pre-existing connection "new1" + # Add a pre-existing connection "new3" merge_conn( Connection( - conn_id="new1", + conn_id="new3", conn_type="mysql", - description="mysql description", + description="original description", host="mysql", login="root", - password="", + password="password", schema="airflow", ), session=session, ) - # Sample connections to import, including a collision with "new1" + # Sample connections to import, including a collision with "new3" expected_connections = { - "new0": { + "new2": { "conn_type": "postgres", - "description": "new0 description", + "description": "new2 description", "host": "host", - "is_encrypted": False, - "is_extra_encrypted": False, "login": "airflow", + "password": "password", "port": 5432, "schema": "airflow", + "extra": "test", }, - "new1": { + "new3": { "conn_type": "mysql", - "description": "new1 description", + "description": "updated description", "host": "host", - "is_encrypted": False, - "is_extra_encrypted": False, "login": "airflow", + "password": "new password", "port": 3306, "schema": "airflow", + "extra": "test", }, } - # We're not testing the behavior of load_connections_dict, assume successfully reads JSON, YAML or env - mock_load_connections_dict.return_value = expected_connections + # We're not testing the behavior of _parse_secret_file, assume it successfully reads JSON, YAML or env + mock_parse_secret_file.return_value = expected_connections with redirect_stdout(io.StringIO()) as stdout: connection_command.connections_import( self.parser.parse_args(["connections", "import", 'sample.json']) ) - assert 'Could not import connection new1: connection already exists.' in stdout.getvalue() + assert 'Could not import connection new3: connection already exists.' in stdout.getvalue() # Verify that the imported connections match the expected, sample connections current_conns = session.query(Connection).all() comparable_attrs = [ + "conn_id", "conn_type", "description", "host", - "is_encrypted", - "is_extra_encrypted", "login", + "password", "port", "schema", + "extra", ] current_conns_as_dicts = { current_conn.conn_id: {attr: getattr(current_conn, attr) for attr in comparable_attrs} for current_conn in current_conns } - assert current_conns_as_dicts['new0'] == expected_connections['new0'] + assert current_conns_as_dicts['new2'] == expected_connections['new2'] # The existing connection's description should not have changed - assert current_conns_as_dicts['new1']['description'] == 'new1 description' + assert current_conns_as_dicts['new3']['description'] == 'original description'