Skip to content

Commit

Permalink
Fix source seed selection docs generate (#9454) (#9493)
Browse files Browse the repository at this point in the history
(cherry picked from commit 719a50c)
  • Loading branch information
MichelleArk authored Jan 30, 2024
1 parent 0d8e4af commit 1391363
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 33 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20240125-155641.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Fix seed and source selection in `dbt docs generate`
time: 2024-01-25T15:56:41.557934-05:00
custom:
Author: michelleark
Issue: "9161"
72 changes: 50 additions & 22 deletions core/dbt/task/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from .compile import CompileTask

from dbt.adapters.factory import get_adapter

from dbt.graph.graph import UniqueId
from dbt.contracts.graph.nodes import ResultNode
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.results import (
Expand Down Expand Up @@ -109,8 +111,13 @@ def add_column(self, data: PrimitiveDict):
table.columns[column.name] = column

def make_unique_id_map(
self, manifest: Manifest
self, manifest: Manifest, selected_node_ids: Optional[Set[UniqueId]] = None
) -> Tuple[Dict[str, CatalogTable], Dict[str, CatalogTable]]:
"""
Create mappings between CatalogKeys and CatalogTables for nodes and sources, filtered by selected_node_ids.
By default, selected_node_ids is None and all nodes and sources defined in the manifest are included in the mappings.
"""
nodes: Dict[str, CatalogTable] = {}
sources: Dict[str, CatalogTable] = {}

Expand All @@ -120,7 +127,8 @@ def make_unique_id_map(
key = table.key()
if key in node_map:
unique_id = node_map[key]
nodes[unique_id] = table.replace(unique_id=unique_id)
if selected_node_ids is None or unique_id in selected_node_ids:
nodes[unique_id] = table.replace(unique_id=unique_id)

unique_ids = source_map.get(table.key(), set())
for unique_id in unique_ids:
Expand All @@ -130,7 +138,7 @@ def make_unique_id_map(
sources[unique_id].to_dict(omit_none=True),
table.to_dict(omit_none=True),
)
else:
elif selected_node_ids is None or unique_id in selected_node_ids:
sources[unique_id] = table.replace(unique_id=unique_id)
return nodes, sources

Expand Down Expand Up @@ -235,9 +243,11 @@ def run(self) -> CatalogArtifact:
if self.manifest is None:
raise DbtInternalError("self.manifest was None in run!")

selected_node_ids: Optional[Set[UniqueId]] = None
if self.args.empty_catalog:
catalog_table: agate.Table = agate.Table([])
exceptions: List[Exception] = []
selected_node_ids = set()
else:
adapter = get_adapter(self.config)
with adapter.connection_named("generate_catalog"):
Expand All @@ -248,14 +258,19 @@ def run(self) -> CatalogArtifact:
selected_node_ids = self.job_queue.get_selected_nodes()
selected_nodes = self._get_nodes_from_ids(self.manifest, selected_node_ids)

source_ids = self._get_nodes_from_ids(
self.manifest, self.manifest.sources.keys()
# Source selection is handled separately from main job_queue selection because
# SourceDefinition nodes cannot be safely compiled / run by the CompileRunner / CompileTask,
# but should still be included in the catalog based on the selection spec
selected_source_ids = self._get_selected_source_ids()
selected_source_nodes = self._get_nodes_from_ids(
self.manifest, selected_source_ids
)
selected_nodes.extend(source_ids)
selected_node_ids.update(selected_source_ids)
selected_nodes.extend(selected_source_nodes)

relations = {
adapter.Relation.create_from(adapter.config, node_id)
for node_id in selected_nodes
adapter.Relation.create_from(adapter.config, node)
for node in selected_nodes
}

# This generates the catalog as an agate.Table
Expand All @@ -272,7 +287,7 @@ def run(self) -> CatalogArtifact:
if exceptions:
errors = [str(e) for e in exceptions]

nodes, sources = catalog.make_unique_id_map(self.manifest)
nodes, sources = catalog.make_unique_id_map(self.manifest, selected_node_ids)
results = self.get_catalog_results(
nodes=nodes,
sources=sources,
Expand Down Expand Up @@ -309,19 +324,6 @@ def run(self) -> CatalogArtifact:
fire_event(CatalogWritten(path=os.path.abspath(catalog_path)))
return results

@staticmethod
def _get_nodes_from_ids(manifest: Manifest, node_ids: Iterable[str]) -> List[ResultNode]:
selected: List[ResultNode] = []
for unique_id in node_ids:
if unique_id in manifest.nodes:
node = manifest.nodes[unique_id]
if node.is_relational and not node.is_ephemeral_model:
selected.append(node)
elif unique_id in manifest.sources:
source = manifest.sources[unique_id]
selected.append(source)
return selected

def get_node_selector(self) -> ResourceTypeSelector:
if self.manifest is None or self.graph is None:
raise DbtInternalError("manifest and graph must be set to perform node selection")
Expand Down Expand Up @@ -359,3 +361,29 @@ def interpret_results(self, results: Optional[CatalogResults]) -> bool:
return True

return super().interpret_results(compile_results)

@staticmethod
def _get_nodes_from_ids(manifest: Manifest, node_ids: Iterable[str]) -> List[ResultNode]:
selected: List[ResultNode] = []
for unique_id in node_ids:
if unique_id in manifest.nodes:
node = manifest.nodes[unique_id]
if node.is_relational and not node.is_ephemeral_model:
selected.append(node)
elif unique_id in manifest.sources:
source = manifest.sources[unique_id]
selected.append(source)
return selected

def _get_selected_source_ids(self) -> Set[UniqueId]:
if self.manifest is None or self.graph is None:
raise DbtInternalError("manifest and graph must be set to perform node selection")

source_selector = ResourceTypeSelector(
graph=self.graph,
manifest=self.manifest,
previous_state=self.previous_state,
resource_types=[NodeType.Source],
)

return source_selector.get_graph_queue(self.get_selection_spec()).get_selected_nodes()
78 changes: 67 additions & 11 deletions tests/functional/docs/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@

sample_config = """
sources:
- name: my_seed
- name: my_source_schema
schema: "{{ target.schema }}"
tables:
- name: sample_seed
- name: second_seed
- name: fake_seed
- name: sample_source
- name: second_source
- name: non_existent_source
- name: source_from_seed
"""


Expand Down Expand Up @@ -74,11 +75,18 @@ def test_select_limits_no_match(self, project):
run_dbt(["run"])
catalog = run_dbt(["docs", "generate", "--select", "my_missing_model"])
assert len(catalog.nodes) == 0
assert len(catalog.sources) == 0


class TestGenerateCatalogWithSources(TestBaseGenerate):
def test_catalog_with_sources(self, project):
# populate sources other than non_existent_source
project.run_sql("create table {}.sample_source (id int)".format(project.test_schema))
project.run_sql("create table {}.second_source (id int)".format(project.test_schema))

# build nodes
run_dbt(["build"])

catalog = run_dbt(["docs", "generate"])

# 2 seeds + 2 models
Expand All @@ -88,13 +96,61 @@ def test_catalog_with_sources(self, project):


class TestGenerateSelectSource(TestBaseGenerate):
@pytest.fixture(scope="class")
def seeds(self):
return {
"sample_seed.csv": sample_seed,
"second_seed.csv": sample_seed,
"source_from_seed.csv": sample_seed,
}

def test_select_source(self, project):
run_dbt(["build"])
catalog = run_dbt(["docs", "generate", "--select", "source:test.my_seed.sample_seed"])

# 2 seeds
# TODO: Filtering doesn't work for seeds
assert len(catalog.nodes) == 2
# 2 sources
# TODO: Filtering doesn't work for sources
assert len(catalog.sources) == 2
project.run_sql("create table {}.sample_source (id int)".format(project.test_schema))
project.run_sql("create table {}.second_source (id int)".format(project.test_schema))

# 2 existing sources, 1 selected
catalog = run_dbt(
["docs", "generate", "--select", "source:test.my_source_schema.sample_source"]
)
assert len(catalog.sources) == 1
assert "source.test.my_source_schema.sample_source" in catalog.sources
# no nodes selected
assert len(catalog.nodes) == 0

# 2 existing sources sources, 1 selected that has relation as a seed
catalog = run_dbt(
["docs", "generate", "--select", "source:test.my_source_schema.source_from_seed"]
)
assert len(catalog.sources) == 1
assert "source.test.my_source_schema.source_from_seed" in catalog.sources
# seed with same relation that was not selected not in catalog
assert len(catalog.nodes) == 0


class TestGenerateSelectSeed(TestBaseGenerate):
@pytest.fixture(scope="class")
def seeds(self):
return {
"sample_seed.csv": sample_seed,
"second_seed.csv": sample_seed,
"source_from_seed.csv": sample_seed,
}

def test_select_seed(self, project):
run_dbt(["build"])

# 3 seeds, 1 selected
catalog = run_dbt(["docs", "generate", "--select", "sample_seed"])
assert len(catalog.nodes) == 1
assert "seed.test.sample_seed" in catalog.nodes
# no sources selected
assert len(catalog.sources) == 0

# 3 seeds, 1 selected that has same relation as a source
catalog = run_dbt(["docs", "generate", "--select", "source_from_seed"])
assert len(catalog.nodes) == 1
assert "seed.test.source_from_seed" in catalog.nodes
# source with same relation that was not selected not in catalog
assert len(catalog.sources) == 0

0 comments on commit 1391363

Please sign in to comment.