diff --git a/airflow/cli/commands/connection_command.py b/airflow/cli/commands/connection_command.py index 8a009b5b824b0..b63e1e042b8cc 100644 --- a/airflow/cli/commands/connection_command.py +++ b/airflow/cli/commands/connection_command.py @@ -144,6 +144,7 @@ def connections_export(args): connections = session.query(Connection).order_by(Connection.conn_id).all() msg = _format_connections(connections, filetype) args.file.write(msg) + args.file.close() if _is_stdout(args.file): print("Connections successfully exported.", file=sys.stderr) diff --git a/tests/cli/commands/test_connection_command.py b/tests/cli/commands/test_connection_command.py index 5339083155eee..5a8cac0debda7 100644 --- a/tests/cli/commands/test_connection_command.py +++ b/tests/cli/commands/test_connection_command.py @@ -18,12 +18,10 @@ import io import json import re -import unittest from contextlib import redirect_stdout from unittest import mock import pytest -from parameterized import parameterized from airflow.cli import cli_parser from airflow.cli.commands import connection_command @@ -34,13 +32,17 @@ from tests.test_utils.db import clear_db_connections -class TestCliGetConnection(unittest.TestCase): - def setUp(self): - self.parser = cli_parser.get_parser() - clear_db_connections() +@pytest.fixture(scope='module', autouse=True) +def clear_connections(): + yield + clear_db_connections(add_default_connections_back=False) - def tearDown(self): - clear_db_connections() + +class TestCliGetConnection: + parser = cli_parser.get_parser() + + def setup_method(self): + clear_db_connections(add_default_connections_back=True) def test_cli_connection_get(self): with redirect_stdout(io.StringIO()) as stdout: @@ -55,68 +57,26 @@ def test_cli_connection_get_invalid(self): connection_command.connections_get(self.parser.parse_args(["connections", "get", "INVALID"])) -class TestCliListConnections(unittest.TestCase): +class TestCliListConnections: + parser = cli_parser.get_parser() EXPECTED_CONS = [ - ( - 'airflow_db', - 'mysql', - ), - ( - 'google_cloud_default', - 'google_cloud_platform', - ), - ( - 'http_default', - 'http', - ), - ( - 'local_mysql', - 'mysql', - ), - ( - 'mongo_default', - 'mongo', - ), - ( - 'mssql_default', - 'mssql', - ), - ( - 'mysql_default', - 'mysql', - ), - ( - 'pinot_broker_default', - 'pinot', - ), - ( - 'postgres_default', - 'postgres', - ), - ( - 'presto_default', - 'presto', - ), - ( - 'sqlite_default', - 'sqlite', - ), - ( - 'trino_default', - 'trino', - ), - ( - 'vertica_default', - 'vertica', - ), + ('airflow_db', 'mysql'), + ('google_cloud_default', 'google_cloud_platform'), + ('http_default', 'http'), + ('local_mysql', 'mysql'), + ('mongo_default', 'mongo'), + ('mssql_default', 'mssql'), + ('mysql_default', 'mysql'), + ('pinot_broker_default', 'pinot'), + ('postgres_default', 'postgres'), + ('presto_default', 'presto'), + ('sqlite_default', 'sqlite'), + ('trino_default', 'trino'), + ('vertica_default', 'vertica'), ] - def setUp(self): - self.parser = cli_parser.get_parser() - clear_db_connections() - - def tearDown(self): - clear_db_connections() + def setup_method(self): + clear_db_connections(add_default_connections_back=True) def test_cli_connections_list_as_json(self): args = self.parser.parse_args(["connections", "list", "--output", "json"]) @@ -133,17 +93,16 @@ def test_cli_connections_filter_conn_id(self): args = self.parser.parse_args( ["connections", "list", "--output", "json", '--conn-id', 'http_default'] ) - with redirect_stdout(io.StringIO()) as stdout: connection_command.connections_list(args) stdout = stdout.getvalue() - assert "http_default" in stdout -class TestCliExportConnections(unittest.TestCase): - @provide_session - def setUp(self, session=None): +class TestCliExportConnections: + parser = cli_parser.get_parser() + + def setup_method(self): clear_db_connections(add_default_connections_back=False) merge_conn( Connection( @@ -155,7 +114,6 @@ def setUp(self, session=None): password="plainpassword", schema="airflow", ), - session, ) merge_conn( Connection( @@ -166,86 +124,41 @@ def setUp(self, session=None): port=8082, extra='{"endpoint": "druid/v2/sql"}', ), - session, ) - self.parser = cli_parser.get_parser() - - def tearDown(self): - clear_db_connections() - def test_cli_connections_export_should_return_error_for_invalid_command(self): with pytest.raises(SystemExit): - self.parser.parse_args( - [ - "connections", - "export", - ] - ) + self.parser.parse_args(["connections", "export"]) def test_cli_connections_export_should_return_error_for_invalid_format(self): with pytest.raises(SystemExit): self.parser.parse_args(["connections", "export", "--format", "invalid", "/path/to/file"]) - @mock.patch('os.path.splitext') - @mock.patch('builtins.open', new_callable=mock.mock_open()) - def test_cli_connections_export_should_return_error_for_invalid_export_format( - self, mock_file_open, mock_splittext - ): - output_filepath = '/tmp/connections.invalid' - mock_splittext.return_value = (None, '.invalid') - - args = self.parser.parse_args( - [ - "connections", - "export", - output_filepath, - ] - ) - with pytest.raises( - SystemExit, match=r"Unsupported file format. The file must have the extension .yaml, .json, .env" - ): + def test_cli_connections_export_should_return_error_for_invalid_export_format(self, tmp_path): + output_filepath = tmp_path / 'connections.invalid' + args = self.parser.parse_args(["connections", "export", output_filepath.as_posix()]) + with pytest.raises(SystemExit, match=r"Unsupported file format"): connection_command.connections_export(args) - mock_splittext.assert_called_once() - mock_file_open.assert_called_once_with(output_filepath, 'w', -1, 'UTF-8', None) - mock_file_open.return_value.write.assert_not_called() - - @mock.patch('os.path.splitext') - @mock.patch('builtins.open', new_callable=mock.mock_open()) @mock.patch.object(connection_command, 'create_session') - def test_cli_connections_export_should_return_error_if_create_session_fails( - self, mock_session, mock_file_open, mock_splittext + def test_cli_connections_export_should_raise_error_if_create_session_fails( + self, mock_create_session, tmp_path ): - output_filepath = '/tmp/connections.json' + output_filepath = tmp_path / 'connections.json' def my_side_effect(): raise Exception("dummy exception") - mock_session.side_effect = my_side_effect - mock_splittext.return_value = (None, '.json') - - args = self.parser.parse_args( - [ - "connections", - "export", - output_filepath, - ] - ) + mock_create_session.side_effect = my_side_effect + args = self.parser.parse_args(["connections", "export", output_filepath.as_posix()]) with pytest.raises(Exception, match=r"dummy exception"): connection_command.connections_export(args) - mock_splittext.assert_not_called() - mock_file_open.assert_called_once_with(output_filepath, 'w', -1, 'UTF-8', None) - mock_file_open.return_value.write.assert_not_called() - - @mock.patch('os.path.splitext') - @mock.patch('builtins.open', new_callable=mock.mock_open()) @mock.patch.object(connection_command, 'create_session') - def test_cli_connections_export_should_return_error_if_fetching_connections_fails( - self, mock_session, mock_file_open, mock_splittext + def test_cli_connections_export_should_raise_error_if_fetching_connections_fails( + self, mock_session, tmp_path ): - output_filepath = '/tmp/connections.json' + output_filepath = tmp_path / 'connections.json' def my_side_effect(_): raise Exception("dummy exception") @@ -253,61 +166,24 @@ def my_side_effect(_): mock_session.return_value.__enter__.return_value.query.return_value.order_by.side_effect = ( my_side_effect ) - mock_splittext.return_value = (None, '.json') - - args = self.parser.parse_args( - [ - "connections", - "export", - output_filepath, - ] - ) + args = self.parser.parse_args(["connections", "export", output_filepath.as_posix()]) with pytest.raises(Exception, match=r"dummy exception"): connection_command.connections_export(args) - mock_splittext.assert_called_once() - mock_file_open.assert_called_once_with(output_filepath, 'w', -1, 'UTF-8', None) - mock_file_open.return_value.write.assert_not_called() - - @mock.patch('os.path.splitext') - @mock.patch('builtins.open', new_callable=mock.mock_open()) @mock.patch.object(connection_command, 'create_session') - def test_cli_connections_export_should_not_return_error_if_connections_is_empty( - self, mock_session, mock_file_open, mock_splittext + def test_cli_connections_export_should_not_raise_error_if_connections_is_empty( + self, mock_session, tmp_path ): - output_filepath = '/tmp/connections.json' - + output_filepath = tmp_path / 'connections.json' mock_session.return_value.__enter__.return_value.query.return_value.all.return_value = [] - mock_splittext.return_value = (None, '.json') - - args = self.parser.parse_args( - [ - "connections", - "export", - output_filepath, - ] - ) + args = self.parser.parse_args(["connections", "export", output_filepath.as_posix()]) connection_command.connections_export(args) + assert output_filepath.read_text() == '{}' - mock_splittext.assert_called_once() - mock_file_open.assert_called_once_with(output_filepath, 'w', -1, 'UTF-8', None) - mock_file_open.return_value.write.assert_called_once_with('{}') - - @mock.patch('os.path.splitext') - @mock.patch('builtins.open', new_callable=mock.mock_open()) - def test_cli_connections_export_should_export_as_json(self, mock_file_open, mock_splittext): - output_filepath = '/tmp/connections.json' - mock_splittext.return_value = (None, '.json') - - args = self.parser.parse_args( - [ - "connections", - "export", - output_filepath, - ] - ) + def test_cli_connections_export_should_export_as_json(self, tmp_path): + output_filepath = tmp_path / 'connections.json' + args = self.parser.parse_args(["connections", "export", output_filepath.as_posix()]) connection_command.connections_export(args) - expected_connections = json.dumps( { "airflow_db": { @@ -333,26 +209,12 @@ def test_cli_connections_export_should_export_as_json(self, mock_file_open, mock }, indent=2, ) + assert output_filepath.read_text() == expected_connections - mock_splittext.assert_called_once() - mock_file_open.assert_called_once_with(output_filepath, 'w', -1, 'UTF-8', None) - mock_file_open.return_value.write.assert_called_once_with(expected_connections) - - @mock.patch('os.path.splitext') - @mock.patch('builtins.open', new_callable=mock.mock_open()) - def test_cli_connections_export_should_export_as_yaml(self, mock_file_open, mock_splittext): - output_filepath = '/tmp/connections.yaml' - mock_splittext.return_value = (None, '.yaml') - - args = self.parser.parse_args( - [ - "connections", - "export", - output_filepath, - ] - ) + def test_cli_connections_export_should_export_as_yaml(self, tmp_path): + output_filepath = tmp_path / 'connections.yaml' + args = self.parser.parse_args(["connections", "export", output_filepath.as_posix()]) connection_command.connections_export(args) - expected_connections = ( "airflow_db:\n" " conn_type: mysql\n" @@ -373,84 +235,47 @@ def test_cli_connections_export_should_export_as_yaml(self, mock_file_open, mock " port: 8082\n" " schema: null\n" ) - mock_splittext.assert_called_once() - mock_file_open.assert_called_once_with(output_filepath, 'w', -1, 'UTF-8', None) - mock_file_open.return_value.write.assert_called_once_with(expected_connections) - - @mock.patch('os.path.splitext') - @mock.patch('builtins.open', new_callable=mock.mock_open()) - def test_cli_connections_export_should_export_as_env(self, mock_file_open, mock_splittext): - output_filepath = '/tmp/connections.env' - mock_splittext.return_value = (None, '.env') + assert output_filepath.read_text() == expected_connections + def test_cli_connections_export_should_export_as_env(self, tmp_path): + output_filepath = tmp_path / 'connections.env' args = self.parser.parse_args( [ "connections", "export", - output_filepath, + output_filepath.as_posix(), ] ) connection_command.connections_export(args) - expected_connections = [ - "airflow_db=mysql://root:plainpassword@mysql/airflow\n" - "druid_broker_default=druid://druid-broker:8082?endpoint=druid%2Fv2%2Fsql\n", - "druid_broker_default=druid://druid-broker:8082?endpoint=druid%2Fv2%2Fsql\n" - "airflow_db=mysql://root:plainpassword@mysql/airflow\n", + "airflow_db=mysql://root:plainpassword@mysql/airflow", + "druid_broker_default=druid://druid-broker:8082?endpoint=druid%2Fv2%2Fsql", ] + assert output_filepath.read_text().splitlines() == expected_connections - mock_splittext.assert_called_once() - mock_file_open.assert_called_once_with(output_filepath, 'w', -1, 'UTF-8', None) - mock_file_open.return_value.write.assert_called_once_with(mock.ANY) - assert mock_file_open.return_value.write.call_args_list[0][0][0] in expected_connections - - @mock.patch('os.path.splitext') - @mock.patch('builtins.open', new_callable=mock.mock_open()) - def test_cli_connections_export_should_export_as_env_for_uppercase_file_extension( - self, mock_file_open, mock_splittext - ): - output_filepath = '/tmp/connections.ENV' - mock_splittext.return_value = (None, '.ENV') - - args = self.parser.parse_args( - [ - "connections", - "export", - output_filepath, - ] - ) + def test_cli_connections_export_should_export_as_env_for_uppercase_file_extension(self, tmp_path): + output_filepath = tmp_path / 'connections.ENV' + args = self.parser.parse_args(["connections", "export", output_filepath.as_posix()]) connection_command.connections_export(args) - expected_connections = [ - "airflow_db=mysql://root:plainpassword@mysql/airflow\n" - "druid_broker_default=druid://druid-broker:8082?endpoint=druid%2Fv2%2Fsql\n", - "druid_broker_default=druid://druid-broker:8082?endpoint=druid%2Fv2%2Fsql\n" - "airflow_db=mysql://root:plainpassword@mysql/airflow\n", + "airflow_db=mysql://root:plainpassword@mysql/airflow", + "druid_broker_default=druid://druid-broker:8082?endpoint=druid%2Fv2%2Fsql", ] - mock_splittext.assert_called_once() - mock_file_open.assert_called_once_with(output_filepath, 'w', -1, 'UTF-8', None) - mock_file_open.return_value.write.assert_called_once_with(mock.ANY) - assert mock_file_open.return_value.write.call_args_list[0][0][0] in expected_connections - - @mock.patch('os.path.splitext') - @mock.patch('builtins.open', new_callable=mock.mock_open()) - def test_cli_connections_export_should_force_export_as_specified_format( - self, mock_file_open, mock_splittext - ): - output_filepath = '/tmp/connections.yaml' + assert output_filepath.read_text().splitlines() == expected_connections + def test_cli_connections_export_should_force_export_as_specified_format(self, tmp_path): + output_filepath = tmp_path / 'connections.yaml' args = self.parser.parse_args( [ "connections", "export", - output_filepath, + output_filepath.as_posix(), "--format", "json", ] ) connection_command.connections_export(args) - expected_connections = json.dumps( { "airflow_db": { @@ -476,25 +301,20 @@ def test_cli_connections_export_should_force_export_as_specified_format( }, indent=2, ) - mock_splittext.assert_not_called() - mock_file_open.assert_called_once_with(output_filepath, 'w', -1, 'UTF-8', None) - mock_file_open.return_value.write.assert_called_once_with(expected_connections) + assert output_filepath.read_text() == expected_connections TEST_URL = "postgresql://airflow:airflow@host:5432/airflow" -class TestCliAddConnections(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.parser = cli_parser.get_parser() - clear_db_connections() +class TestCliAddConnections: + parser = cli_parser.get_parser() - @classmethod - def tearDownClass(cls): - clear_db_connections() + def setup_method(self): + clear_db_connections(add_default_connections_back=False) - @parameterized.expand( + @pytest.mark.parametrize( + 'cmd, expected_output, expected_conn', [ ( [ @@ -629,7 +449,7 @@ def tearDownClass(cls): "schema": None, }, ), - ] + ], ) def test_cli_connection_add(self, cmd, expected_output, expected_conn): with redirect_stdout(io.StringIO()) as stdout: @@ -680,15 +500,11 @@ def test_cli_connections_add_invalid_uri(self): ) -class TestCliDeleteConnections(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.parser = cli_parser.get_parser() - clear_db_connections() +class TestCliDeleteConnections: + parser = cli_parser.get_parser() - @classmethod - def tearDownClass(cls): - clear_db_connections() + def setup_method(self): + clear_db_connections(add_default_connections_back=False) @provide_session def test_cli_delete_connections(self, session=None): @@ -723,15 +539,11 @@ def test_cli_delete_invalid_connection(self): connection_command.connections_delete(self.parser.parse_args(["connections", "delete", "fake"])) -class TestCliImportConnections(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.parser = cli_parser.get_parser() - clear_db_connections(add_default_connections_back=False) +class TestCliImportConnections: + parser = cli_parser.get_parser() - @classmethod - def tearDownClass(cls): - clear_db_connections() + def setup_method(self): + clear_db_connections(add_default_connections_back=False) @mock.patch('os.path.exists') def test_cli_connections_import_should_return_error_if_file_does_not_exist(self, mock_exists): @@ -740,16 +552,10 @@ def test_cli_connections_import_should_return_error_if_file_does_not_exist(self, with pytest.raises(SystemExit, match=r"Missing connections file."): connection_command.connections_import(self.parser.parse_args(["connections", "import", filepath])) - @parameterized.expand( - [ - ("sample.jso",), - ("sample.yml",), - ("sample.environ",), - ] - ) + @pytest.mark.parametrize('filepath', ["sample.jso", "sample.yml", "sample.environ"]) @mock.patch('os.path.exists') def test_cli_connections_import_should_return_error_if_file_format_is_invalid( - self, filepath, mock_exists + self, mock_exists, filepath ): mock_exists.return_value = True with pytest.raises(