Skip to content

Commit

Permalink
Add node config to TaskInstance Context
Browse files Browse the repository at this point in the history
Add node config into TaskInstance context for retrieval
  • Loading branch information
linchun3 committed Jun 14, 2024
1 parent e4f9ece commit 5525478
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 0 deletions.
8 changes: 8 additions & 0 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def create_test_task_metadata(
"""
task_args = dict(task_args)
task_args["on_warning_callback"] = on_warning_callback
extra_context = {}

if test_indirect_selection != TestIndirectSelection.EAGER:
task_args["indirect_selection"] = test_indirect_selection.value
if node is not None:
Expand All @@ -102,6 +104,9 @@ def create_test_task_metadata(
task_args["select"] = f"source:{node.resource_name}"
else: # tested with node.resource_type == DbtResourceType.SEED or DbtResourceType.SNAPSHOT
task_args["select"] = node.resource_name

extra_context = {f"{node.resource_type.value}_config": node.config}

elif render_config is not None: # TestBehavior.AFTER_ALL
task_args["select"] = render_config.select
task_args["selector"] = render_config.selector
Expand All @@ -114,6 +119,7 @@ def create_test_task_metadata(
dbt_class="DbtTest",
),
arguments=task_args,
extra_context=extra_context,
)


Expand All @@ -140,6 +146,7 @@ def create_task_metadata(
args = {**args, **{"models": node.resource_name}}

if DbtResourceType(node.resource_type) in DEFAULT_DBT_RESOURCES and node.resource_type in dbt_resource_to_class:
extra_context = {f"{node.resource_type.value}_config": node.config}
if node.resource_type == DbtResourceType.MODEL:
task_id = f"{node.name}_run"
if use_task_group is True:
Expand All @@ -155,6 +162,7 @@ def create_task_metadata(
execution_mode=execution_mode, dbt_class=dbt_resource_to_class[node.resource_type]
),
arguments=args,
extra_context=extra_context,
)
return task_metadata
else:
Expand Down
1 change: 1 addition & 0 deletions cosmos/core/airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def get_airflow_task(task: Task, dag: DAG, task_group: "TaskGroup | None" = None
task_id=task.id,
dag=dag,
task_group=task_group,
extra_context=task.extra_context,
**task.arguments,
)

Expand Down
1 change: 1 addition & 0 deletions cosmos/core/graph/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,4 @@ class Task(CosmosEntity):

operator_class: str = "airflow.operators.empty.EmptyOperator"
arguments: Dict[str, Any] = field(default_factory=dict)
extra_context: Dict[str, Any] = field(default_factory=dict)
9 changes: 9 additions & 0 deletions cosmos/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class AbstractDbtBaseOperator(BaseOperator, metaclass=ABCMeta):
:param dbt_cmd_flags: List of flags to pass to dbt command
:param dbt_cmd_global_flags: List of dbt global flags to be passed to the dbt command
:param cache_dir: Directory used to cache Cosmos/dbt artifacts in Airflow worker nodes
:param extra_context: A dictionary of values to add to the TaskInstance's Context
"""

template_fields: Sequence[str] = ("env", "select", "exclude", "selector", "vars", "models")
Expand Down Expand Up @@ -111,6 +112,7 @@ def __init__(
dbt_cmd_flags: list[str] | None = None,
dbt_cmd_global_flags: list[str] | None = None,
cache_dir: Path | None = None,
extra_context: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
self.project_dir = project_dir
Expand Down Expand Up @@ -139,6 +141,7 @@ def __init__(
self.dbt_cmd_flags = dbt_cmd_flags
self.dbt_cmd_global_flags = dbt_cmd_global_flags or []
self.cache_dir = cache_dir
self.extra_context = extra_context or {}
super().__init__(**kwargs)

def get_env(self, context: Context) -> dict[str, str | bytes | os.PathLike[Any]]:
Expand Down Expand Up @@ -260,6 +263,12 @@ def build_cmd(
def build_and_run_cmd(self, context: Context, cmd_flags: list[str]) -> Any:
"""Override this method for the operator to execute the dbt command"""

def pre_execute(self, context: Any) -> Any | None: # type: ignore
if self.extra_context:
logger.info("Extra context injected into TaskInstance...")
context = context | self.extra_context
return super().pre_execute(context)

def execute(self, context: Context) -> Any | None: # type: ignore
self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags())

Expand Down

0 comments on commit 5525478

Please sign in to comment.