From 121aa09232a0eee38794b7f2a2c3759474aea1b6 Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Mon, 20 Nov 2023 16:39:15 -0500 Subject: [PATCH 1/5] initial-commit --- cosmos/airflow/graph.py | 13 ++++++++++--- cosmos/core/airflow.py | 7 +++++-- cosmos/operators/base.py | 7 +++++++ cosmos/operators/local.py | 4 +--- 4 files changed, 23 insertions(+), 8 deletions(-) diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index af854d4f5..32719ff11 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -99,7 +99,10 @@ def create_test_task_metadata( def create_task_metadata( - node: DbtNode, execution_mode: ExecutionMode, args: dict[str, Any], use_task_group: bool = False + node: DbtNode, + execution_mode: ExecutionMode, + args: dict[str, Any], + use_task_group: bool = False, ) -> TaskMetadata | None: """ Create the metadata that will be used to instantiate the Airflow Task used to run the Dbt node. @@ -156,6 +159,7 @@ def generate_task_or_group( test_behavior: TestBehavior, test_indirect_selection: TestIndirectSelection, on_warning_callback: Callable[..., Any] | None, + node_config: dict[str, Any], **kwargs: Any, ) -> BaseOperator | TaskGroup | None: task_or_group: BaseOperator | TaskGroup | None = None @@ -176,7 +180,7 @@ def generate_task_or_group( if task_meta and node.resource_type != DbtResourceType.TEST: if use_task_group: with TaskGroup(dag=dag, group_id=node.name, parent_group=task_group) as model_task_group: - task = create_airflow_task(task_meta, dag, task_group=model_task_group) + task = create_airflow_task(task_meta, dag, task_group=model_task_group, extra_context=node_config) test_meta = create_test_task_metadata( "test", execution_mode, @@ -184,12 +188,14 @@ def generate_task_or_group( task_args=task_args, node=node, on_warning_callback=on_warning_callback, + node_config=node_config, ) - test_task = create_airflow_task(test_meta, dag, task_group=model_task_group) + test_task = create_airflow_task(test_meta, dag, task_group=model_task_group, extra_context=node_config) task >> test_task task_or_group = model_task_group else: task_or_group = create_airflow_task(task_meta, dag, task_group=task_group) + return task_or_group @@ -251,6 +257,7 @@ def build_airflow_graph( test_indirect_selection=test_indirect_selection, on_warning_callback=on_warning_callback, node=node, + node_config=node.config, ) if task_or_group is not None: logger.debug(f"Conversion of <{node.unique_id}> was successful!") diff --git a/cosmos/core/airflow.py b/cosmos/core/airflow.py index 7c5dee328..5ac15dcef 100644 --- a/cosmos/core/airflow.py +++ b/cosmos/core/airflow.py @@ -6,12 +6,14 @@ from cosmos.core.graph.entities import Task from cosmos.log import get_logger - +from typing import Any logger = get_logger(__name__) -def get_airflow_task(task: Task, dag: DAG, task_group: "TaskGroup | None" = None) -> BaseOperator: +def get_airflow_task( + task: Task, dag: DAG, task_group: "TaskGroup | None" = None, extra_context: dict[str, Any] = {} +) -> BaseOperator: """ Get the Airflow Operator class for a Task. @@ -30,6 +32,7 @@ def get_airflow_task(task: Task, dag: DAG, task_group: "TaskGroup | None" = None task_id=task.id, dag=dag, task_group=task_group, + extra_context=extra_context, **task.arguments, ) diff --git a/cosmos/operators/base.py b/cosmos/operators/base.py index 6d276013d..2072cf22d 100644 --- a/cosmos/operators/base.py +++ b/cosmos/operators/base.py @@ -57,6 +57,7 @@ class DbtBaseOperator(BaseOperator): (i.e. /home/astro/.pyenv/versions/dbt_venv/bin/dbt) :param dbt_cmd_flags: List of flags to pass to dbt command :param dbt_cmd_global_flags: List of dbt global flags to be passed to the dbt command + :param extra_context: A dictionary of values to add to the Airflow Task context """ template_fields: Sequence[str] = ("env", "vars") @@ -105,6 +106,7 @@ def __init__( dbt_executable_path: str = get_system_dbt(), dbt_cmd_flags: list[str] | None = None, dbt_cmd_global_flags: list[str] | None = None, + extra_context: dict[str, Any] | None = None, **kwargs: Any, ) -> None: self.project_dir = project_dir @@ -132,6 +134,7 @@ def __init__( self.dbt_executable_path = dbt_executable_path self.dbt_cmd_flags = dbt_cmd_flags self.dbt_cmd_global_flags = dbt_cmd_global_flags or [] + self.extra_context = extra_context or {} super().__init__(**kwargs) def get_env(self, context: Context) -> dict[str, str | bytes | os.PathLike[Any]]: @@ -231,3 +234,7 @@ def build_cmd( env = self.get_env(context) return dbt_cmd, env + + def pre_execute(self, context: Any): + context["model_config"] = self.extra_context + return super().pre_execute(context) diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 6eea764ad..0e7d3684c 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -236,9 +236,7 @@ def run_command( ) if is_openlineage_available: self.calculate_openlineage_events_completes(env, Path(tmp_project_dir)) - context[ - "task_instance" - ].openlineage_events_completes = self.openlineage_events_completes # type: ignore + context["task_instance"].openlineage_events_completes = self.openlineage_events_completes # type: ignore if self.emit_datasets: inlets = self.get_datasets("inputs") From 20663cb747682507beeb899bd6bde19234ef8bf8 Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Mon, 20 Nov 2023 16:53:30 -0500 Subject: [PATCH 2/5] fixup --- cosmos/airflow/graph.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index 32719ff11..716feaf05 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -188,7 +188,6 @@ def generate_task_or_group( task_args=task_args, node=node, on_warning_callback=on_warning_callback, - node_config=node_config, ) test_task = create_airflow_task(test_meta, dag, task_group=model_task_group, extra_context=node_config) task >> test_task From be5c79c7d43de85c69e380af75a50e215057d2f9 Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Mon, 20 Nov 2023 18:47:49 -0500 Subject: [PATCH 3/5] add log --- cosmos/operators/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cosmos/operators/base.py b/cosmos/operators/base.py index 2072cf22d..62d830116 100644 --- a/cosmos/operators/base.py +++ b/cosmos/operators/base.py @@ -236,5 +236,7 @@ def build_cmd( return dbt_cmd, env def pre_execute(self, context: Any): - context["model_config"] = self.extra_context + if self.extra_context: + logger.info("Extra context passed to operator, injecting into TaskInstance") + context["model_config"] = self.extra_context return super().pre_execute(context) From 9ef16f9bb5e3a177a8279816f568b96dade82bc0 Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Mon, 20 Nov 2023 19:07:51 -0500 Subject: [PATCH 4/5] type hint the pre_execute function --- cosmos/operators/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cosmos/operators/base.py b/cosmos/operators/base.py index 62d830116..3aa5dae08 100644 --- a/cosmos/operators/base.py +++ b/cosmos/operators/base.py @@ -235,7 +235,7 @@ def build_cmd( return dbt_cmd, env - def pre_execute(self, context: Any): + def pre_execute(self, context: Any) -> None: if self.extra_context: logger.info("Extra context passed to operator, injecting into TaskInstance") context["model_config"] = self.extra_context From 472287bc83666689fd78868a9d09b6642ba44494 Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Tue, 21 Nov 2023 08:39:54 -0500 Subject: [PATCH 5/5] fixup type error --- cosmos/airflow/graph.py | 2 +- cosmos/core/airflow.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index 716feaf05..4286a7a08 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -159,7 +159,7 @@ def generate_task_or_group( test_behavior: TestBehavior, test_indirect_selection: TestIndirectSelection, on_warning_callback: Callable[..., Any] | None, - node_config: dict[str, Any], + node_config: dict[str, Any] | None = None, **kwargs: Any, ) -> BaseOperator | TaskGroup | None: task_or_group: BaseOperator | TaskGroup | None = None diff --git a/cosmos/core/airflow.py b/cosmos/core/airflow.py index 5ac15dcef..ffb6137a1 100644 --- a/cosmos/core/airflow.py +++ b/cosmos/core/airflow.py @@ -12,7 +12,7 @@ def get_airflow_task( - task: Task, dag: DAG, task_group: "TaskGroup | None" = None, extra_context: dict[str, Any] = {} + task: Task, dag: DAG, task_group: "TaskGroup | None" = None, extra_context: dict[str, Any] | None = None ) -> BaseOperator: """ Get the Airflow Operator class for a Task.