Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more template fields to DbtBaseOperator #786

Merged
merged 22 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
23ed0eb
add more template fields
dwreeves Jan 7, 2024
9561911
🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
pre-commit-ci[bot] Jan 7, 2024
3decda7
Merge branch 'main' into add-more-template-fields
dwreeves Feb 1, 2024
5c6210d
fix inplace mutation
dwreeves Feb 9, 2024
4395f15
Merge branch 'main' of github.com-dwreeves:dwreeves/astronomer-cosmos…
dwreeves Feb 25, 2024
206de9a
updates
dwreeves Feb 25, 2024
dcc1b71
Merge branch 'main' of github.com-dwreeves:astronomer/astronomer-cosm…
dwreeves Feb 25, 2024
4cab6e6
fix typing
dwreeves Feb 25, 2024
bce8f22
fix
dwreeves Feb 25, 2024
1729082
Merge branch 'main' into add-more-template-fields
tatiana Feb 26, 2024
9704e7f
update tests
dwreeves Feb 28, 2024
0efc4bd
Merge branch 'add-more-template-fields' of github.com-dwreeves:dwreev…
dwreeves Feb 28, 2024
2acb445
Merge branch 'main' into add-more-template-fields
dwreeves Feb 28, 2024
4bba5b3
fix something that leaked from another PR somehow
dwreeves Feb 28, 2024
6b4d930
add test for template fields
dwreeves Feb 28, 2024
a075b3c
Merge branch 'main' into add-more-template-fields
tatiana Feb 29, 2024
eb6f274
Improve test coverage
tatiana Feb 29, 2024
173eee1
Update docs/configuration/operator-args.rst
tatiana Feb 29, 2024
790ceec
Update docs/configuration/operator-args.rst
tatiana Feb 29, 2024
adaaae2
Update docs/configuration/operator-args.rst
tatiana Feb 29, 2024
5c9b975
Update docs/configuration/operator-args.rst
tatiana Feb 29, 2024
d44424b
🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
pre-commit-ci[bot] Feb 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 40 additions & 5 deletions cosmos/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from airflow.models.baseoperator import BaseOperator
from airflow.utils.context import Context
from airflow.utils.operator_helpers import context_to_airflow_vars
from airflow.utils.strings import to_boolean

from cosmos.dbt.executable import get_system_dbt
from cosmos.log import get_logger
Expand Down Expand Up @@ -61,7 +62,7 @@ class AbstractDbtBaseOperator(BaseOperator, metaclass=ABCMeta):
:param dbt_cmd_global_flags: List of dbt global flags to be passed to the dbt command
"""

template_fields: Sequence[str] = ("env", "vars")
template_fields: Sequence[str] = ("env", "select", "exclude", "selector", "vars", "models")
global_flags = (
"project_dir",
"select",
Expand Down Expand Up @@ -253,6 +254,26 @@ class DbtBuildMixin:
base_cmd = ["build"]
ui_color = "#8194E0"

template_fields: Sequence[str] = ("full_refresh",)

def __init__(self, full_refresh: bool | str = False, **kwargs: Any) -> None:
self.full_refresh = full_refresh
super().__init__(**kwargs)

def add_cmd_flags(self) -> list[str]:
flags = []

if isinstance(self.full_refresh, str):
# Handle template fields when render_template_as_native_obj=False
full_refresh = to_boolean(self.full_refresh)
else:
full_refresh = self.full_refresh

if full_refresh is True:
flags.append("--full-refresh")

return flags


class DbtLSMixin:
"""
Expand All @@ -275,13 +296,20 @@ class DbtSeedMixin:

template_fields: Sequence[str] = ("full_refresh",)

def __init__(self, full_refresh: bool = False, **kwargs: Any) -> None:
def __init__(self, full_refresh: bool | str = False, **kwargs: Any) -> None:
self.full_refresh = full_refresh
super().__init__(**kwargs)

def add_cmd_flags(self) -> list[str]:
flags = []
if self.full_refresh is True:

if isinstance(self.full_refresh, str):
# Handle template fields when render_template_as_native_obj=False
full_refresh = to_boolean(self.full_refresh)
else:
full_refresh = self.full_refresh

if full_refresh is True:
flags.append("--full-refresh")

return flags
Expand All @@ -307,13 +335,20 @@ class DbtRunMixin:

template_fields: Sequence[str] = ("full_refresh",)

def __init__(self, full_refresh: bool = False, **kwargs: Any) -> None:
def __init__(self, full_refresh: bool | str = False, **kwargs: Any) -> None:
self.full_refresh = full_refresh
super().__init__(**kwargs)

def add_cmd_flags(self) -> list[str]:
flags = []
if self.full_refresh is True:

if isinstance(self.full_refresh, str):
# Handle template fields when render_template_as_native_obj=False
full_refresh = to_boolean(self.full_refresh)
else:
full_refresh = self.full_refresh

if full_refresh is True:
flags.append("--full-refresh")

return flags
Expand Down
2 changes: 2 additions & 0 deletions cosmos/operators/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ class DbtBuildDockerOperator(DbtBuildMixin, DbtDockerBaseOperator):
Executes a dbt core build command.
"""

template_fields: Sequence[str] = DbtDockerBaseOperator.template_fields + DbtBuildMixin.template_fields # type: ignore[operator]


class DbtLSDockerOperator(DbtLSMixin, DbtDockerBaseOperator):
"""
Expand Down
2 changes: 2 additions & 0 deletions cosmos/operators/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ class DbtBuildKubernetesOperator(DbtBuildMixin, DbtKubernetesBaseOperator):
Executes a dbt core build command.
"""

template_fields: Sequence[str] = DbtKubernetesBaseOperator.template_fields + DbtBuildMixin.template_fields # type: ignore[operator]


class DbtLSKubernetesOperator(DbtLSMixin, DbtKubernetesBaseOperator):
"""
Expand Down
2 changes: 2 additions & 0 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,8 @@ class DbtBuildLocalOperator(DbtBuildMixin, DbtLocalBaseOperator):
Executes a dbt core build command.
"""

template_fields: Sequence[str] = DbtLocalBaseOperator.template_fields + DbtBuildMixin.template_fields # type: ignore[operator]


class DbtLSLocalOperator(DbtLSMixin, DbtLocalBaseOperator):
"""
Expand Down
2 changes: 1 addition & 1 deletion cosmos/operators/virtualenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def execute(self, context: Context) -> None:
logger.info(output)


class DbtBuildVirtualenvOperator(DbtVirtualenvBaseOperator, DbtBuildLocalOperator):
class DbtBuildVirtualenvOperator(DbtVirtualenvBaseOperator, DbtBuildLocalOperator): # type: ignore[misc]
"""
Executes a dbt core build command within a Python Virtual Environment, that is created before running the dbt command
and deleted just after.
Expand Down
32 changes: 32 additions & 0 deletions docs/configuration/operator-args.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ dbt-related
- ``quiet``: run ``dbt`` in silent mode, only displaying its error logs.
- ``vars``: (Deprecated since Cosmos 1.3 use ``ProjectConfig.dbt_vars`` instead) Supply variables to the project. This argument overrides variables defined in the ``dbt_project.yml``.
- ``warn_error``: convert ``dbt`` warnings into errors.
- ``full_refresh``: If True, then full refresh the node. This only applies to model and seed nodes.

Airflow-related
...............
Expand Down Expand Up @@ -88,3 +89,34 @@ Sample usage
"skip_exit_code": 1,
}
)


Template fields
---------------

Some of the operator args are `template fields <https://airflow.apache.org/docs/apache-airflow/stable/howto/custom-operator.html#templating>`_ for your convenience.

These template fields can be useful for hooking into Airflow `Params <https://airflow.apache.org/docs/apache-airflow/stable/core-concepts/params.html>`_, or for more advanced customization with `XComs <https://airflow.apache.org/docs/apache-airflow/stable/core-concepts/xcoms.html>`_.

The following operator args support templating, and are accessible both through the ``DbtDag`` and ``DbtTaskGroup`` constructors in addition to being accessible standalone:

- ``env``
- ``vars``
- ``full_refresh`` (for the ``build``, ``seed``, and ``run`` operators.)
tatiana marked this conversation as resolved.
Show resolved Hide resolved

.. note::
Using Jinja templating for ``env`` and ``vars`` may cause problems when using ``LoadMode.DBT_LS`` to render your DAG.

The following template fields are only selectable when using the operators in a standalone context:
tatiana marked this conversation as resolved.
Show resolved Hide resolved

- ``select``
- ``exclude``
- ``selector``
- ``models``

The aforementioned args are not available to be templated via ``DbtDag`` and ``DbtTaskGroup`` because they need to select dbt nodes to render the DAG's tasks.
tatiana marked this conversation as resolved.
Show resolved Hide resolved
Since template fields are rendered on each ``DagRun``,
tatiana marked this conversation as resolved.
Show resolved Hide resolved

Additionally, the SQL for compiled dbt models is stored in the template fields, which is viewable in the Airflow UI for each task run.
This is provided for telemetry on task execution, and is not an operator arg.
For more information about this, see the `Compiled SQL <compiled-sql.html>`_ docs.
6 changes: 4 additions & 2 deletions tests/operators/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ def test_dbt_mixin_base_cmd(dbt_command, dbt_operator_class):
assert [dbt_command] == dbt_operator_class.base_cmd


@pytest.mark.parametrize("dbt_operator_class", [DbtSeedMixin, DbtRunMixin])
@pytest.mark.parametrize("full_refresh, expected_flags", [(True, ["--full-refresh"]), (False, [])])
@pytest.mark.parametrize("dbt_operator_class", [DbtSeedMixin, DbtRunMixin, DbtBuildMixin])
@pytest.mark.parametrize(
"full_refresh, expected_flags", [("True", ["--full-refresh"]), (True, ["--full-refresh"]), (False, [])]
)
def test_dbt_mixin_add_cmd_flags_full_refresh(full_refresh, expected_flags, dbt_operator_class):
dbt_mixin = dbt_operator_class(full_refresh=full_refresh)
flags = dbt_mixin.add_cmd_flags()
Expand Down
15 changes: 13 additions & 2 deletions tests/operators/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,7 @@ def test_store_compiled_sql() -> None:
"operator_class,kwargs,expected_call_kwargs",
[
(DbtSeedLocalOperator, {"full_refresh": True}, {"context": {}, "cmd_flags": ["--full-refresh"]}),
(DbtBuildLocalOperator, {"full_refresh": True}, {"context": {}, "cmd_flags": ["--full-refresh"]}),
(DbtRunLocalOperator, {"full_refresh": True}, {"context": {}, "cmd_flags": ["--full-refresh"]}),
(
DbtTestLocalOperator,
Expand Down Expand Up @@ -650,8 +651,18 @@ def test_calculate_openlineage_events_completes_openlineage_errors(mock_processo
@pytest.mark.parametrize(
"operator_class,expected_template",
[
(DbtSeedLocalOperator, ("env", "vars", "compiled_sql", "full_refresh")),
(DbtRunLocalOperator, ("env", "vars", "compiled_sql", "full_refresh")),
(
DbtSeedLocalOperator,
("env", "select", "exclude", "selector", "vars", "models", "compiled_sql", "full_refresh"),
),
(
DbtRunLocalOperator,
("env", "select", "exclude", "selector", "vars", "models", "compiled_sql", "full_refresh"),
),
(
DbtBuildLocalOperator,
("env", "select", "exclude", "selector", "vars", "models", "compiled_sql", "full_refresh"),
),
],
)
def test_dbt_base_operator_template_fields(operator_class, expected_template):
Expand Down
Loading