Skip to content

Commit

Permalink
Update feature to insert attributes of DbtNode into extra_context
Browse files Browse the repository at this point in the history
This would allow the user to interact with all attributes of a DbtNode
  • Loading branch information
linchun3 committed Jun 20, 2024
1 parent d1c6896 commit 1fa1e91
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 5 deletions.
4 changes: 2 additions & 2 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def create_test_task_metadata(
else: # tested with node.resource_type == DbtResourceType.SEED or DbtResourceType.SNAPSHOT
task_args["select"] = node.resource_name

extra_context = {f"{node.resource_type.value}_config": node.config}
extra_context = {"dbt_node_config": node.context_dict}

elif render_config is not None: # TestBehavior.AFTER_ALL
task_args["select"] = render_config.select
Expand Down Expand Up @@ -146,7 +146,7 @@ def create_task_metadata(
args = {**args, **{"models": node.resource_name}}

if DbtResourceType(node.resource_type) in DEFAULT_DBT_RESOURCES and node.resource_type in dbt_resource_to_class:
extra_context = {f"{node.resource_type.value}_config": node.config}
extra_context = {"dbt_node_config": node.context_dict}
if node.resource_type == DbtResourceType.MODEL:
task_id = f"{node.name}_run"
if use_task_group is True:
Expand Down
18 changes: 18 additions & 0 deletions cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,24 @@ def name(self) -> str:
"""
return self.resource_name.replace(".", "_")

@property
def context_dict(self) -> dict[str, Any]:
"""
Returns a dictionary containing all the attributes of the DbtNode object,
ensuring that the output is JSON serializable so it can be stored in Airflow's db
"""
return {
"unique_id": self.unique_id,
"resource_type": self.resource_type.value, # convert enum to value
"depends_on": self.depends_on,
"file_path": str(self.file_path), # convert path to string
"tags": self.tags,
"config": self.config,
"has_test": self.has_test,
"resource_name": self.resource_name,
"name": self.name,
}


def run_command(command: list[str], tmp_dir: Path, env_vars: dict[str, str]) -> str:
"""Run a command in a subprocess, returning the stdout."""
Expand Down
42 changes: 39 additions & 3 deletions tests/airflow/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,23 +286,59 @@ def test_create_task_metadata_unsupported(caplog):
"my_model_run",
"cosmos.operators.local.DbtRunLocalOperator",
{"models": "my_model"},
{"model_config": {}},
{
"dbt_node_config": {
"unique_id": "model.my_folder.my_model",
"resource_type": "model",
"depends_on": [],
"file_path": ".",
"tags": [],
"config": {},
"has_test": False,
"resource_name": "my_model",
"name": "my_model",
}
},
),
(
f"{DbtResourceType.SOURCE.value}.my_folder.my_source",
DbtResourceType.SOURCE,
"my_source_run",
"cosmos.operators.local.DbtRunLocalOperator",
{"models": "my_source"},
{"source_config": {}},
{
"dbt_node_config": {
"unique_id": "model.my_folder.my_source",
"resource_type": "source",
"depends_on": [],
"file_path": ".",
"tags": [],
"config": {},
"has_test": False,
"resource_name": "my_source",
"name": "my_source",
}
},
),
(
f"{DbtResourceType.SNAPSHOT.value}.my_folder.my_snapshot",
DbtResourceType.SNAPSHOT,
"my_snapshot_snapshot",
"cosmos.operators.local.DbtSnapshotLocalOperator",
{"models": "my_snapshot"},
{"snapshot_config": {}},
{
"dbt_node_config": {
"unique_id": "snapshot.my_folder.my_snapshot",
"resource_type": "snapshot",
"depends_on": [],
"file_path": ".",
"tags": [],
"config": {},
"has_test": False,
"resource_name": "my_snapshot",
"name": "my_snapshot",
},
},
),
],
)
Expand Down
41 changes: 41 additions & 0 deletions tests/dbt/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,47 @@ def test_dbt_node_name_and_select(unique_id, expected_name, expected_select):
assert node.resource_name == expected_select


@pytest.mark.parametrize(
"unique_id,expected_dict",
[
(
"model.my_project.customers",
{
"unique_id": "model.my_project.customers",
"resource_type": "model",
"depends_on": [],
"file_path": ".",
"tags": [],
"config": {},
"has_test": False,
"resource_name": "customers",
"name": "customers",
},
),
(
"model.my_project.customers.v1",
{
"unique_id": "model.my_project.customers_v1",
"resource_type": "model",
"depends_on": [],
"file_path": ".",
"tags": [],
"config": {},
"has_test": False,
"resource_name": "customers.v1",
"name": "customers_v1",
},
),
],
)
def test_dbt_node_context_dict(
unique_id,
expected_dict,
):
node = DbtNode(unique_id=unique_id, resource_type=DbtResourceType.MODEL, depends_on=[], file_path="")
assert node.context_dict == expected_dict


@pytest.mark.parametrize(
"project_name,manifest_filepath,model_filepath",
[(DBT_PROJECT_NAME, SAMPLE_MANIFEST, "customers.sql"), ("jaffle_shop_python", SAMPLE_MANIFEST_PY, "customers.py")],
Expand Down

0 comments on commit 1fa1e91

Please sign in to comment.