Skip to content

Commit

Permalink
Fix #8652: Use seed value if rows not specified
Browse files Browse the repository at this point in the history
  • Loading branch information
aranke committed Nov 13, 2023
1 parent 3b033ac commit 8b2948d
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 2 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20231113-154535.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Use seed value if rows not specified
time: 2023-11-13T15:45:35.008565Z
custom:
Author: aranke
Issue: "8652"
27 changes: 25 additions & 2 deletions core/dbt/parser/unit_tests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from csv import DictReader
from pathlib import Path
from typing import List, Set, Dict, Any

from dbt_extractor import py_extract_from_source, ExtractionError # type: ignore

from dbt.config import RuntimeConfig
from dbt.context.context_config import ContextConfig
from dbt.context.providers import generate_parse_exposure, get_rendered
Expand All @@ -14,7 +18,7 @@
UnitTestConfig,
)
from dbt.contracts.graph.unparsed import UnparsedUnitTest
from dbt.exceptions import ParsingError, InvalidUnitTestGivenInput
from dbt.exceptions import ParsingError, InvalidUnitTestGivenInput, DbtInternalError
from dbt.graph import UniqueId
from dbt.node_types import NodeType
from dbt.parser.schemas import (
Expand All @@ -27,7 +31,6 @@
ParseResult,
)
from dbt.utils import get_pseudo_test_path
from dbt_extractor import py_extract_from_source, ExtractionError # type: ignore


class UnitTestManifestLoader:
Expand Down Expand Up @@ -203,6 +206,22 @@ def __init__(self, schema_parser: SchemaParser, yaml: YamlBlock) -> None:
self.schema_parser = schema_parser
self.yaml = yaml

def load_rows_from_seed(self, seed_name):
rows = []

Check warning on line 210 in core/dbt/parser/unit_tests.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/parser/unit_tests.py#L210

Added line #L210 was not covered by tests

try:
seed_node = self.manifest.ref_lookup.perform_lookup(

Check warning on line 213 in core/dbt/parser/unit_tests.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/parser/unit_tests.py#L212-L213

Added lines #L212 - L213 were not covered by tests
f"seed.{self.project.project_name}.{seed_name}", self.manifest
)
seed_path = Path(seed_node.root_path) / seed_node.original_file_path
with open(seed_path, "r") as f:
for row in DictReader(f):
rows.append(row)
except DbtInternalError:
pass

Check warning on line 221 in core/dbt/parser/unit_tests.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/parser/unit_tests.py#L216-L221

Added lines #L216 - L221 were not covered by tests
finally:
return rows

Check warning on line 223 in core/dbt/parser/unit_tests.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/parser/unit_tests.py#L223

Added line #L223 was not covered by tests

def parse(self) -> ParseResult:
for data in self.get_key_dicts():
unit_test = self._get_unit_test(data)
Expand All @@ -214,8 +233,12 @@ def parse(self) -> ParseResult:
unit_test_fqn = [self.project.project_name] + model_name_split + [unit_test.name]
unit_test_config = self._build_unit_test_config(unit_test_fqn, unit_test.config)

# self.manifest.ref_lookup.perform_lookup('seed.test.my_favorite_source', self.manifest)

# Check that format and type of rows matches for each given input
for input in unit_test.given:
if input.rows is None:
input.rows = self.load_rows_from_seed(input.input.split("'")[1])

Check warning on line 241 in core/dbt/parser/unit_tests.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/parser/unit_tests.py#L240-L241

Added lines #L240 - L241 were not covered by tests
input.validate_fixture("input", unit_test.name)
unit_test.expect.validate_fixture("expected", unit_test.name)

Expand Down
61 changes: 61 additions & 0 deletions tests/functional/unit_testing/test_unit_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,64 @@ def test_basic(self, project):
# Select by model name
results = run_dbt(["unit-test", "--select", "my_incremental_model"], expect_pass=True)
assert len(results) == 2


my_new_model = """
select
my_favorite_seed.id,
a + b as c
from {{ ref('my_favorite_seed') }} as my_favorite_seed
inner join {{ ref('my_favorite_model') }} as my_favorite_model
on my_favorite_seed.id = my_second_favorite_model.id
"""

my_second_favorite_model = """
select
2 as id,
3 as b
"""

seed_my_favorite_seed = """id,a
1,5
2,4
3,3
4,2
5,1
"""

test_my_model_implicit_seed = """
unit_tests:
- name: t
model: my_new_model
given:
- input: ref('my_favorite_seed')
- input: ref('my_second_favorite_model')
rows:
- {id: 1, b: 2}
expect:
rows:
- {id: 1, c: 7}
"""


class TestUnitTestImplicitSeed:
@pytest.fixture(scope="class")
def seeds(self):
return {"my_favorite_seed.csv": seed_my_favorite_seed}

@pytest.fixture(scope="class")
def models(self):
return {
"my_new_model.sql": my_new_model,
"my_second_favorite_model.sql": my_second_favorite_model,
"schema.yml": test_my_model_implicit_seed,
}

def test_basic(self, project):
run_dbt(["seed"])
run_dbt(["run"])
# assert len(results) == 1

# Select by model name
results = run_dbt(["unit-test", "--select", "my_new_model"], expect_pass=True)
assert len(results) == 1

0 comments on commit 8b2948d

Please sign in to comment.