Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CLI connections import and migrate logic from secrets to Connection model #15425

Merged
merged 4 commits into from
Jun 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 5 additions & 22 deletions airflow/cli/commands/connection_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from airflow.exceptions import AirflowNotFoundException
from airflow.hooks.base import BaseHook
from airflow.models import Connection
from airflow.secrets.local_filesystem import _create_connection, load_connections_dict
from airflow.secrets.local_filesystem import load_connections_dict
natanweinberger marked this conversation as resolved.
Show resolved Hide resolved
from airflow.utils import cli as cli_utils, yaml
from airflow.utils.cli import suppress_logs_and_warning
from airflow.utils.session import create_session
Expand Down Expand Up @@ -238,39 +238,22 @@ def connections_delete(args):

@cli_utils.action_logging
def connections_import(args):
"""Imports connections from a given file"""
"""Imports connections from a file"""
if os.path.exists(args.file):
_import_helper(args.file)
else:
raise SystemExit("Missing connections file.")


def _import_helper(file_path):
"""Helps import connections from a file"""
"""Load connections from a file and save them to the DB. On collision, skip."""
connections_dict = load_connections_dict(file_path)
with create_session() as session:
for conn_id, conn_values in connections_dict.items():
for conn_id, conn in connections_dict.items():
if session.query(Connection).filter(Connection.conn_id == conn_id).first():
print(f'Could not import connection {conn_id}: connection already exists.')
continue

allowed_fields = [
'extra',
'description',
'conn_id',
'login',
'conn_type',
'host',
'password',
'schema',
'port',
'uri',
'extra_dejson',
]
filtered_connection_values = {
key: value for key, value in conn_values.items() if key in allowed_fields
}
connection = _create_connection(conn_id, filtered_connection_values)
session.add(connection)
session.add(conn)
session.commit()
print(f'Imported connection {conn_id}')
6 changes: 4 additions & 2 deletions airflow/models/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import json
import warnings
from json import JSONDecodeError
from typing import Dict, Optional
from typing import Dict, Optional, Union
from urllib.parse import parse_qsl, quote, unquote, urlencode, urlparse

from sqlalchemy import Boolean, Column, Integer, String, Text
Expand Down Expand Up @@ -117,12 +117,14 @@ def __init__( # pylint: disable=too-many-arguments
password: Optional[str] = None,
schema: Optional[str] = None,
port: Optional[int] = None,
extra: Optional[str] = None,
extra: Optional[Union[str, dict]] = None,
uri: Optional[str] = None,
):
super().__init__()
self.conn_id = conn_id
self.description = description
if extra and not isinstance(extra, str):
extra = json.dumps(extra)
if uri and ( # pylint: disable=too-many-boolean-expressions
conn_type or host or login or password or schema or port or extra
):
Expand Down
66 changes: 34 additions & 32 deletions tests/cli/commands/test_connection_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,9 +758,9 @@ def test_cli_connections_import_should_return_error_if_file_format_is_invalid(
):
connection_command.connections_import(self.parser.parse_args(["connections", "import", filepath]))

@mock.patch('airflow.cli.commands.connection_command.load_connections_dict')
@mock.patch('airflow.secrets.local_filesystem._parse_secret_file')
@mock.patch('os.path.exists')
def test_cli_connections_import_should_load_connections(self, mock_exists, mock_load_connections_dict):
def test_cli_connections_import_should_load_connections(self, mock_exists, mock_parse_secret_file):
mock_exists.return_value = True

# Sample connections to import
Expand All @@ -769,26 +769,26 @@ def test_cli_connections_import_should_load_connections(self, mock_exists, mock_
"conn_type": "postgres",
"description": "new0 description",
"host": "host",
"is_encrypted": False,
"is_extra_encrypted": False,
"login": "airflow",
"password": "password",
"port": 5432,
"schema": "airflow",
"extra": "test",
},
"new1": {
"conn_type": "mysql",
"description": "new1 description",
"host": "host",
"is_encrypted": False,
"is_extra_encrypted": False,
"login": "airflow",
"password": "password",
"port": 3306,
"schema": "airflow",
"extra": "test",
},
}

# We're not testing the behavior of load_connections_dict, assume successfully reads JSON, YAML or env
mock_load_connections_dict.return_value = expected_connections
# We're not testing the behavior of _parse_secret_file, assume it successfully reads JSON, YAML or env
mock_parse_secret_file.return_value = expected_connections

connection_command.connections_import(
self.parser.parse_args(["connections", "import", 'sample.json'])
Expand All @@ -799,14 +799,15 @@ def test_cli_connections_import_should_load_connections(self, mock_exists, mock_
current_conns = session.query(Connection).all()

comparable_attrs = [
"conn_id",
"conn_type",
"description",
"host",
"is_encrypted",
"is_extra_encrypted",
"login",
"password",
"port",
"schema",
"extra",
]

current_conns_as_dicts = {
Expand All @@ -816,80 +817,81 @@ def test_cli_connections_import_should_load_connections(self, mock_exists, mock_
assert expected_connections == current_conns_as_dicts

@provide_session
@mock.patch('airflow.cli.commands.connection_command.load_connections_dict')
@mock.patch('airflow.secrets.local_filesystem._parse_secret_file')
@mock.patch('os.path.exists')
def test_cli_connections_import_should_not_overwrite_existing_connections(
self, mock_exists, mock_load_connections_dict, session=None
self, mock_exists, mock_parse_secret_file, session=None
):
mock_exists.return_value = True

# Add a pre-existing connection "new1"
# Add a pre-existing connection "new3"
merge_conn(
Connection(
conn_id="new1",
conn_id="new3",
conn_type="mysql",
description="mysql description",
description="original description",
host="mysql",
login="root",
password="",
password="password",
schema="airflow",
),
session=session,
)

# Sample connections to import, including a collision with "new1"
# Sample connections to import, including a collision with "new3"
expected_connections = {
"new0": {
"new2": {
"conn_type": "postgres",
"description": "new0 description",
"description": "new2 description",
"host": "host",
"is_encrypted": False,
"is_extra_encrypted": False,
"login": "airflow",
"password": "password",
"port": 5432,
"schema": "airflow",
"extra": "test",
},
"new1": {
"new3": {
"conn_type": "mysql",
"description": "new1 description",
"description": "updated description",
"host": "host",
"is_encrypted": False,
"is_extra_encrypted": False,
"login": "airflow",
"password": "new password",
"port": 3306,
"schema": "airflow",
"extra": "test",
},
}

# We're not testing the behavior of load_connections_dict, assume successfully reads JSON, YAML or env
mock_load_connections_dict.return_value = expected_connections
# We're not testing the behavior of _parse_secret_file, assume it successfully reads JSON, YAML or env
mock_parse_secret_file.return_value = expected_connections

with redirect_stdout(io.StringIO()) as stdout:
connection_command.connections_import(
self.parser.parse_args(["connections", "import", 'sample.json'])
)

assert 'Could not import connection new1: connection already exists.' in stdout.getvalue()
assert 'Could not import connection new3: connection already exists.' in stdout.getvalue()

# Verify that the imported connections match the expected, sample connections
current_conns = session.query(Connection).all()

comparable_attrs = [
"conn_id",
"conn_type",
"description",
"host",
"is_encrypted",
"is_extra_encrypted",
"login",
"password",
"port",
"schema",
"extra",
]

current_conns_as_dicts = {
current_conn.conn_id: {attr: getattr(current_conn, attr) for attr in comparable_attrs}
for current_conn in current_conns
}
assert current_conns_as_dicts['new0'] == expected_connections['new0']
assert current_conns_as_dicts['new2'] == expected_connections['new2']

# The existing connection's description should not have changed
assert current_conns_as_dicts['new1']['description'] == 'new1 description'
assert current_conns_as_dicts['new3']['description'] == 'original description'