Skip to content

Commit

Permalink
Add node config to TaskInstance Context
Browse files Browse the repository at this point in the history
Add a node's config into TaskInstance context so that we can retrieve
it and use it for airflow callbacks
  • Loading branch information
linchun3 committed Jun 18, 2024
1 parent e4f9ece commit d1c6896
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 8 deletions.
8 changes: 8 additions & 0 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def create_test_task_metadata(
"""
task_args = dict(task_args)
task_args["on_warning_callback"] = on_warning_callback
extra_context = {}

if test_indirect_selection != TestIndirectSelection.EAGER:
task_args["indirect_selection"] = test_indirect_selection.value
if node is not None:
Expand All @@ -102,6 +104,9 @@ def create_test_task_metadata(
task_args["select"] = f"source:{node.resource_name}"
else: # tested with node.resource_type == DbtResourceType.SEED or DbtResourceType.SNAPSHOT
task_args["select"] = node.resource_name

extra_context = {f"{node.resource_type.value}_config": node.config}

elif render_config is not None: # TestBehavior.AFTER_ALL
task_args["select"] = render_config.select
task_args["selector"] = render_config.selector
Expand All @@ -114,6 +119,7 @@ def create_test_task_metadata(
dbt_class="DbtTest",
),
arguments=task_args,
extra_context=extra_context,
)


Expand All @@ -140,6 +146,7 @@ def create_task_metadata(
args = {**args, **{"models": node.resource_name}}

if DbtResourceType(node.resource_type) in DEFAULT_DBT_RESOURCES and node.resource_type in dbt_resource_to_class:
extra_context = {f"{node.resource_type.value}_config": node.config}
if node.resource_type == DbtResourceType.MODEL:
task_id = f"{node.name}_run"
if use_task_group is True:
Expand All @@ -155,6 +162,7 @@ def create_task_metadata(
execution_mode=execution_mode, dbt_class=dbt_resource_to_class[node.resource_type]
),
arguments=args,
extra_context=extra_context,
)
return task_metadata
else:
Expand Down
1 change: 1 addition & 0 deletions cosmos/core/airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def get_airflow_task(task: Task, dag: DAG, task_group: "TaskGroup | None" = None
task_id=task.id,
dag=dag,
task_group=task_group,
extra_context=task.extra_context,
**task.arguments,
)

Expand Down
1 change: 1 addition & 0 deletions cosmos/core/graph/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,4 @@ class Task(CosmosEntity):

operator_class: str = "airflow.operators.empty.EmptyOperator"
arguments: Dict[str, Any] = field(default_factory=dict)
extra_context: Dict[str, Any] = field(default_factory=dict)
8 changes: 7 additions & 1 deletion cosmos/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import yaml
from airflow.models.baseoperator import BaseOperator
from airflow.utils.context import Context
from airflow.utils.context import Context, context_merge
from airflow.utils.operator_helpers import context_to_airflow_vars
from airflow.utils.strings import to_boolean

Expand Down Expand Up @@ -63,6 +63,7 @@ class AbstractDbtBaseOperator(BaseOperator, metaclass=ABCMeta):
:param dbt_cmd_flags: List of flags to pass to dbt command
:param dbt_cmd_global_flags: List of dbt global flags to be passed to the dbt command
:param cache_dir: Directory used to cache Cosmos/dbt artifacts in Airflow worker nodes
:param extra_context: A dictionary of values to add to the TaskInstance's Context
"""

template_fields: Sequence[str] = ("env", "select", "exclude", "selector", "vars", "models")
Expand Down Expand Up @@ -111,6 +112,7 @@ def __init__(
dbt_cmd_flags: list[str] | None = None,
dbt_cmd_global_flags: list[str] | None = None,
cache_dir: Path | None = None,
extra_context: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
self.project_dir = project_dir
Expand Down Expand Up @@ -139,6 +141,7 @@ def __init__(
self.dbt_cmd_flags = dbt_cmd_flags
self.dbt_cmd_global_flags = dbt_cmd_global_flags or []
self.cache_dir = cache_dir
self.extra_context = extra_context or {}
super().__init__(**kwargs)

def get_env(self, context: Context) -> dict[str, str | bytes | os.PathLike[Any]]:
Expand Down Expand Up @@ -261,6 +264,9 @@ def build_and_run_cmd(self, context: Context, cmd_flags: list[str]) -> Any:
"""Override this method for the operator to execute the dbt command"""

def execute(self, context: Context) -> Any | None: # type: ignore
if self.extra_context:
context_merge(context, self.extra_context)

self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags())


Expand Down
54 changes: 47 additions & 7 deletions tests/airflow/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,19 +277,59 @@ def test_create_task_metadata_unsupported(caplog):
assert caplog.messages[0] == expected_msg


def test_create_task_metadata_model(caplog):
@pytest.mark.parametrize(
"unique_id, resource_type, expected_id, expected_operator_class, expected_arguments, expected_extra_context",
[
(
f"{DbtResourceType.MODEL.value}.my_folder.my_model",
DbtResourceType.MODEL,
"my_model_run",
"cosmos.operators.local.DbtRunLocalOperator",
{"models": "my_model"},
{"model_config": {}},
),
(
f"{DbtResourceType.SOURCE.value}.my_folder.my_source",
DbtResourceType.SOURCE,
"my_source_run",
"cosmos.operators.local.DbtRunLocalOperator",
{"models": "my_source"},
{"source_config": {}},
),
(
f"{DbtResourceType.SNAPSHOT.value}.my_folder.my_snapshot",
DbtResourceType.SNAPSHOT,
"my_snapshot_snapshot",
"cosmos.operators.local.DbtSnapshotLocalOperator",
{"models": "my_snapshot"},
{"snapshot_config": {}},
),
],
)
def test_create_task_metadata_model(
unique_id,
resource_type,
expected_id,
expected_operator_class,
expected_arguments,
expected_extra_context,
caplog,
):
child_node = DbtNode(
unique_id=f"{DbtResourceType.MODEL.value}.my_folder.my_model",
resource_type=DbtResourceType.MODEL,
unique_id=unique_id,
resource_type=resource_type,
depends_on=[],
file_path="",
file_path=Path(""),
tags=[],
config={},
)

metadata = create_task_metadata(child_node, execution_mode=ExecutionMode.LOCAL, args={})
assert metadata.id == "my_model_run"
assert metadata.operator_class == "cosmos.operators.local.DbtRunLocalOperator"
assert metadata.arguments == {"models": "my_model"}
if metadata:
assert metadata.id == expected_id
assert metadata.operator_class == expected_operator_class
assert metadata.arguments == expected_arguments
assert metadata.extra_context == expected_extra_context


def test_create_task_metadata_model_with_versions(caplog):
Expand Down
79 changes: 79 additions & 0 deletions tests/operators/test_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import sys
from datetime import datetime
from unittest.mock import patch

import pytest
from airflow.utils.context import Context

from cosmos.operators.base import (
AbstractDbtBaseOperator,
Expand Down Expand Up @@ -55,6 +57,83 @@ def test_dbt_base_operator_execute(mock_build_and_run_cmd, cmd_flags, monkeypatc
mock_build_and_run_cmd.assert_called_once_with(context={}, cmd_flags=cmd_flags)


@patch("cosmos.operators.base.context_merge")
def test_dbt_base_operator_context_merge_called(mock_context_merge):
"""Tests that the base operator execute method calls the context_merge method with the expected arguments."""
base_operator = AbstractDbtBaseOperator(
task_id="fake_task",
project_dir="fake_dir",
extra_context={"extra": "extra"},
)

base_operator.execute(context={})
mock_context_merge.assert_called_once_with({}, {"extra": "extra"})


@pytest.mark.parametrize(
"context, extra_context, expected_context",
[
(
Context(
start_date=datetime(2021, 1, 1),
),
{
"extra": "extra",
},
Context(
start_date=datetime(2021, 1, 1),
extra="extra",
),
),
(
Context(
start_date=datetime(2021, 1, 1),
end_date=datetime(2023, 1, 1),
),
{
"extra": "extra",
"extra_2": "extra_2",
},
Context(
start_date=datetime(2021, 1, 1),
end_date=datetime(2023, 1, 1),
extra="extra",
extra_2="extra_2",
),
),
(
Context(
overwrite="to_overwrite",
start_date=datetime(2021, 1, 1),
end_date=datetime(2023, 1, 1),
),
{
"overwrite": "overwritten",
},
Context(
start_date=datetime(2021, 1, 1),
end_date=datetime(2023, 1, 1),
overwrite="overwritten",
),
),
],
)
def test_dbt_base_operator_context_merge(
context,
extra_context,
expected_context,
):
"""Tests that the base operator execute method calls and update context"""
base_operator = AbstractDbtBaseOperator(
task_id="fake_task",
project_dir="fake_dir",
extra_context=extra_context,
)

base_operator.execute(context=context)
assert context == expected_context


@pytest.mark.parametrize(
"dbt_command, dbt_operator_class",
[
Expand Down

0 comments on commit d1c6896

Please sign in to comment.