Skip to content

Commit

Permalink
simplify fix, improve test setup to decouple seeds & sources
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelleArk committed Jan 29, 2024
1 parent 77d71b2 commit 6d17b50
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 20 deletions.
17 changes: 5 additions & 12 deletions core/dbt/task/docs/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,10 @@ def add_column(self, data: PrimitiveDict):
table.columns[column.name] = column

def make_unique_id_map(
self, manifest: Manifest, selected_node_ids: Optional[Set[UniqueId]] = None
self, manifest: Manifest
) -> Tuple[Dict[str, CatalogTable], Dict[str, CatalogTable]]:
"""
Create mappings between CatalogKeys and CatalogTables for nodes and sources, filtered by selected_node_ids.
By default, selected_node_ids is None and all nodes and sources defined in the manifest are included in the mappings.
Create mappings between CatalogKeys and CatalogTables for nodes and sources.
"""
nodes: Dict[str, CatalogTable] = {}
sources: Dict[str, CatalogTable] = {}
Expand All @@ -130,8 +128,7 @@ def make_unique_id_map(
key = table.key()
if key in node_map:
unique_id = node_map[key]
if selected_node_ids is None or unique_id in selected_node_ids:
nodes[unique_id] = table.replace(unique_id=unique_id)
nodes[unique_id] = table.replace(unique_id=unique_id)

unique_ids = source_map.get(table.key(), set())
for unique_id in unique_ids:
Expand All @@ -142,8 +139,7 @@ def make_unique_id_map(
table.to_dict(omit_none=True),
)
else:
if selected_node_ids is None or unique_id in selected_node_ids:
sources[unique_id] = table.replace(unique_id=unique_id)
sources[unique_id] = table.replace(unique_id=unique_id)
return nodes, sources


Expand Down Expand Up @@ -247,11 +243,9 @@ def run(self) -> CatalogArtifact:
if self.manifest is None:
raise DbtInternalError("self.manifest was None in run!")

selected_node_ids: Optional[Set[UniqueId]] = None
if self.args.empty_catalog:
catalog_table: agate.Table = agate.Table([])
exceptions: List[Exception] = []
selected_node_ids = set()
else:
adapter = get_adapter(self.config)
with adapter.connection_named("generate_catalog"):
Expand All @@ -269,7 +263,6 @@ def run(self) -> CatalogArtifact:
selected_source_nodes = self._get_nodes_from_ids(
self.manifest, selected_source_ids
)
selected_node_ids.update(selected_source_ids)
selected_nodes.extend(selected_source_nodes)

relations = {
Expand Down Expand Up @@ -301,7 +294,7 @@ def run(self) -> CatalogArtifact:
if exceptions:
errors = [str(e) for e in exceptions]

nodes, sources = catalog.make_unique_id_map(self.manifest, selected_node_ids)
nodes, sources = catalog.make_unique_id_map(self.manifest)
results = self.get_catalog_results(
nodes=nodes,
sources=sources,
Expand Down
25 changes: 17 additions & 8 deletions tests/functional/docs/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@

sample_config = """
sources:
- name: my_seed
- name: my_source_schema
schema: "{{ target.schema }}"
tables:
- name: sample_seed
- name: second_seed
- name: fake_seed
- name: sample_source
- name: second_source
- name: non_existent_source
"""


Expand Down Expand Up @@ -81,7 +81,13 @@ def test_select_limits_no_match(self, project):

class TestGenerateCatalogWithSources(TestBaseGenerate):
def test_catalog_with_sources(self, project):
# populate sources other than non_existent_source
project.run_sql("create table {}.sample_source (id int)".format(project.test_schema))
project.run_sql("create table {}.second_source (id int)".format(project.test_schema))

# build nodes
run_dbt(["build"])

catalog = run_dbt(["docs", "generate"])

# 2 seeds + 2 models
Expand All @@ -92,7 +98,7 @@ def test_catalog_with_sources(self, project):

class TestGenerateCatalogWithExternalNodes(TestBaseGenerate):
@mock.patch("dbt.plugins.get_plugin_manager")
def test_catalog_with_sources(self, get_plugin_manager, project):
def test_catalog_with_external_node(self, get_plugin_manager, project):
project.run_sql("create table {}.external_model (id int)".format(project.test_schema))

run_dbt(["build"])
Expand All @@ -114,12 +120,15 @@ def test_catalog_with_sources(self, get_plugin_manager, project):

class TestGenerateSelectSource(TestBaseGenerate):
def test_select_source(self, project):
run_dbt(["build"])
project.run_sql("create table {}.sample_source (id int)".format(project.test_schema))
project.run_sql("create table {}.second_source (id int)".format(project.test_schema))

# 2 sources, 1 selected
catalog = run_dbt(["docs", "generate", "--select", "source:test.my_seed.sample_seed"])
catalog = run_dbt(
["docs", "generate", "--select", "source:test.my_source_schema.sample_source"]
)
assert len(catalog.sources) == 1
assert "source.test.my_seed.sample_seed" in catalog.sources
assert "source.test.my_source_schema.sample_source" in catalog.sources
# no nodes selected
assert len(catalog.nodes) == 0

Expand Down

0 comments on commit 6d17b50

Please sign in to comment.