Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Import connections from a file #15177

Merged
merged 4 commits into from
Apr 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions airflow/cli/cli_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,7 @@ def _check(value):
ARG_CONN_EXPORT_FORMAT = Arg(
('--format',), help='Format of the connections data in file', type=str, choices=['json', 'yaml', 'env']
)
ARG_CONN_IMPORT = Arg(("file",), help="Import connections from a file")

# providers
ARG_PROVIDER_NAME = Arg(
Expand Down Expand Up @@ -1200,6 +1201,16 @@ class GroupCommand(NamedTuple):
ARG_CONN_EXPORT_FORMAT,
),
),
ActionCommand(
name='import',
help='Import connections from a file',
description=(
"Connections can be imported from the output of the export command.\n"
"The filetype must by json, yaml or env and will be automatically inferred."
),
func=lazy_load_command('airflow.cli.commands.connection_command.connections_import'),
args=(ARG_CONN_IMPORT,),
),
)
PROVIDERS_COMMANDS = (
ActionCommand(
Expand Down
41 changes: 41 additions & 0 deletions airflow/cli/commands/connection_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from airflow.exceptions import AirflowNotFoundException
from airflow.hooks.base import BaseHook
from airflow.models import Connection
from airflow.secrets.local_filesystem import _create_connection, load_connections_dict
from airflow.utils import cli as cli_utils
from airflow.utils.cli import suppress_logs_and_warning
from airflow.utils.session import create_session
Expand Down Expand Up @@ -234,3 +235,43 @@ def connections_delete(args):
else:
session.delete(to_delete)
print(f"Successfully deleted connection with `conn_id`={to_delete.conn_id}")


@cli_utils.action_logging
def connections_import(args):
"""Imports connections from a given file"""
if os.path.exists(args.file):
_import_helper(args.file)
else:
raise SystemExit("Missing connections file.")


def _import_helper(file_path):
"""Helps import connections from a file"""
connections_dict = load_connections_dict(file_path)
with create_session() as session:
for conn_id, conn_values in connections_dict.items():
if session.query(Connection).filter(Connection.conn_id == conn_id).first():
print(f'Could not import connection {conn_id}: connection already exists.')
continue

allowed_fields = [
'extra',
'description',
'conn_id',
'login',
'conn_type',
'host',
'password',
'schema',
'port',
'uri',
'extra_dejson',
]
filtered_connection_values = {
key: value for key, value in conn_values.items() if key in allowed_fields
}
connection = _create_connection(conn_id, filtered_connection_values)
session.add(connection)
session.commit()
print(f'Imported connection {conn_id}')
173 changes: 173 additions & 0 deletions tests/cli/commands/test_connection_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from airflow.cli import cli_parser
from airflow.cli.commands import connection_command
from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.utils.db import merge_conn
from airflow.utils.session import create_session, provide_session
Expand Down Expand Up @@ -716,3 +717,175 @@ def test_cli_delete_invalid_connection(self):
# Attempt to delete a non-existing connection
with pytest.raises(SystemExit, match=r"Did not find a connection with `conn_id`=fake"):
connection_command.connections_delete(self.parser.parse_args(["connections", "delete", "fake"]))


class TestCliImportConnections(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.parser = cli_parser.get_parser()
clear_db_connections(add_default_connections_back=False)

@classmethod
def tearDownClass(cls):
clear_db_connections()

@mock.patch('os.path.exists')
def test_cli_connections_import_should_return_error_if_file_does_not_exist(self, mock_exists):
mock_exists.return_value = False
filepath = '/does/not/exist.json'
with pytest.raises(SystemExit, match=r"Missing connections file."):
connection_command.connections_import(self.parser.parse_args(["connections", "import", filepath]))

@parameterized.expand(
[
("sample.jso",),
("sample.yml",),
("sample.environ",),
]
)
@mock.patch('os.path.exists')
def test_cli_connections_import_should_return_error_if_file_format_is_invalid(
self, filepath, mock_exists
):
mock_exists.return_value = True
with pytest.raises(
AirflowException,
match=r"Unsupported file format. The file must have the extension .env or .json or .yaml",
):
connection_command.connections_import(self.parser.parse_args(["connections", "import", filepath]))

@mock.patch('airflow.cli.commands.connection_command.load_connections_dict')
@mock.patch('os.path.exists')
def test_cli_connections_import_should_load_connections(self, mock_exists, mock_load_connections_dict):
mock_exists.return_value = True

# Sample connections to import
expected_connections = {
"new0": {
"conn_type": "postgres",
"description": "new0 description",
"host": "host",
"is_encrypted": False,
"is_extra_encrypted": False,
"login": "airflow",
"port": 5432,
"schema": "airflow",
},
"new1": {
"conn_type": "mysql",
"description": "new1 description",
"host": "host",
"is_encrypted": False,
"is_extra_encrypted": False,
"login": "airflow",
"port": 3306,
"schema": "airflow",
},
}

# We're not testing the behavior of load_connections_dict, assume successfully reads JSON, YAML or env
mock_load_connections_dict.return_value = expected_connections

connection_command.connections_import(
self.parser.parse_args(["connections", "import", 'sample.json'])
)

# Verify that the imported connections match the expected, sample connections
with create_session() as session:
current_conns = session.query(Connection).all()

comparable_attrs = [
"conn_type",
"description",
"host",
"is_encrypted",
"is_extra_encrypted",
"login",
"port",
"schema",
]

current_conns_as_dicts = {
current_conn.conn_id: {attr: getattr(current_conn, attr) for attr in comparable_attrs}
for current_conn in current_conns
}
assert expected_connections == current_conns_as_dicts

@provide_session
@mock.patch('airflow.cli.commands.connection_command.load_connections_dict')
@mock.patch('os.path.exists')
def test_cli_connections_import_should_not_overwrite_existing_connections(
self, mock_exists, mock_load_connections_dict, session=None
):
mock_exists.return_value = True

# Add a pre-existing connection "new1"
merge_conn(
Connection(
conn_id="new1",
conn_type="mysql",
description="mysql description",
host="mysql",
login="root",
password="",
schema="airflow",
),
session=session,
)

# Sample connections to import, including a collision with "new1"
expected_connections = {
"new0": {
"conn_type": "postgres",
"description": "new0 description",
"host": "host",
"is_encrypted": False,
"is_extra_encrypted": False,
"login": "airflow",
"port": 5432,
"schema": "airflow",
},
"new1": {
"conn_type": "mysql",
"description": "new1 description",
"host": "host",
"is_encrypted": False,
"is_extra_encrypted": False,
"login": "airflow",
"port": 3306,
"schema": "airflow",
},
}

# We're not testing the behavior of load_connections_dict, assume successfully reads JSON, YAML or env
mock_load_connections_dict.return_value = expected_connections

with redirect_stdout(io.StringIO()) as stdout:
connection_command.connections_import(
self.parser.parse_args(["connections", "import", 'sample.json'])
)

assert 'Could not import connection new1: connection already exists.' in stdout.getvalue()

# Verify that the imported connections match the expected, sample connections
current_conns = session.query(Connection).all()

comparable_attrs = [
"conn_type",
"description",
"host",
"is_encrypted",
"is_extra_encrypted",
"login",
"port",
"schema",
]

current_conns_as_dicts = {
current_conn.conn_id: {attr: getattr(current_conn, attr) for attr in comparable_attrs}
for current_conn in current_conns
}
assert current_conns_as_dicts['new0'] == expected_connections['new0']

# The existing connection's description should not have changed
assert current_conns_as_dicts['new1']['description'] == 'new1 description'