From 5f5616c40700bf6f3139658f80f51473956bad12 Mon Sep 17 00:00:00 2001 From: linchun Date: Mon, 24 Jun 2024 17:56:49 +0800 Subject: [PATCH] Add node config to TaskInstance Context (#1044) Add the node's attributes (config, tags, etc, ...) into a TaskInstance context for retrieval by callback functions in Airflow through the use of `pre_execute` to store these attributes into a task's context. As [this PR](https://github.com/astronomer/astronomer-cosmos/pull/700/files) seems to be closed, and I have a use case for this feature, I attempt to recreate the needed feature. We leverage the `context_merge` utility function from Airflow to merge the extra context into the `Context` object of a `TaskInstance`. Closes #698 --- cosmos/airflow/graph.py | 8 ++++ cosmos/core/airflow.py | 1 + cosmos/core/graph/entities.py | 1 + cosmos/dbt/graph.py | 18 +++++++ cosmos/operators/base.py | 8 +++- tests/airflow/test_graph.py | 90 ++++++++++++++++++++++++++++++++--- tests/dbt/test_graph.py | 41 ++++++++++++++++ tests/operators/test_base.py | 79 ++++++++++++++++++++++++++++++ 8 files changed, 238 insertions(+), 8 deletions(-) diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index 905c845a6..ebae7f32f 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -93,6 +93,8 @@ def create_test_task_metadata( """ task_args = dict(task_args) task_args["on_warning_callback"] = on_warning_callback + extra_context = {} + if test_indirect_selection != TestIndirectSelection.EAGER: task_args["indirect_selection"] = test_indirect_selection.value if node is not None: @@ -102,6 +104,9 @@ def create_test_task_metadata( task_args["select"] = f"source:{node.resource_name}" else: # tested with node.resource_type == DbtResourceType.SEED or DbtResourceType.SNAPSHOT task_args["select"] = node.resource_name + + extra_context = {"dbt_node_config": node.context_dict} + elif render_config is not None: # TestBehavior.AFTER_ALL task_args["select"] = render_config.select task_args["selector"] = render_config.selector @@ -114,6 +119,7 @@ def create_test_task_metadata( dbt_class="DbtTest", ), arguments=task_args, + extra_context=extra_context, ) @@ -140,6 +146,7 @@ def create_task_metadata( args = {**args, **{"models": node.resource_name}} if DbtResourceType(node.resource_type) in DEFAULT_DBT_RESOURCES and node.resource_type in dbt_resource_to_class: + extra_context = {"dbt_node_config": node.context_dict} if node.resource_type == DbtResourceType.MODEL: task_id = f"{node.name}_run" if use_task_group is True: @@ -155,6 +162,7 @@ def create_task_metadata( execution_mode=execution_mode, dbt_class=dbt_resource_to_class[node.resource_type] ), arguments=args, + extra_context=extra_context, ) return task_metadata else: diff --git a/cosmos/core/airflow.py b/cosmos/core/airflow.py index f6f7464d8..d4a962483 100644 --- a/cosmos/core/airflow.py +++ b/cosmos/core/airflow.py @@ -29,6 +29,7 @@ def get_airflow_task(task: Task, dag: DAG, task_group: "TaskGroup | None" = None task_id=task.id, dag=dag, task_group=task_group, + extra_context=task.extra_context, **task.arguments, ) diff --git a/cosmos/core/graph/entities.py b/cosmos/core/graph/entities.py index f88c3d6b2..3c3ee58d0 100644 --- a/cosmos/core/graph/entities.py +++ b/cosmos/core/graph/entities.py @@ -59,3 +59,4 @@ class Task(CosmosEntity): operator_class: str = "airflow.operators.empty.EmptyOperator" arguments: Dict[str, Any] = field(default_factory=dict) + extra_context: Dict[str, Any] = field(default_factory=dict) diff --git a/cosmos/dbt/graph.py b/cosmos/dbt/graph.py index b38469b08..bd3181a20 100644 --- a/cosmos/dbt/graph.py +++ b/cosmos/dbt/graph.py @@ -70,6 +70,24 @@ def name(self) -> str: """ return self.resource_name.replace(".", "_") + @property + def context_dict(self) -> dict[str, Any]: + """ + Returns a dictionary containing all the attributes of the DbtNode object, + ensuring that the output is JSON serializable so it can be stored in Airflow's db + """ + return { + "unique_id": self.unique_id, + "resource_type": self.resource_type.value, # convert enum to value + "depends_on": self.depends_on, + "file_path": str(self.file_path), # convert path to string + "tags": self.tags, + "config": self.config, + "has_test": self.has_test, + "resource_name": self.resource_name, + "name": self.name, + } + def run_command(command: list[str], tmp_dir: Path, env_vars: dict[str, str]) -> str: """Run a command in a subprocess, returning the stdout.""" diff --git a/cosmos/operators/base.py b/cosmos/operators/base.py index e22703fb5..d0cbdd282 100644 --- a/cosmos/operators/base.py +++ b/cosmos/operators/base.py @@ -7,7 +7,7 @@ import yaml from airflow.models.baseoperator import BaseOperator -from airflow.utils.context import Context +from airflow.utils.context import Context, context_merge from airflow.utils.operator_helpers import context_to_airflow_vars from airflow.utils.strings import to_boolean @@ -63,6 +63,7 @@ class AbstractDbtBaseOperator(BaseOperator, metaclass=ABCMeta): :param dbt_cmd_flags: List of flags to pass to dbt command :param dbt_cmd_global_flags: List of dbt global flags to be passed to the dbt command :param cache_dir: Directory used to cache Cosmos/dbt artifacts in Airflow worker nodes + :param extra_context: A dictionary of values to add to the TaskInstance's Context """ template_fields: Sequence[str] = ("env", "select", "exclude", "selector", "vars", "models") @@ -111,6 +112,7 @@ def __init__( dbt_cmd_flags: list[str] | None = None, dbt_cmd_global_flags: list[str] | None = None, cache_dir: Path | None = None, + extra_context: dict[str, Any] | None = None, **kwargs: Any, ) -> None: self.project_dir = project_dir @@ -139,6 +141,7 @@ def __init__( self.dbt_cmd_flags = dbt_cmd_flags self.dbt_cmd_global_flags = dbt_cmd_global_flags or [] self.cache_dir = cache_dir + self.extra_context = extra_context or {} super().__init__(**kwargs) def get_env(self, context: Context) -> dict[str, str | bytes | os.PathLike[Any]]: @@ -261,6 +264,9 @@ def build_and_run_cmd(self, context: Context, cmd_flags: list[str]) -> Any: """Override this method for the operator to execute the dbt command""" def execute(self, context: Context) -> Any | None: # type: ignore + if self.extra_context: + context_merge(context, self.extra_context) + self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags()) diff --git a/tests/airflow/test_graph.py b/tests/airflow/test_graph.py index 4ef7d112c..a238475c2 100644 --- a/tests/airflow/test_graph.py +++ b/tests/airflow/test_graph.py @@ -277,19 +277,95 @@ def test_create_task_metadata_unsupported(caplog): assert caplog.messages[0] == expected_msg -def test_create_task_metadata_model(caplog): +@pytest.mark.parametrize( + "unique_id, resource_type, expected_id, expected_operator_class, expected_arguments, expected_extra_context", + [ + ( + f"{DbtResourceType.MODEL.value}.my_folder.my_model", + DbtResourceType.MODEL, + "my_model_run", + "cosmos.operators.local.DbtRunLocalOperator", + {"models": "my_model"}, + { + "dbt_node_config": { + "unique_id": "model.my_folder.my_model", + "resource_type": "model", + "depends_on": [], + "file_path": ".", + "tags": [], + "config": {}, + "has_test": False, + "resource_name": "my_model", + "name": "my_model", + } + }, + ), + ( + f"{DbtResourceType.SOURCE.value}.my_folder.my_source", + DbtResourceType.SOURCE, + "my_source_run", + "cosmos.operators.local.DbtRunLocalOperator", + {"models": "my_source"}, + { + "dbt_node_config": { + "unique_id": "model.my_folder.my_source", + "resource_type": "source", + "depends_on": [], + "file_path": ".", + "tags": [], + "config": {}, + "has_test": False, + "resource_name": "my_source", + "name": "my_source", + } + }, + ), + ( + f"{DbtResourceType.SNAPSHOT.value}.my_folder.my_snapshot", + DbtResourceType.SNAPSHOT, + "my_snapshot_snapshot", + "cosmos.operators.local.DbtSnapshotLocalOperator", + {"models": "my_snapshot"}, + { + "dbt_node_config": { + "unique_id": "snapshot.my_folder.my_snapshot", + "resource_type": "snapshot", + "depends_on": [], + "file_path": ".", + "tags": [], + "config": {}, + "has_test": False, + "resource_name": "my_snapshot", + "name": "my_snapshot", + }, + }, + ), + ], +) +def test_create_task_metadata_model( + unique_id, + resource_type, + expected_id, + expected_operator_class, + expected_arguments, + expected_extra_context, + caplog, +): child_node = DbtNode( - unique_id=f"{DbtResourceType.MODEL.value}.my_folder.my_model", - resource_type=DbtResourceType.MODEL, + unique_id=unique_id, + resource_type=resource_type, depends_on=[], - file_path="", + file_path=Path(""), tags=[], config={}, ) + metadata = create_task_metadata(child_node, execution_mode=ExecutionMode.LOCAL, args={}) - assert metadata.id == "my_model_run" - assert metadata.operator_class == "cosmos.operators.local.DbtRunLocalOperator" - assert metadata.arguments == {"models": "my_model"} + if metadata: + assert metadata.id == expected_id + assert metadata.operator_class == expected_operator_class + assert metadata.arguments == expected_arguments + assert metadata.extra_context == expected_extra_context def test_create_task_metadata_model_with_versions(caplog): diff --git a/tests/dbt/test_graph.py b/tests/dbt/test_graph.py index 0166dd89f..652a81482 100644 --- a/tests/dbt/test_graph.py +++ b/tests/dbt/test_graph.py @@ -71,6 +71,47 @@ def test_dbt_node_name_and_select(unique_id, expected_name, expected_select): assert node.resource_name == expected_select +@pytest.mark.parametrize( + "unique_id,expected_dict", + [ + ( + "model.my_project.customers", + { + "unique_id": "model.my_project.customers", + "resource_type": "model", + "depends_on": [], + "file_path": "", + "tags": [], + "config": {}, + "has_test": False, + "resource_name": "customers", + "name": "customers", + }, + ), + ( + "model.my_project.customers.v1", + { + "unique_id": "model.my_project.customers.v1", + "resource_type": "model", + "depends_on": [], + "file_path": "", + "tags": [], + "config": {}, + "has_test": False, + "resource_name": "customers.v1", + "name": "customers_v1", + }, + ), + ], +) +def test_dbt_node_context_dict( + unique_id, + expected_dict, +): + node = DbtNode(unique_id=unique_id, resource_type=DbtResourceType.MODEL, depends_on=[], file_path="") + assert node.context_dict == expected_dict + + @pytest.mark.parametrize( "project_name,manifest_filepath,model_filepath", [(DBT_PROJECT_NAME, SAMPLE_MANIFEST, "customers.sql"), ("jaffle_shop_python", SAMPLE_MANIFEST_PY, "customers.py")], diff --git a/tests/operators/test_base.py b/tests/operators/test_base.py index 3d39d43a7..6f4425282 100644 --- a/tests/operators/test_base.py +++ b/tests/operators/test_base.py @@ -1,7 +1,9 @@ import sys +from datetime import datetime from unittest.mock import patch import pytest +from airflow.utils.context import Context from cosmos.operators.base import ( AbstractDbtBaseOperator, @@ -55,6 +57,83 @@ def test_dbt_base_operator_execute(mock_build_and_run_cmd, cmd_flags, monkeypatc mock_build_and_run_cmd.assert_called_once_with(context={}, cmd_flags=cmd_flags) +@patch("cosmos.operators.base.context_merge") +def test_dbt_base_operator_context_merge_called(mock_context_merge): + """Tests that the base operator execute method calls the context_merge method with the expected arguments.""" + base_operator = AbstractDbtBaseOperator( + task_id="fake_task", + project_dir="fake_dir", + extra_context={"extra": "extra"}, + ) + + base_operator.execute(context={}) + mock_context_merge.assert_called_once_with({}, {"extra": "extra"}) + + +@pytest.mark.parametrize( + "context, extra_context, expected_context", + [ + ( + Context( + start_date=datetime(2021, 1, 1), + ), + { + "extra": "extra", + }, + Context( + start_date=datetime(2021, 1, 1), + extra="extra", + ), + ), + ( + Context( + start_date=datetime(2021, 1, 1), + end_date=datetime(2023, 1, 1), + ), + { + "extra": "extra", + "extra_2": "extra_2", + }, + Context( + start_date=datetime(2021, 1, 1), + end_date=datetime(2023, 1, 1), + extra="extra", + extra_2="extra_2", + ), + ), + ( + Context( + overwrite="to_overwrite", + start_date=datetime(2021, 1, 1), + end_date=datetime(2023, 1, 1), + ), + { + "overwrite": "overwritten", + }, + Context( + start_date=datetime(2021, 1, 1), + end_date=datetime(2023, 1, 1), + overwrite="overwritten", + ), + ), + ], +) +def test_dbt_base_operator_context_merge( + context, + extra_context, + expected_context, +): + """Tests that the base operator execute method calls and update context""" + base_operator = AbstractDbtBaseOperator( + task_id="fake_task", + project_dir="fake_dir", + extra_context=extra_context, + ) + + base_operator.execute(context=context) + assert context == expected_context + + @pytest.mark.parametrize( "dbt_command, dbt_operator_class", [