Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix source seed selection docs generate #9454

Merged
merged 8 commits into from
Jan 30, 2024
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20240125-155641.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Fix seed and source selection in `dbt docs generate`
time: 2024-01-25T15:56:41.557934-05:00
custom:
Author: michelleark
Issue: "9161"
74 changes: 51 additions & 23 deletions core/dbt/task/docs/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from dbt.task.compile import CompileTask

from dbt.adapters.factory import get_adapter

from dbt.graph.graph import UniqueId
from dbt.contracts.graph.nodes import ResultNode
from dbt.contracts.graph.manifest import Manifest
from dbt.artifacts.schemas.results import NodeStatus
Expand All @@ -30,7 +32,7 @@
from dbt_common.exceptions import DbtInternalError
from dbt.exceptions import AmbiguousCatalogMatchError
from dbt.graph import ResourceTypeSelector
from dbt.node_types import EXECUTABLE_NODE_TYPES
from dbt.node_types import EXECUTABLE_NODE_TYPES, NodeType
from dbt_common.events.functions import fire_event
from dbt.adapters.events.types import (
WriteCatalogFailure,
Expand Down Expand Up @@ -112,8 +114,13 @@
table.columns[column.name] = column

def make_unique_id_map(
self, manifest: Manifest
self, manifest: Manifest, selected_node_ids: Optional[Set[UniqueId]] = None
) -> Tuple[Dict[str, CatalogTable], Dict[str, CatalogTable]]:
"""
Create mappings between CatalogKeys and CatalogTables for nodes and sources, filtered by selected_node_ids.

By default, selected_node_ids is None and all nodes and sources defined in the manifest are included in the mappings.
"""
nodes: Dict[str, CatalogTable] = {}
sources: Dict[str, CatalogTable] = {}

Expand All @@ -123,7 +130,8 @@
key = table.key()
if key in node_map:
unique_id = node_map[key]
nodes[unique_id] = table.replace(unique_id=unique_id)
if selected_node_ids is None or unique_id in selected_node_ids:
nodes[unique_id] = table.replace(unique_id=unique_id)

unique_ids = source_map.get(table.key(), set())
for unique_id in unique_ids:
Expand All @@ -133,7 +141,7 @@
sources[unique_id].to_dict(omit_none=True),
table.to_dict(omit_none=True),
)
else:
elif selected_node_ids is None or unique_id in selected_node_ids:
sources[unique_id] = table.replace(unique_id=unique_id)
return nodes, sources

Expand Down Expand Up @@ -238,9 +246,11 @@
if self.manifest is None:
raise DbtInternalError("self.manifest was None in run!")

selected_node_ids: Optional[Set[UniqueId]] = None
if self.args.empty_catalog:
catalog_table: agate.Table = agate.Table([])
exceptions: List[Exception] = []
selected_node_ids = set()
else:
adapter = get_adapter(self.config)
with adapter.connection_named("generate_catalog"):
Expand All @@ -251,14 +261,19 @@
selected_node_ids = self.job_queue.get_selected_nodes()
selected_nodes = self._get_nodes_from_ids(self.manifest, selected_node_ids)

source_ids = self._get_nodes_from_ids(
self.manifest, self.manifest.sources.keys()
# Source selection is handled separately from main job_queue selection because
# SourceDefinition nodes cannot be safely compiled / run by the CompileRunner / CompileTask,
# but should still be included in the catalog based on the selection spec
selected_source_ids = self._get_selected_source_ids()
selected_source_nodes = self._get_nodes_from_ids(
self.manifest, selected_source_ids
)
selected_nodes.extend(source_ids)
selected_node_ids.update(selected_source_ids)
selected_nodes.extend(selected_source_nodes)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Can we use tuple unpacking here?

selected_nodes = (*selected_nodes, *selected_source_nodes)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this led to some unsavory mypy issues, so I left it as-is.


relations = {
adapter.Relation.create_from(adapter.config, node_id)
for node_id in selected_nodes
adapter.Relation.create_from(adapter.config, node)
for node in selected_nodes
}

# This generates the catalog as an agate.Table
Expand All @@ -285,7 +300,7 @@
if exceptions:
errors = [str(e) for e in exceptions]

nodes, sources = catalog.make_unique_id_map(self.manifest)
nodes, sources = catalog.make_unique_id_map(self.manifest, selected_node_ids)
results = self.get_catalog_results(
nodes=nodes,
sources=sources,
Expand Down Expand Up @@ -322,19 +337,6 @@
fire_event(CatalogWritten(path=os.path.abspath(catalog_path)))
return results

@staticmethod
def _get_nodes_from_ids(manifest: Manifest, node_ids: Iterable[str]) -> List[ResultNode]:
selected: List[ResultNode] = []
for unique_id in node_ids:
if unique_id in manifest.nodes:
node = manifest.nodes[unique_id]
if node.is_relational and not node.is_ephemeral_model:
selected.append(node)
elif unique_id in manifest.sources:
source = manifest.sources[unique_id]
selected.append(source)
return selected

def get_node_selector(self) -> ResourceTypeSelector:
if self.manifest is None or self.graph is None:
raise DbtInternalError("manifest and graph must be set to perform node selection")
Expand Down Expand Up @@ -373,3 +375,29 @@
return True

return super().interpret_results(compile_results)

@staticmethod
def _get_nodes_from_ids(manifest: Manifest, node_ids: Iterable[str]) -> List[ResultNode]:
selected: List[ResultNode] = []
for unique_id in node_ids:
if unique_id in manifest.nodes:
node = manifest.nodes[unique_id]
if node.is_relational and not node.is_ephemeral_model:
selected.append(node)
elif unique_id in manifest.sources:
source = manifest.sources[unique_id]
selected.append(source)
return selected

def _get_selected_source_ids(self) -> Set[UniqueId]:
if self.manifest is None or self.graph is None:
raise DbtInternalError("manifest and graph must be set to perform node selection")

Check warning on line 394 in core/dbt/task/docs/generate.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/task/docs/generate.py#L394

Added line #L394 was not covered by tests

source_selector = ResourceTypeSelector(
graph=self.graph,
manifest=self.manifest,
previous_state=self.previous_state,
resource_types=[NodeType.Source],
)

return source_selector.get_graph_queue(self.get_selection_spec()).get_selected_nodes()
80 changes: 68 additions & 12 deletions tests/functional/docs/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@

sample_config = """
sources:
- name: my_seed
- name: my_source_schema
schema: "{{ target.schema }}"
tables:
- name: sample_seed
- name: second_seed
- name: fake_seed
- name: sample_source
- name: second_source
- name: non_existent_source
- name: source_from_seed
"""


Expand Down Expand Up @@ -76,11 +77,18 @@ def test_select_limits_no_match(self, project):
run_dbt(["run"])
catalog = run_dbt(["docs", "generate", "--select", "my_missing_model"])
assert len(catalog.nodes) == 0
assert len(catalog.sources) == 0


class TestGenerateCatalogWithSources(TestBaseGenerate):
def test_catalog_with_sources(self, project):
# populate sources other than non_existent_source
project.run_sql("create table {}.sample_source (id int)".format(project.test_schema))
project.run_sql("create table {}.second_source (id int)".format(project.test_schema))

# build nodes
run_dbt(["build"])

catalog = run_dbt(["docs", "generate"])

# 2 seeds + 2 models
Expand All @@ -91,7 +99,7 @@ def test_catalog_with_sources(self, project):

class TestGenerateCatalogWithExternalNodes(TestBaseGenerate):
@mock.patch("dbt.plugins.get_plugin_manager")
def test_catalog_with_sources(self, get_plugin_manager, project):
def test_catalog_with_external_node(self, get_plugin_manager, project):
project.run_sql("create table {}.external_model (id int)".format(project.test_schema))

run_dbt(["build"])
Expand All @@ -112,13 +120,61 @@ def test_catalog_with_sources(self, get_plugin_manager, project):


class TestGenerateSelectSource(TestBaseGenerate):
@pytest.fixture(scope="class")
def seeds(self):
return {
"sample_seed.csv": sample_seed,
"second_seed.csv": sample_seed,
"source_from_seed.csv": sample_seed,
}

def test_select_source(self, project):
run_dbt(["build"])
catalog = run_dbt(["docs", "generate", "--select", "source:test.my_seed.sample_seed"])

# 2 seeds
# TODO: Filtering doesn't work for seeds
assert len(catalog.nodes) == 2
# 2 sources
# TODO: Filtering doesn't work for sources
assert len(catalog.sources) == 2
project.run_sql("create table {}.sample_source (id int)".format(project.test_schema))
project.run_sql("create table {}.second_source (id int)".format(project.test_schema))

# 2 existing sources, 1 selected
catalog = run_dbt(
["docs", "generate", "--select", "source:test.my_source_schema.sample_source"]
)
assert len(catalog.sources) == 1
assert "source.test.my_source_schema.sample_source" in catalog.sources
# no nodes selected
assert len(catalog.nodes) == 0

# 2 existing sources sources, 1 selected that has relation as a seed
catalog = run_dbt(
["docs", "generate", "--select", "source:test.my_source_schema.source_from_seed"]
)
assert len(catalog.sources) == 1
assert "source.test.my_source_schema.source_from_seed" in catalog.sources
# seed with same relation that was not selected not in catalog
assert len(catalog.nodes) == 0


class TestGenerateSelectSeed(TestBaseGenerate):
@pytest.fixture(scope="class")
def seeds(self):
return {
"sample_seed.csv": sample_seed,
"second_seed.csv": sample_seed,
"source_from_seed.csv": sample_seed,
}

def test_select_seed(self, project):
run_dbt(["build"])

# 3 seeds, 1 selected
catalog = run_dbt(["docs", "generate", "--select", "sample_seed"])
assert len(catalog.nodes) == 1
assert "seed.test.sample_seed" in catalog.nodes
# no sources selected
assert len(catalog.sources) == 0

# 3 seeds, 1 selected that has same relation as a source
catalog = run_dbt(["docs", "generate", "--select", "source_from_seed"])
assert len(catalog.nodes) == 1
assert "seed.test.source_from_seed" in catalog.nodes
# source with same relation that was not selected not in catalog
assert len(catalog.sources) == 0
Loading