Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix global flags for lists #863

Merged
merged 12 commits into from
Apr 26, 2024
35 changes: 17 additions & 18 deletions cosmos/operators/base.py
Original file line number Diff line number Diff line change
@@ -195,17 +195,28 @@ def add_global_flags(self) -> list[str]:

dbt_name = f"--{global_flag.replace('_', '-')}"
global_flag_value = self.__getattribute__(global_flag)
if global_flag_value is not None:
if isinstance(global_flag_value, dict):
yaml_string = yaml.dump(global_flag_value)
flags.extend([dbt_name, yaml_string])
else:
flags.extend([dbt_name, str(global_flag_value)])
flags.extend(self._process_global_flag(dbt_name, global_flag_value))

for global_boolean_flag in self.global_boolean_flags:
if self.__getattribute__(global_boolean_flag):
flags.append(f"--{global_boolean_flag.replace('_', '-')}")
return flags

@staticmethod
def _process_global_flag(flag_name: str, flag_value: Any) -> list[str]:
"""Helper method to process global flags and reduce complexity."""
if flag_value is None:
return []
elif isinstance(flag_value, dict):
yaml_string = yaml.dump(flag_value)
return [flag_name, yaml_string]
elif isinstance(flag_value, list) and flag_value:
return [flag_name, " ".join(flag_value)]
elif isinstance(flag_value, list):
return []
else:
return [flag_name, str(flag_value)]

def add_cmd_flags(self) -> list[str]:
"""Allows subclasses to override to add flags for their dbt command"""
return []
@@ -373,18 +384,6 @@ def __init__(
self.selector = selector
super().__init__(exclude=exclude, select=select, selector=selector, **kwargs) # type: ignore

def add_cmd_flags(self) -> list[str]:
flags = []
if self.exclude:
flags.extend(["--exclude", *self.exclude])

if self.select:
flags.extend(["--select", *self.select])

if self.selector:
flags.extend(["--selector", self.selector])
return flags


class DbtRunOperationMixin:
"""
46 changes: 37 additions & 9 deletions tests/operators/test_local.py
Original file line number Diff line number Diff line change
@@ -88,8 +88,11 @@ def test_dbt_base_operator_add_global_flags() -> None:
"end_time": "{{ data_interval_end.strftime('%Y%m%d%H%M%S') }}",
},
no_version_check=True,
select=["my_first_model", "my_second_model"],
)
assert dbt_base_operator.add_global_flags() == [
"--select",
"my_first_model my_second_model",
"--vars",
"end_time: '{{ data_interval_end.strftime(''%Y%m%d%H%M%S'') }}'\n"
"start_time: '{{ data_interval_start.strftime(''%Y%m%d%H%M%S'') }}'\n",
@@ -564,37 +567,62 @@ def test_store_compiled_sql() -> None:
@pytest.mark.parametrize(
"operator_class,kwargs,expected_call_kwargs",
[
(DbtSeedLocalOperator, {"full_refresh": True}, {"context": {}, "cmd_flags": ["--full-refresh"]}),
(DbtBuildLocalOperator, {"full_refresh": True}, {"context": {}, "cmd_flags": ["--full-refresh"]}),
(DbtRunLocalOperator, {"full_refresh": True}, {"context": {}, "cmd_flags": ["--full-refresh"]}),
(
DbtSeedLocalOperator,
{"full_refresh": True},
{"context": {}, "env": {}, "cmd_flags": ["seed", "--full-refresh"]},
),
(
DbtBuildLocalOperator,
{"full_refresh": True},
{"context": {}, "env": {}, "cmd_flags": ["build", "--full-refresh"]},
),
(
DbtRunLocalOperator,
{"full_refresh": True},
{"context": {}, "env": {}, "cmd_flags": ["run", "--full-refresh"]},
),
(
DbtTestLocalOperator,
{},
{"context": {}, "env": {}, "cmd_flags": ["test"]},
),
(
DbtTestLocalOperator,
{"select": []},
{"context": {}, "env": {}, "cmd_flags": ["test"]},
),
(
DbtTestLocalOperator,
{"full_refresh": True, "select": ["tag:daily"], "exclude": ["tag:disabled"]},
{"context": {}, "cmd_flags": ["--exclude", "tag:disabled", "--select", "tag:daily"]},
{"context": {}, "env": {}, "cmd_flags": ["test", "--select", "tag:daily", "--exclude", "tag:disabled"]},
),
(
DbtTestLocalOperator,
{"full_refresh": True, "selector": "nightly_snowplow"},
{"context": {}, "cmd_flags": ["--selector", "nightly_snowplow"]},
{"context": {}, "env": {}, "cmd_flags": ["test", "--selector", "nightly_snowplow"]},
),
(
DbtRunOperationLocalOperator,
{"args": {"days": 7, "dry_run": True}, "macro_name": "bla"},
{"context": {}, "cmd_flags": ["--args", "days: 7\ndry_run: true\n"]},
{"context": {}, "env": {}, "cmd_flags": ["run-operation", "bla", "--args", "days: 7\ndry_run: true\n"]},
),
],
)
@patch("cosmos.operators.local.DbtLocalBaseOperator.build_and_run_cmd")
def test_operator_execute_with_flags(mock_build_and_run_cmd, operator_class, kwargs, expected_call_kwargs):
@patch("cosmos.operators.local.DbtLocalBaseOperator.run_command")
def test_operator_execute_with_flags(mock_run_cmd, operator_class, kwargs, expected_call_kwargs):
ms32035 marked this conversation as resolved.
Show resolved Hide resolved
task = operator_class(
profile_config=profile_config,
task_id="my-task",
project_dir="my/dir",
invocation_mode=InvocationMode.DBT_RUNNER,
**kwargs,
)
task.get_env = MagicMock(return_value={})
task.execute(context={})
mock_build_and_run_cmd.assert_called_once_with(**expected_call_kwargs)
mock_run_cmd.assert_called_once_with(
cmd=[task.dbt_executable_path, *expected_call_kwargs.pop("cmd_flags")], **expected_call_kwargs
)


@pytest.mark.parametrize(