From 645f504b0f2f024ed89ebe74aace2d090a786976 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20Szyma=C5=84ski?= Date: Fri, 26 Apr 2024 16:46:44 +0100 Subject: [PATCH] Fix global flags for lists (#863) Correctly deals with global flags when they are a list. Note - in the module there's no distinguishing which are and which aren't. Co-authored-by: Tatiana Al-Chueyr --- cosmos/operators/base.py | 35 +++++++++++++------------- tests/operators/test_local.py | 46 ++++++++++++++++++++++++++++------- 2 files changed, 54 insertions(+), 27 deletions(-) diff --git a/cosmos/operators/base.py b/cosmos/operators/base.py index b9f25758d..b6e1797d8 100644 --- a/cosmos/operators/base.py +++ b/cosmos/operators/base.py @@ -195,17 +195,28 @@ def add_global_flags(self) -> list[str]: dbt_name = f"--{global_flag.replace('_', '-')}" global_flag_value = self.__getattribute__(global_flag) - if global_flag_value is not None: - if isinstance(global_flag_value, dict): - yaml_string = yaml.dump(global_flag_value) - flags.extend([dbt_name, yaml_string]) - else: - flags.extend([dbt_name, str(global_flag_value)]) + flags.extend(self._process_global_flag(dbt_name, global_flag_value)) + for global_boolean_flag in self.global_boolean_flags: if self.__getattribute__(global_boolean_flag): flags.append(f"--{global_boolean_flag.replace('_', '-')}") return flags + @staticmethod + def _process_global_flag(flag_name: str, flag_value: Any) -> list[str]: + """Helper method to process global flags and reduce complexity.""" + if flag_value is None: + return [] + elif isinstance(flag_value, dict): + yaml_string = yaml.dump(flag_value) + return [flag_name, yaml_string] + elif isinstance(flag_value, list) and flag_value: + return [flag_name, " ".join(flag_value)] + elif isinstance(flag_value, list): + return [] + else: + return [flag_name, str(flag_value)] + def add_cmd_flags(self) -> list[str]: """Allows subclasses to override to add flags for their dbt command""" return [] @@ -373,18 +384,6 @@ def __init__( self.selector = selector super().__init__(exclude=exclude, select=select, selector=selector, **kwargs) # type: ignore - def add_cmd_flags(self) -> list[str]: - flags = [] - if self.exclude: - flags.extend(["--exclude", *self.exclude]) - - if self.select: - flags.extend(["--select", *self.select]) - - if self.selector: - flags.extend(["--selector", self.selector]) - return flags - class DbtRunOperationMixin: """ diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index 17b98ee56..250b044fa 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -88,8 +88,11 @@ def test_dbt_base_operator_add_global_flags() -> None: "end_time": "{{ data_interval_end.strftime('%Y%m%d%H%M%S') }}", }, no_version_check=True, + select=["my_first_model", "my_second_model"], ) assert dbt_base_operator.add_global_flags() == [ + "--select", + "my_first_model my_second_model", "--vars", "end_time: '{{ data_interval_end.strftime(''%Y%m%d%H%M%S'') }}'\n" "start_time: '{{ data_interval_start.strftime(''%Y%m%d%H%M%S'') }}'\n", @@ -564,28 +567,50 @@ def test_store_compiled_sql() -> None: @pytest.mark.parametrize( "operator_class,kwargs,expected_call_kwargs", [ - (DbtSeedLocalOperator, {"full_refresh": True}, {"context": {}, "cmd_flags": ["--full-refresh"]}), - (DbtBuildLocalOperator, {"full_refresh": True}, {"context": {}, "cmd_flags": ["--full-refresh"]}), - (DbtRunLocalOperator, {"full_refresh": True}, {"context": {}, "cmd_flags": ["--full-refresh"]}), + ( + DbtSeedLocalOperator, + {"full_refresh": True}, + {"context": {}, "env": {}, "cmd_flags": ["seed", "--full-refresh"]}, + ), + ( + DbtBuildLocalOperator, + {"full_refresh": True}, + {"context": {}, "env": {}, "cmd_flags": ["build", "--full-refresh"]}, + ), + ( + DbtRunLocalOperator, + {"full_refresh": True}, + {"context": {}, "env": {}, "cmd_flags": ["run", "--full-refresh"]}, + ), + ( + DbtTestLocalOperator, + {}, + {"context": {}, "env": {}, "cmd_flags": ["test"]}, + ), + ( + DbtTestLocalOperator, + {"select": []}, + {"context": {}, "env": {}, "cmd_flags": ["test"]}, + ), ( DbtTestLocalOperator, {"full_refresh": True, "select": ["tag:daily"], "exclude": ["tag:disabled"]}, - {"context": {}, "cmd_flags": ["--exclude", "tag:disabled", "--select", "tag:daily"]}, + {"context": {}, "env": {}, "cmd_flags": ["test", "--select", "tag:daily", "--exclude", "tag:disabled"]}, ), ( DbtTestLocalOperator, {"full_refresh": True, "selector": "nightly_snowplow"}, - {"context": {}, "cmd_flags": ["--selector", "nightly_snowplow"]}, + {"context": {}, "env": {}, "cmd_flags": ["test", "--selector", "nightly_snowplow"]}, ), ( DbtRunOperationLocalOperator, {"args": {"days": 7, "dry_run": True}, "macro_name": "bla"}, - {"context": {}, "cmd_flags": ["--args", "days: 7\ndry_run: true\n"]}, + {"context": {}, "env": {}, "cmd_flags": ["run-operation", "bla", "--args", "days: 7\ndry_run: true\n"]}, ), ], ) -@patch("cosmos.operators.local.DbtLocalBaseOperator.build_and_run_cmd") -def test_operator_execute_with_flags(mock_build_and_run_cmd, operator_class, kwargs, expected_call_kwargs): +@patch("cosmos.operators.local.DbtLocalBaseOperator.run_command") +def test_operator_execute_with_flags(mock_run_cmd, operator_class, kwargs, expected_call_kwargs): task = operator_class( profile_config=profile_config, task_id="my-task", @@ -593,8 +618,11 @@ def test_operator_execute_with_flags(mock_build_and_run_cmd, operator_class, kwa invocation_mode=InvocationMode.DBT_RUNNER, **kwargs, ) + task.get_env = MagicMock(return_value={}) task.execute(context={}) - mock_build_and_run_cmd.assert_called_once_with(**expected_call_kwargs) + mock_run_cmd.assert_called_once_with( + cmd=[task.dbt_executable_path, *expected_call_kwargs.pop("cmd_flags")], **expected_call_kwargs + ) @pytest.mark.parametrize(