Skip to content

Commit

Permalink
Fix task owner fallback
Browse files Browse the repository at this point in the history
`dag.owner` is a computed property that joins owners of existing tasks.
We should rely on airflow's existing owner fallback in airflow.models.baseoperator.BaseOperator.

Fixes #1194
  • Loading branch information
jmaicher committed Sep 6, 2024
1 parent 1f1fc61 commit 05c0bcb
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 5 deletions.
7 changes: 3 additions & 4 deletions cosmos/core/airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,15 @@ def get_airflow_task(task: Task, dag: DAG, task_group: "TaskGroup | None" = None
module = importlib.import_module(module_name)
Operator = getattr(module, class_name)

task_kwargs = {}
if task.owner != "":
task_owner = task.owner
else:
task_owner = dag.owner
task_kwargs["owner"] = task.owner

airflow_task = Operator(
task_id=task.id,
dag=dag,
task_group=task_group,
owner=task_owner,
**task_kwargs,
**({} if class_name == "EmptyOperator" else {"extra_context": task.extra_context}),
**task.arguments,
)
Expand Down
51 changes: 50 additions & 1 deletion tests/airflow/test_graph.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import os
from datetime import datetime
from pathlib import Path
from typing import Any
from unittest.mock import patch

import pytest
from airflow import __version__ as airflow_version
from airflow.models import DAG
from airflow.models.abstractoperator import DEFAULT_OWNER
from airflow.utils.task_group import TaskGroup
from packaging import version

Expand Down Expand Up @@ -130,8 +132,9 @@ def test_build_airflow_graph_with_after_each():
task_seed_parent_seed = dag.tasks[0]
task_parent_run = dag.tasks[1]

assert task_seed_parent_seed.owner == ""
assert task_seed_parent_seed.owner == DEFAULT_OWNER
assert task_parent_run.owner == "parent_node"
assert {d for d in dag.owner.split(", ")} == {DEFAULT_OWNER, "parent_node"}


@pytest.mark.parametrize(
Expand Down Expand Up @@ -604,3 +607,49 @@ def test_airflow_kwargs_generation():
result = airflow_kwargs(**task_args)

assert "dag" in result


@pytest.mark.parametrize(
"dbt_extra_config,expected_owner",
[
({}, DEFAULT_OWNER),
({"meta": {}}, DEFAULT_OWNER),
({"meta": {"owner": ""}}, DEFAULT_OWNER),
({"meta": {"owner": "dbt-owner"}}, "dbt-owner"),
],
)
def test_owner(dbt_extra_config, expected_owner):
with DAG("test-task-group-after-each", start_date=datetime(2022, 1, 1)) as dag:
node = DbtNode(
unique_id=f"{DbtResourceType.MODEL.value}.my_folder.my_model",
resource_type=DbtResourceType.MODEL,
file_path=SAMPLE_PROJ_PATH / "gen2/models/parent.sql",
tags=["has_child"],
config={"materialized": "view", **dbt_extra_config},
depends_on=[]
)

output: TaskGroup = generate_task_or_group(
dag=dag,
task_group=None,
node=node,
execution_mode=ExecutionMode.LOCAL,
test_indirect_selection=TestIndirectSelection.EAGER,
task_args={
"project_dir": SAMPLE_PROJ_PATH,
"profile_config": ProfileConfig(
profile_name="default",
target_name="default",
profile_mapping=PostgresUserPasswordProfileMapping(
conn_id="fake_conn",
profile_args={"schema": "public"},
),
),
},
test_behavior=TestBehavior.AFTER_EACH,
on_warning_callback=None,
source_rendering_behavior=SOURCE_RENDERING_BEHAVIOR,
)

assert len(output.leaves) == 1
assert output.leaves[0].owner == expected_owner

0 comments on commit 05c0bcb

Please sign in to comment.