From d4fef4fdc7000328885f23096a11c3093ca5fb9c Mon Sep 17 00:00:00 2001 From: Gerda Shank Date: Thu, 16 Nov 2023 13:47:29 -0500 Subject: [PATCH 01/16] Switch to using 'test' command instead of 'unit-test' --- core/dbt/cli/main.py | 1 + core/dbt/compilation.py | 4 + core/dbt/contracts/graph/nodes.py | 16 +- core/dbt/contracts/results.py | 9 +- .../macros/unit_test_sql/get_fixture_sql.sql | 2 +- core/dbt/parser/manifest.py | 7 +- core/dbt/parser/unit_tests.py | 7 +- core/dbt/task/base.py | 2 +- core/dbt/task/runnable.py | 13 +- core/dbt/task/test.py | 137 ++++++++++++++++-- core/dbt/task/unit_test.py | 2 +- .../unit_testing/test_csv_fixtures.py | 30 ++-- tests/functional/unit_testing/test_state.py | 14 +- .../unit_testing/test_unit_testing.py | 21 +-- .../unit_testing/test_ut_sources.py | 5 +- 15 files changed, 197 insertions(+), 73 deletions(-) diff --git a/core/dbt/cli/main.py b/core/dbt/cli/main.py index bdc7691c9bb..ad82c2688c1 100644 --- a/core/dbt/cli/main.py +++ b/core/dbt/cli/main.py @@ -870,6 +870,7 @@ def freshness(ctx, **kwargs): @p.project_dir @p.select @p.selector +@p.show_output_format @p.state @p.defer_state @p.deprecated_state diff --git a/core/dbt/compilation.py b/core/dbt/compilation.py index 101a7d30f58..affc47bbd7f 100644 --- a/core/dbt/compilation.py +++ b/core/dbt/compilation.py @@ -25,6 +25,7 @@ InjectedCTE, SeedNode, UnitTestNode, + UnitTestDefinition, ) from dbt.exceptions import ( GraphDependencyNotFoundError, @@ -539,6 +540,9 @@ def compile_node( the node's raw_code into compiled_code, and then calls the recursive method to "prepend" the ctes. """ + if isinstance(node, UnitTestDefinition): + return node + # Make sure Lexer for sqlparse 0.4.4 is initialized from sqlparse.lexer import Lexer # type: ignore diff --git a/core/dbt/contracts/graph/nodes.py b/core/dbt/contracts/graph/nodes.py index e0ce7f8ab34..b797edead8a 100644 --- a/core/dbt/contracts/graph/nodes.py +++ b/core/dbt/contracts/graph/nodes.py @@ -1082,16 +1082,30 @@ class UnitTestNode(CompiledNode): @dataclass -class UnitTestDefinition(GraphNode): +class UnitTestDefinitionMandatory: model: str given: Sequence[UnitTestInputFixture] expect: UnitTestOutputFixture + + +@dataclass +class UnitTestDefinition(NodeInfoMixin, GraphNode, UnitTestDefinitionMandatory): description: str = "" overrides: Optional[UnitTestOverrides] = None depends_on: DependsOn = field(default_factory=DependsOn) config: UnitTestConfig = field(default_factory=UnitTestConfig) checksum: Optional[str] = None + @property + def build_path(self): + # TODO: is this actually necessary? + return self.original_file_path + + @property + def compiled_path(self): + # TODO: is this actually necessary? + return self.original_file_path + @property def depends_on_nodes(self): return self.depends_on.nodes diff --git a/core/dbt/contracts/results.py b/core/dbt/contracts/results.py index a94abe0dfda..6172e3b468e 100644 --- a/core/dbt/contracts/results.py +++ b/core/dbt/contracts/results.py @@ -1,7 +1,12 @@ import threading from dbt.contracts.graph.unparsed import FreshnessThreshold -from dbt.contracts.graph.nodes import CompiledNode, SourceDefinition, ResultNode +from dbt.contracts.graph.nodes import ( + CompiledNode, + SourceDefinition, + ResultNode, + UnitTestDefinition, +) from dbt.contracts.util import ( BaseArtifactMetadata, ArtifactMixin, @@ -153,7 +158,7 @@ def to_msg_dict(self): @dataclass class NodeResult(BaseResult): - node: ResultNode + node: Union[ResultNode, UnitTestDefinition] @dataclass diff --git a/core/dbt/include/global_project/macros/unit_test_sql/get_fixture_sql.sql b/core/dbt/include/global_project/macros/unit_test_sql/get_fixture_sql.sql index c869abe1f35..2f90a561d91 100644 --- a/core/dbt/include/global_project/macros/unit_test_sql/get_fixture_sql.sql +++ b/core/dbt/include/global_project/macros/unit_test_sql/get_fixture_sql.sql @@ -11,7 +11,7 @@ {%- endif -%} {%- if not column_name_to_data_types -%} - {{ exceptions.raise_compiler_error("columns not available for " ~ model.name) }} + {{ exceptions.raise_compiler_error("Not able to get columns for unit test '" ~ model.name ~ "' from relation " ~ this) }} {%- endif -%} {%- for column_name, column_type in column_name_to_data_types.items() -%} diff --git a/core/dbt/parser/manifest.py b/core/dbt/parser/manifest.py index 3af1a681fb3..f0da24e630f 100644 --- a/core/dbt/parser/manifest.py +++ b/core/dbt/parser/manifest.py @@ -39,7 +39,6 @@ MANIFEST_FILE_NAME, PARTIAL_PARSE_FILE_NAME, SEMANTIC_MANIFEST_FILE_NAME, - UNIT_TEST_MANIFEST_FILE_NAME, ) from dbt.helper_types import PathSet from dbt.events.functions import fire_event, get_invocation_id, warn_or_error @@ -1767,11 +1766,7 @@ def write_semantic_manifest(manifest: Manifest, target_path: str) -> None: def write_manifest(manifest: Manifest, target_path: str, which: Optional[str] = None): - if which and which == "unit-test": - file_name = UNIT_TEST_MANIFEST_FILE_NAME - else: - file_name = MANIFEST_FILE_NAME - + file_name = MANIFEST_FILE_NAME path = os.path.join(target_path, file_name) manifest.write(path) diff --git a/core/dbt/parser/unit_tests.py b/core/dbt/parser/unit_tests.py index c0182ef017c..667394f7162 100644 --- a/core/dbt/parser/unit_tests.py +++ b/core/dbt/parser/unit_tests.py @@ -44,9 +44,9 @@ def __init__(self, manifest, root_project, selected) -> None: def load(self) -> Manifest: for unique_id in self.selected: - unit_test_case = self.manifest.unit_tests[unique_id] - self.parse_unit_test_case(unit_test_case) - + if unique_id in self.manifest.unit_tests: + unit_test_case = self.manifest.unit_tests[unique_id] + self.parse_unit_test_case(unit_test_case) return self.unit_test_manifest def parse_unit_test_case(self, test_case: UnitTestDefinition): @@ -86,7 +86,6 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition): overrides=test_case.overrides, ) - # TODO: generalize this method ctx = generate_parse_exposure( unit_test_node, # type: ignore self.root_project, diff --git a/core/dbt/task/base.py b/core/dbt/task/base.py index a4c9b526008..a7b407216ea 100644 --- a/core/dbt/task/base.py +++ b/core/dbt/task/base.py @@ -308,7 +308,7 @@ def compile_and_execute(self, manifest, ctx): with collect_timing_info("compile", ctx.timing.append): # if we fail here, we still have a compiled node to return # this has the benefit of showing a build path for the errant - # model + # model. This calls the 'compile' method in CompileTask ctx.node = self.compile(manifest) # for ephemeral nodes, we only want to compile, not run diff --git a/core/dbt/task/runnable.py b/core/dbt/task/runnable.py index dcb378c5f04..0a0844f8b7f 100644 --- a/core/dbt/task/runnable.py +++ b/core/dbt/task/runnable.py @@ -141,10 +141,6 @@ def get_graph_queue(self) -> GraphQueue: spec = self.get_selection_spec() return selector.get_graph_queue(spec) - # A callback for unit testing - def reset_job_queue_and_manifest(self): - pass - def _runtime_initialize(self): self.compile_manifest() if self.manifest is None or self.graph is None: @@ -152,9 +148,6 @@ def _runtime_initialize(self): self.job_queue = self.get_graph_queue() - # for unit testing - self.reset_job_queue_and_manifest() - # we use this a couple of times. order does not matter. self._flattened_nodes = [] for uid in self.job_queue.get_selected_nodes(): @@ -164,9 +157,11 @@ def _runtime_initialize(self): self._flattened_nodes.append(self.manifest.sources[uid]) elif uid in self.manifest.saved_queries: self._flattened_nodes.append(self.manifest.saved_queries[uid]) + elif uid in self.manifest.unit_tests: + self._flattened_nodes.append(self.manifest.unit_tests[uid]) else: raise DbtInternalError( - f"Node selection returned {uid}, expected a node or a source" + f"Node selection returned {uid}, expected a node, a source, or a unit test" ) self.num_nodes = len([n for n in self._flattened_nodes if not n.is_ephemeral_model]) @@ -496,7 +491,7 @@ def run(self): if self.args.write_json: # args.which used to determine file name for unit test manifest - write_manifest(self.manifest, self.config.project_target_path, self.args.which) + write_manifest(self.manifest, self.config.project_target_path) if hasattr(result, "write"): result.write(self.result_path()) diff --git a/core/dbt/task/test.py b/core/dbt/task/test.py index c0af8baa5df..a6090cd5bc3 100644 --- a/core/dbt/task/test.py +++ b/core/dbt/task/test.py @@ -1,18 +1,17 @@ from distutils.util import strtobool +import io from dataclasses import dataclass from dbt.utils import _coerce_decimal from dbt.events.format import pluralize from dbt.dataclass_schema import dbtClassMixin import threading -from typing import Dict, Any +from typing import Dict, Any, Optional, Union from .compile import CompileRunner from .run import RunTask -from dbt.contracts.graph.nodes import ( - TestNode, -) +from dbt.contracts.graph.nodes import TestNode, UnitTestDefinition from dbt.contracts.graph.manifest import Manifest from dbt.contracts.results import TestStatus, PrimitiveDict, RunResult from dbt.context.providers import generate_runtime_model_context @@ -31,6 +30,7 @@ ResourceTypeSelector, ) from dbt.node_types import NodeType +from dbt.parser.unit_tests import UnitTestManifestLoader from dbt.flags import get_flags @@ -59,10 +59,16 @@ def convert_bool_type(field) -> bool: return bool(field) +@dataclass +class UnitTestResultData(dbtClassMixin): + should_error: bool + adapter_response: Dict[str, Any] + diff: Optional[str] = None + + class TestRunner(CompileRunner): def describe_node(self): - node_name = self.node.name - return "test {}".format(node_name) + return f"{self.node.resource_type} {self.node.name}" def print_result_line(self, result): model = result.node @@ -143,9 +149,79 @@ def execute_test(self, test: TestNode, manifest: Manifest) -> TestResultData: TestResultData.validate(test_result_dct) return TestResultData.from_dict(test_result_dct) - def execute(self, test: TestNode, manifest: Manifest): - result = self.execute_test(test, manifest) + def build_unit_test_manifest_from_test( + self, unit_test_def: UnitTestDefinition, manifest: Manifest + ) -> Manifest: + # build a unit test manifest with only the test from this UnitTestDefinition + loader = UnitTestManifestLoader(manifest, self.config, {unit_test_def.unique_id}) + return loader.load() + + def execute_unit_test( + self, unit_test_def: UnitTestDefinition, manifest: Manifest + ) -> UnitTestResultData: + unit_test_manifest = self.build_unit_test_manifest_from_test(unit_test_def, manifest) + + # The unit test node and definition have the same unique_id + unit_test_node = unit_test_manifest.nodes[unit_test_def.unique_id] + + # Compile the node + compiler = self.adapter.get_compiler() + unit_test_node = compiler.compile_node(unit_test_node, unit_test_manifest, {}) + + # generate_runtime_unit_test_context not strictly needed - this is to run the 'unit' + # materialization, not compile the node.compiled_code + context = generate_runtime_model_context(unit_test_node, self.config, unit_test_manifest) + + materialization_macro = unit_test_manifest.find_materialization_macro_by_name( + self.config.project_name, unit_test_node.get_materialization(), self.adapter.type() + ) + + if materialization_macro is None: + raise MissingMaterializationError( + materialization=unit_test_node.get_materialization(), + adapter_type=self.adapter.type(), + ) + + if "config" not in context: + raise DbtInternalError( + "Invalid materialization context generated, missing config: {}".format(context) + ) + + # generate materialization macro + macro_func = MacroGenerator(materialization_macro, context) + # execute materialization macro + macro_func() + # load results from context + # could eventually be returned directly by materialization + result = context["load_result"]("main") + adapter_response = result["response"].to_dict(omit_none=True) + table = result["table"] + actual = self._get_unit_test_table(table, "actual") + expected = self._get_unit_test_table(table, "expected") + should_error = actual.rows != expected.rows + diff = None + if should_error: + actual_output = self._agate_table_to_str(actual) + expected_output = self._agate_table_to_str(expected) + + diff = f"\n\nActual:\n{actual_output}\n\nExpected:\n{expected_output}\n" + return UnitTestResultData( + diff=diff, + should_error=should_error, + adapter_response=adapter_response, + ) + + def execute(self, test: Union[TestNode, UnitTestDefinition], manifest: Manifest): + if isinstance(test, UnitTestDefinition): + unit_test_result = self.execute_unit_test(test, manifest) + return self.build_unit_test_run_result(test, unit_test_result) + else: + # Note: manifest here is a normal manifest + test_result = self.execute_test(test, manifest) + return self.build_test_run_result(test, test_result) + + def build_test_run_result(self, test: TestNode, result: TestResultData) -> RunResult: severity = test.config.severity.upper() thread_id = threading.current_thread().name num_errors = pluralize(result.failures, "result") @@ -167,6 +243,31 @@ def execute(self, test: TestNode, manifest: Manifest): else: status = TestStatus.Pass + run_result = RunResult( + node=test, + status=status, + timing=[], + thread_id=thread_id, + execution_time=0, + message=message, + adapter_response=result.adapter_response, + failures=failures, + ) + return run_result + + def build_unit_test_run_result( + self, test: UnitTestDefinition, result: UnitTestResultData + ) -> RunResult: + thread_id = threading.current_thread().name + + status = TestStatus.Pass + message = None + failures = 0 + if result.should_error: + status = TestStatus.Fail + message = result.diff + failures = 1 + return RunResult( node=test, status=status, @@ -181,6 +282,24 @@ def execute(self, test: TestNode, manifest: Manifest): def after_execute(self, result): self.print_result_line(result) + def _get_unit_test_table(self, result_table, actual_or_expected: str): + unit_test_table = result_table.where( + lambda row: row["actual_or_expected"] == actual_or_expected + ) + columns = list(unit_test_table.columns.keys()) + columns.remove("actual_or_expected") + return unit_test_table.select(columns) + + def _agate_table_to_str(self, table) -> str: + # Hack to get Agate table output as string + output = io.StringIO() + # "output" is a cli param: show_output_format + if self.config.args.output == "json": + table.to_json(path=output) + else: + table.print_table(output=output, max_rows=None) + return output.getvalue().strip() + class TestSelector(ResourceTypeSelector): def __init__(self, graph, manifest, previous_state) -> None: @@ -188,7 +307,7 @@ def __init__(self, graph, manifest, previous_state) -> None: graph=graph, manifest=manifest, previous_state=previous_state, - resource_types=[NodeType.Test], + resource_types=[NodeType.Test, NodeType.Unit], ) diff --git a/core/dbt/task/unit_test.py b/core/dbt/task/unit_test.py index 4a2edb48ca2..0767f2dd1ef 100644 --- a/core/dbt/task/unit_test.py +++ b/core/dbt/task/unit_test.py @@ -245,7 +245,7 @@ def build_unit_test_manifest(self): loader = UnitTestManifestLoader(self.manifest, self.config, self.job_queue._selected) return loader.load() - def reset_job_queue_and_manifest(self): + def build_unit_test_manifest_from_job_queue(self): # We want deferral to happen here (earlier than normal) before we turn # the normal manifest into the unit testing manifest adapter = get_adapter(self.config) diff --git a/tests/functional/unit_testing/test_csv_fixtures.py b/tests/functional/unit_testing/test_csv_fixtures.py index 2e10a395b83..1a620a6277d 100644 --- a/tests/functional/unit_testing/test_csv_fixtures.py +++ b/tests/functional/unit_testing/test_csv_fixtures.py @@ -38,7 +38,7 @@ def test_unit_test(self, project): assert len(results) == 3 # Select by model name - results = run_dbt(["unit-test", "--select", "my_model"], expect_pass=False) + results = run_dbt(["test", "--select", "my_model"], expect_pass=False) assert len(results) == 5 # Check error with invalid format key @@ -49,7 +49,7 @@ def test_unit_test(self, project): "test_my_model.yml", ) with pytest.raises(YamlParseDictError): - results = run_dbt(["unit-test", "--select", "my_model"], expect_pass=False) + results = run_dbt(["test", "--select", "my_model"], expect_pass=False) # Check error with csv format defined but dict on rows write_file( @@ -59,7 +59,7 @@ def test_unit_test(self, project): "test_my_model.yml", ) with pytest.raises(ParsingError): - results = run_dbt(["unit-test", "--select", "my_model"], expect_pass=False) + results = run_dbt(["test", "--select", "my_model"], expect_pass=False) class TestUnitTestsWithFileCSV: @@ -91,7 +91,7 @@ def test_unit_test(self, project): assert len(results) == 3 # Select by model name - results = run_dbt(["unit-test", "--select", "my_model"], expect_pass=False) + results = run_dbt(["test", "--select", "my_model"], expect_pass=False) assert len(results) == 5 # Check error with invalid format key @@ -102,7 +102,7 @@ def test_unit_test(self, project): "test_my_model.yml", ) with pytest.raises(YamlParseDictError): - results = run_dbt(["unit-test", "--select", "my_model"], expect_pass=False) + results = run_dbt(["test", "--select", "my_model"], expect_pass=False) # Check error with csv format defined but dict on rows write_file( @@ -112,7 +112,7 @@ def test_unit_test(self, project): "test_my_model.yml", ) with pytest.raises(ParsingError): - results = run_dbt(["unit-test", "--select", "my_model"], expect_pass=False) + results = run_dbt(["test", "--select", "my_model"], expect_pass=False) class TestUnitTestsWithMixedCSV: @@ -144,7 +144,7 @@ def test_unit_test(self, project): assert len(results) == 3 # Select by model name - results = run_dbt(["unit-test", "--select", "my_model"], expect_pass=False) + results = run_dbt(["test", "--select", "my_model"], expect_pass=False) assert len(results) == 5 # Check error with invalid format key @@ -155,7 +155,7 @@ def test_unit_test(self, project): "test_my_model.yml", ) with pytest.raises(YamlParseDictError): - results = run_dbt(["unit-test", "--select", "my_model"], expect_pass=False) + results = run_dbt(["test", "--select", "my_model"], expect_pass=False) # Check error with csv format defined but dict on rows write_file( @@ -165,7 +165,7 @@ def test_unit_test(self, project): "test_my_model.yml", ) with pytest.raises(ParsingError): - results = run_dbt(["unit-test", "--select", "my_model"], expect_pass=False) + results = run_dbt(["test", "--select", "my_model"], expect_pass=False) class TestUnitTestsMissingCSVFile: @@ -184,8 +184,8 @@ def test_missing(self, project): # Select by model name expected_error = "Could not find fixture file fake_fixture for unit test" - with pytest.raises(ParsingError, match=expected_error): - results = run_dbt(["unit-test", "--select", "my_model"], expect_pass=False) + results = run_dbt(["test", "--select", "my_model"], expect_pass=False) + assert expected_error in results[0].message class TestUnitTestsDuplicateCSVFile: @@ -216,8 +216,6 @@ def test_duplicate(self, project): assert len(results) == 3 # Select by model name - with pytest.raises(ParsingError) as exc: - results = run_dbt(["unit-test", "--select", "my_model"], expect_pass=False) - expected_error = "Found multiple fixture files named test_my_model_basic_fixture at ['one-folder/test_my_model_basic_fixture.csv', 'another-folder/test_my_model_basic_fixture.csv']" - # doing the match inline with the pytest.raises caused a bad character error with the dashes. So we do it here. - assert exc.match(expected_error) + results = run_dbt(["test", "--select", "my_model"], expect_pass=False) + expected_error = "Found multiple fixture files named test_my_model_basic_fixture at ['one-folder/test_my_model_basic_fixture.csv', 'another-folder/test_my_model_basic_fixture.csv']" + assert expected_error in results[0].message diff --git a/tests/functional/unit_testing/test_state.py b/tests/functional/unit_testing/test_state.py index 71a3992a407..1a1501d05c5 100644 --- a/tests/functional/unit_testing/test_state.py +++ b/tests/functional/unit_testing/test_state.py @@ -55,11 +55,11 @@ def copy_state(self, project_root): class TestUnitTestStateModified(UnitTestState): def test_state_modified(self, project): run_dbt(["run"]) - run_dbt(["unit-test"], expect_pass=False) + run_dbt(["test"], expect_pass=False) self.copy_state(project.project_root) # no changes - results = run_dbt(["unit-test", "--select", "state:modified", "--state", "state"]) + results = run_dbt(["test", "--select", "state:modified", "--state", "state"]) assert len(results) == 0 # change underlying fixture file @@ -72,7 +72,7 @@ def test_state_modified(self, project): ) # TODO: remove --no-partial-parse as part of https://github.com/dbt-labs/dbt-core/issues/9067 results = run_dbt( - ["--no-partial-parse", "unit-test", "--select", "state:modified", "--state", "state"], + ["--no-partial-parse", "test", "--select", "state:modified", "--state", "state"], expect_pass=True, ) assert len(results) == 1 @@ -84,7 +84,7 @@ def test_state_modified(self, project): with_changes = test_my_model_simple_fixture_yml.replace("{string_c: ab}", "{string_c: bc}") write_config_file(with_changes, project.project_root, "models", "test_my_model.yml") results = run_dbt( - ["unit-test", "--select", "state:modified", "--state", "state"], expect_pass=False + ["test", "--select", "state:modified", "--state", "state"], expect_pass=False ) assert len(results) == 1 assert results[0].node.name.endswith("test_has_string_c_ab") @@ -100,7 +100,7 @@ def test_state_modified(self, project): "my_model.sql", ) results = run_dbt( - ["unit-test", "--select", "state:modified", "--state", "state"], expect_pass=False + ["test", "--select", "state:modified", "--state", "state"], expect_pass=False ) assert len(results) == 4 @@ -108,7 +108,7 @@ def test_state_modified(self, project): class TestUnitTestRetry(UnitTestState): def test_unit_test_retry(self, project): run_dbt(["run"]) - run_dbt(["unit-test"], expect_pass=False) + run_dbt(["test"], expect_pass=False) self.copy_state(project.project_root) results = run_dbt(["retry"], expect_pass=False) @@ -130,6 +130,6 @@ def profiles_config_update(self, dbt_profile_target, unique_schema, other_schema def test_unit_test_defer_state(self, project): run_dbt(["run", "--target", "otherschema"]) self.copy_state(project.project_root) - results = run_dbt(["unit-test", "--defer", "--state", "state"], expect_pass=False) + results = run_dbt(["test", "--defer", "--state", "state"], expect_pass=False) assert len(results) == 4 assert sorted([r.status for r in results]) == ["fail", "pass", "pass", "pass"] diff --git a/tests/functional/unit_testing/test_unit_testing.py b/tests/functional/unit_testing/test_unit_testing.py index 815881da12b..1154750e5e1 100644 --- a/tests/functional/unit_testing/test_unit_testing.py +++ b/tests/functional/unit_testing/test_unit_testing.py @@ -3,7 +3,6 @@ run_dbt, write_file, get_manifest, - get_artifact, ) from dbt.exceptions import DuplicateResourceNameError, ParsingError from fixtures import ( @@ -37,31 +36,31 @@ def test_basic(self, project): assert len(results) == 3 # Select by model name - results = run_dbt(["unit-test", "--select", "my_model"], expect_pass=False) + results = run_dbt(["test", "--select", "my_model"], expect_pass=False) assert len(results) == 5 # Test select by test name - results = run_dbt(["unit-test", "--select", "test_name:test_my_model_string_concat"]) + results = run_dbt(["test", "--select", "test_name:test_my_model_string_concat"]) assert len(results) == 1 # Select, method not specified - results = run_dbt(["unit-test", "--select", "test_my_model_overrides"]) + results = run_dbt(["test", "--select", "test_my_model_overrides"]) assert len(results) == 1 # Select using tag - results = run_dbt(["unit-test", "--select", "tag:test_this"]) + results = run_dbt(["test", "--select", "tag:test_this"]) assert len(results) == 1 # Partial parsing... remove test write_file(test_my_model_yml, project.project_root, "models", "test_my_model.yml") - results = run_dbt(["unit-test", "--select", "my_model"], expect_pass=False) + results = run_dbt(["test", "--select", "my_model"], expect_pass=False) assert len(results) == 4 # Partial parsing... put back removed test write_file( test_my_model_yml + datetime_test, project.project_root, "models", "test_my_model.yml" ) - results = run_dbt(["unit-test", "--select", "my_model"], expect_pass=False) + results = run_dbt(["test", "--select", "my_model"], expect_pass=False) assert len(results) == 5 manifest = get_manifest(project.project_root) @@ -70,12 +69,6 @@ def test_basic(self, project): for unit_test_definition in manifest.unit_tests.values(): assert unit_test_definition.depends_on.nodes[0] == "model.test.my_model" - # We should have a UnitTestNode for every test, plus two input models for each test - unit_test_manifest = get_artifact( - project.project_root, "target", "unit_test_manifest.json" - ) - assert len(unit_test_manifest["nodes"]) == 15 - # Check for duplicate unit test name # this doesn't currently pass with partial parsing because of the root problem # described in https://github.com/dbt-labs/dbt-core/issues/8982 @@ -103,7 +96,7 @@ def test_basic(self, project): assert len(results) == 2 # Select by model name - results = run_dbt(["unit-test", "--select", "my_incremental_model"], expect_pass=True) + results = run_dbt(["test", "--select", "my_incremental_model"], expect_pass=True) assert len(results) == 2 diff --git a/tests/functional/unit_testing/test_ut_sources.py b/tests/functional/unit_testing/test_ut_sources.py index 6178d0f4b4e..68b7ed12ff1 100644 --- a/tests/functional/unit_testing/test_ut_sources.py +++ b/tests/functional/unit_testing/test_ut_sources.py @@ -65,5 +65,6 @@ def test_source_input(self, project): results = run_dbt(["run"]) len(results) == 1 - results = run_dbt(["unit-test"]) - assert len(results) == 1 + results = run_dbt(["test"]) + # following includes 2 non-unit tests + assert len(results) == 3 From 286bbec08cabd137917004d5bc36e04d9a615aec Mon Sep 17 00:00:00 2001 From: Gerda Shank Date: Thu, 16 Nov 2023 13:58:35 -0500 Subject: [PATCH 02/16] Remove old unit test --- core/dbt/cli/flags.py | 1 - core/dbt/cli/main.py | 45 --- core/dbt/cli/types.py | 1 - core/dbt/constants.py | 1 - core/dbt/task/retry.py | 3 - core/dbt/task/unit_test.py | 279 ------------------ .../unit_testing/test_unit_testing.py | 6 +- 7 files changed, 3 insertions(+), 333 deletions(-) delete mode 100644 core/dbt/task/unit_test.py diff --git a/core/dbt/cli/flags.py b/core/dbt/cli/flags.py index f430dcd70c2..2678d53b6dd 100644 --- a/core/dbt/cli/flags.py +++ b/core/dbt/cli/flags.py @@ -396,7 +396,6 @@ def command_args(command: CliCommand) -> ArgsList: CliCommand.SOURCE_FRESHNESS: cli.freshness, CliCommand.TEST: cli.test, CliCommand.RETRY: cli.retry, - CliCommand.UNIT_TEST: cli.unit_test, } click_cmd: Optional[ClickCommand] = CMD_DICT.get(command, None) if click_cmd is None: diff --git a/core/dbt/cli/main.py b/core/dbt/cli/main.py index ad82c2688c1..db8158399a8 100644 --- a/core/dbt/cli/main.py +++ b/core/dbt/cli/main.py @@ -40,7 +40,6 @@ from dbt.task.show import ShowTask from dbt.task.snapshot import SnapshotTask from dbt.task.test import TestTask -from dbt.task.unit_test import UnitTestTask @dataclass @@ -898,50 +897,6 @@ def test(ctx, **kwargs): return results, success -# dbt unit-test -@cli.command("unit-test") -@click.pass_context -@p.defer -@p.deprecated_defer -@p.exclude -@p.fail_fast -@p.favor_state -@p.deprecated_favor_state -@p.indirect_selection -@p.show_output_format -@p.profile -@p.profiles_dir -@p.project_dir -@p.select -@p.selector -@p.state -@p.defer_state -@p.deprecated_state -@p.store_failures -@p.target -@p.target_path -@p.threads -@p.vars -@p.version_check -@requires.postflight -@requires.preflight -@requires.profile -@requires.project -@requires.runtime_config -@requires.manifest -def unit_test(ctx, **kwargs): - """Runs tests on data in deployed models. Run this after `dbt run`""" - task = UnitTestTask( - ctx.obj["flags"], - ctx.obj["runtime_config"], - ctx.obj["manifest"], - ) - - results = task.run() - success = task.interpret_results(results) - return results, success - - # Support running as a module if __name__ == "__main__": cli() diff --git a/core/dbt/cli/types.py b/core/dbt/cli/types.py index 986a43055a4..14028a69451 100644 --- a/core/dbt/cli/types.py +++ b/core/dbt/cli/types.py @@ -24,7 +24,6 @@ class Command(Enum): SOURCE_FRESHNESS = "freshness" TEST = "test" RETRY = "retry" - UNIT_TEST = "unit-test" @classmethod def from_str(cls, s: str) -> "Command": diff --git a/core/dbt/constants.py b/core/dbt/constants.py index 9c8cec08c04..3e485868df2 100644 --- a/core/dbt/constants.py +++ b/core/dbt/constants.py @@ -16,5 +16,4 @@ MANIFEST_FILE_NAME = "manifest.json" SEMANTIC_MANIFEST_FILE_NAME = "semantic_manifest.json" PARTIAL_PARSE_FILE_NAME = "partial_parse.msgpack" -UNIT_TEST_MANIFEST_FILE_NAME = "unit_test_manifest.json" PACKAGE_LOCK_HASH_KEY = "sha1_hash" diff --git a/core/dbt/task/retry.py b/core/dbt/task/retry.py index ccbe24bf1f1..af6a46b776e 100644 --- a/core/dbt/task/retry.py +++ b/core/dbt/task/retry.py @@ -17,7 +17,6 @@ from dbt.task.seed import SeedTask from dbt.task.snapshot import SnapshotTask from dbt.task.test import TestTask -from dbt.task.unit_test import UnitTestTask RETRYABLE_STATUSES = {NodeStatus.Error, NodeStatus.Fail, NodeStatus.Skipped, NodeStatus.RuntimeErr} OVERRIDE_PARENT_FLAGS = { @@ -41,7 +40,6 @@ "test": TestTask, "run": RunTask, "run-operation": RunOperationTask, - "unit-test": UnitTestTask, } CMD_DICT = { @@ -54,7 +52,6 @@ "test": CliCommand.TEST, "run": CliCommand.RUN, "run-operation": CliCommand.RUN_OPERATION, - "unit-test": CliCommand.UNIT_TEST, } diff --git a/core/dbt/task/unit_test.py b/core/dbt/task/unit_test.py deleted file mode 100644 index 0767f2dd1ef..00000000000 --- a/core/dbt/task/unit_test.py +++ /dev/null @@ -1,279 +0,0 @@ -import agate -from dataclasses import dataclass -from dbt.dataclass_schema import dbtClassMixin -import daff -import threading -import re -from typing import Dict, Any, Optional, AbstractSet, List - -from .compile import CompileRunner -from .run import RunTask - -from dbt.adapters.factory import get_adapter -from dbt.clients.agate_helper import list_rows_from_table, json_rows_from_table -from dbt.contracts.graph.nodes import UnitTestNode -from dbt.contracts.graph.manifest import Manifest -from dbt.contracts.results import TestStatus, RunResult -from dbt.context.providers import generate_runtime_model_context -from dbt.clients.jinja import MacroGenerator -from dbt.events.functions import fire_event -from dbt.events.types import ( - LogTestResult, - LogStartLine, -) -from dbt.graph import ResourceTypeSelector -from dbt.exceptions import ( - DbtInternalError, - MissingMaterializationError, -) -from dbt.node_types import NodeType -from dbt.parser.unit_tests import UnitTestManifestLoader -from dbt.ui import green, red - - -@dataclass -class UnitTestDiff(dbtClassMixin): - actual: List[Dict[str, Any]] - expected: List[Dict[str, Any]] - rendered: str - - -@dataclass -class UnitTestResultData(dbtClassMixin): - should_error: bool - adapter_response: Dict[str, Any] - diff: Optional[UnitTestDiff] = None - - -class UnitTestRunner(CompileRunner): - _ANSI_ESCAPE = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") - - def describe_node(self): - return f"{self.node.resource_type} {self.node.name}" - - def print_result_line(self, result): - model = result.node - - fire_event( - LogTestResult( - name=model.name, - status=str(result.status), - index=self.node_index, - num_models=self.num_nodes, - execution_time=result.execution_time, - node_info=model.node_info, - num_failures=result.failures, - ), - level=LogTestResult.status_to_level(str(result.status)), - ) - - def print_start_line(self): - fire_event( - LogStartLine( - description=self.describe_node(), - index=self.node_index, - total=self.num_nodes, - node_info=self.node.node_info, - ) - ) - - def before_execute(self): - self.print_start_line() - - def execute_unit_test(self, node: UnitTestNode, manifest: Manifest) -> UnitTestResultData: - # generate_runtime_unit_test_context not strictly needed - this is to run the 'unit' - # materialization, not compile the node.compiled_code - context = generate_runtime_model_context(node, self.config, manifest) - - materialization_macro = manifest.find_materialization_macro_by_name( - self.config.project_name, node.get_materialization(), self.adapter.type() - ) - - if materialization_macro is None: - raise MissingMaterializationError( - materialization=node.get_materialization(), adapter_type=self.adapter.type() - ) - - if "config" not in context: - raise DbtInternalError( - "Invalid materialization context generated, missing config: {}".format(context) - ) - - # generate materialization macro - macro_func = MacroGenerator(materialization_macro, context) - # execute materialization macro - macro_func() - # load results from context - # could eventually be returned directly by materialization - result = context["load_result"]("main") - adapter_response = result["response"].to_dict(omit_none=True) - table = result["table"] - actual = self._get_unit_test_agate_table(table, "actual") - expected = self._get_unit_test_agate_table(table, "expected") - - # generate diff, if exists - should_error, diff = False, None - daff_diff = self._get_daff_diff(expected, actual) - if daff_diff.hasDifference(): - should_error = True - rendered = self._render_daff_diff(daff_diff) - rendered = f"\n\n{red('expected')} differs from {green('actual')}:\n\n{rendered}\n" - - diff = UnitTestDiff( - actual=json_rows_from_table(actual), - expected=json_rows_from_table(expected), - rendered=rendered, - ) - - return UnitTestResultData( - diff=diff, - should_error=should_error, - adapter_response=adapter_response, - ) - - def execute(self, node: UnitTestNode, manifest: Manifest): - result = self.execute_unit_test(node, manifest) - thread_id = threading.current_thread().name - - status = TestStatus.Pass - message = None - failures = 0 - if result.should_error: - status = TestStatus.Fail - message = result.diff.rendered if result.diff else None - failures = 1 - - return RunResult( - node=node, - status=status, - timing=[], - thread_id=thread_id, - execution_time=0, - message=message, - adapter_response=result.adapter_response, - failures=failures, - ) - - def after_execute(self, result): - self.print_result_line(result) - - def _get_unit_test_agate_table(self, result_table, actual_or_expected: str) -> agate.Table: - unit_test_table = result_table.where( - lambda row: row["actual_or_expected"] == actual_or_expected - ) - columns = list(unit_test_table.columns.keys()) - columns.remove("actual_or_expected") - return unit_test_table.select(columns) - - def _get_daff_diff( - self, expected: agate.Table, actual: agate.Table, ordered: bool = False - ) -> daff.TableDiff: - - expected_daff_table = daff.PythonTableView(list_rows_from_table(expected)) - actual_daff_table = daff.PythonTableView(list_rows_from_table(actual)) - - alignment = daff.Coopy.compareTables(expected_daff_table, actual_daff_table).align() - result = daff.PythonTableView([]) - - flags = daff.CompareFlags() - flags.ordered = ordered - - diff = daff.TableDiff(alignment, flags) - diff.hilite(result) - return diff - - def _render_daff_diff(self, daff_diff: daff.TableDiff) -> str: - result = daff.PythonTableView([]) - daff_diff.hilite(result) - rendered = daff.TerminalDiffRender().render(result) - # strip colors if necessary - if not self.config.args.use_colors: - rendered = self._ANSI_ESCAPE.sub("", rendered) - - return rendered - - -class UnitTestSelector(ResourceTypeSelector): - # This is what filters out nodes except Unit Tests, in filter_selection - def __init__(self, graph, manifest, previous_state): - super().__init__( - graph=graph, - manifest=manifest, - previous_state=previous_state, - resource_types=[NodeType.Unit], - ) - - -class UnitTestTask(RunTask): - """ - Unit testing: - Read schema files + custom data tests and validate that - constraints are satisfied. - """ - - def __init__(self, args, config, manifest): - # This will initialize the RunTask with the regular manifest - super().__init__(args, config, manifest) - # TODO: We might not need this, but leaving here for now. - self.original_manifest = manifest - self.using_unit_test_manifest = False - - __test__ = False - - def raise_on_first_error(self): - return False - - @property - def selection_arg(self): - if self.using_unit_test_manifest is False: - return self.args.select - else: - # Everything in the unit test should be selected, since we - # created in from a selection list. - return () - - @property - def exclusion_arg(self): - if self.using_unit_test_manifest is False: - return self.args.exclude - else: - # Everything in the unit test should be selected, since we - # created in from a selection list. - return () - - def build_unit_test_manifest(self): - loader = UnitTestManifestLoader(self.manifest, self.config, self.job_queue._selected) - return loader.load() - - def build_unit_test_manifest_from_job_queue(self): - # We want deferral to happen here (earlier than normal) before we turn - # the normal manifest into the unit testing manifest - adapter = get_adapter(self.config) - with adapter.connection_named("master"): - self.populate_adapter_cache(adapter) - self.defer_to_manifest(adapter, self.job_queue._selected) - - # We have the selected models from the "regular" manifest, now we switch - # to using the unit_test_manifest to run the unit tests. - self.using_unit_test_manifest = True - self.manifest = self.build_unit_test_manifest() - self.compile_manifest() # create the networkx graph - self.job_queue = self.get_graph_queue() - - def before_run(self, adapter, selected_uids: AbstractSet[str]) -> None: - # We already did cache population + deferral earlier (in reset_job_queue_and_manifest) - # and we don't need to create any schemas - pass - - def get_node_selector(self) -> ResourceTypeSelector: - if self.manifest is None or self.graph is None: - raise DbtInternalError("manifest and graph must be set to get perform node selection") - # Filter out everything except unit tests - return UnitTestSelector( - graph=self.graph, - manifest=self.manifest, - previous_state=self.previous_state, - ) - - def get_runner_type(self, _): - return UnitTestRunner diff --git a/tests/functional/unit_testing/test_unit_testing.py b/tests/functional/unit_testing/test_unit_testing.py index 1154750e5e1..058ee0837b5 100644 --- a/tests/functional/unit_testing/test_unit_testing.py +++ b/tests/functional/unit_testing/test_unit_testing.py @@ -186,7 +186,7 @@ def test_explicit_seed(self, project): run_dbt(["run"]) # Select by model name - results = run_dbt(["unit-test", "--select", "my_new_model"], expect_pass=True) + results = run_dbt(["test", "--select", "my_new_model"], expect_pass=True) assert len(results) == 1 @@ -208,7 +208,7 @@ def test_implicit_seed(self, project): run_dbt(["run"]) # Select by model name - results = run_dbt(["unit-test", "--select", "my_new_model"], expect_pass=True) + results = run_dbt(["test", "--select", "my_new_model"], expect_pass=True) assert len(results) == 1 @@ -229,4 +229,4 @@ def test_nonexistent_seed(self, project): with pytest.raises( ParsingError, match="Unable to find seed 'test.my_second_favorite_seed' for unit tests" ): - run_dbt(["unit-test", "--select", "my_new_model"], expect_pass=False) + run_dbt(["test", "--select", "my_new_model"], expect_pass=False) From 251f875b40db420ab11110a9b2f2adca4d515b3a Mon Sep 17 00:00:00 2001 From: Gerda Shank Date: Thu, 16 Nov 2023 14:38:56 -0500 Subject: [PATCH 03/16] Put daff changes into task/test.py --- core/dbt/task/test.py | 84 +++++++++++++++++++++++++++++++------------ 1 file changed, 61 insertions(+), 23 deletions(-) diff --git a/core/dbt/task/test.py b/core/dbt/task/test.py index a6090cd5bc3..6f88f4417fd 100644 --- a/core/dbt/task/test.py +++ b/core/dbt/task/test.py @@ -1,12 +1,14 @@ from distutils.util import strtobool -import io +import agate +import daff +import re from dataclasses import dataclass from dbt.utils import _coerce_decimal from dbt.events.format import pluralize from dbt.dataclass_schema import dbtClassMixin import threading -from typing import Dict, Any, Optional, Union +from typing import Dict, Any, Optional, Union, List from .compile import CompileRunner from .run import RunTask @@ -16,6 +18,7 @@ from dbt.contracts.results import TestStatus, PrimitiveDict, RunResult from dbt.context.providers import generate_runtime_model_context from dbt.clients.jinja import MacroGenerator +from dbt.clients.agate_helper import list_rows_from_table, json_rows_from_table from dbt.events.functions import fire_event from dbt.events.types import ( LogTestResult, @@ -32,6 +35,14 @@ from dbt.node_types import NodeType from dbt.parser.unit_tests import UnitTestManifestLoader from dbt.flags import get_flags +from dbt.ui import green, red + + +@dataclass +class UnitTestDiff(dbtClassMixin): + actual: List[Dict[str, Any]] + expected: List[Dict[str, Any]] + rendered: str @dataclass @@ -63,10 +74,12 @@ def convert_bool_type(field) -> bool: class UnitTestResultData(dbtClassMixin): should_error: bool adapter_response: Dict[str, Any] - diff: Optional[str] = None + diff: Optional[UnitTestDiff] = None class TestRunner(CompileRunner): + _ANSI_ESCAPE = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") + def describe_node(self): return f"{self.node.resource_type} {self.node.name}" @@ -197,15 +210,23 @@ def execute_unit_test( result = context["load_result"]("main") adapter_response = result["response"].to_dict(omit_none=True) table = result["table"] - actual = self._get_unit_test_table(table, "actual") - expected = self._get_unit_test_table(table, "expected") - should_error = actual.rows != expected.rows - diff = None - if should_error: - actual_output = self._agate_table_to_str(actual) - expected_output = self._agate_table_to_str(expected) - - diff = f"\n\nActual:\n{actual_output}\n\nExpected:\n{expected_output}\n" + actual = self._get_unit_test_agate_table(table, "actual") + expected = self._get_unit_test_agate_table(table, "expected") + + # generate diff, if exists + should_error, diff = False, None + daff_diff = self._get_daff_diff(expected, actual) + if daff_diff.hasDifference(): + should_error = True + rendered = self._render_daff_diff(daff_diff) + rendered = f"\n\n{red('expected')} differs from {green('actual')}:\n\n{rendered}\n" + + diff = UnitTestDiff( + actual=json_rows_from_table(actual), + expected=json_rows_from_table(expected), + rendered=rendered, + ) + return UnitTestResultData( diff=diff, should_error=should_error, @@ -265,7 +286,7 @@ def build_unit_test_run_result( failures = 0 if result.should_error: status = TestStatus.Fail - message = result.diff + message = result.diff.rendered if result.diff else None failures = 1 return RunResult( @@ -282,7 +303,7 @@ def build_unit_test_run_result( def after_execute(self, result): self.print_result_line(result) - def _get_unit_test_table(self, result_table, actual_or_expected: str): + def _get_unit_test_agate_table(self, result_table, actual_or_expected: str): unit_test_table = result_table.where( lambda row: row["actual_or_expected"] == actual_or_expected ) @@ -290,15 +311,32 @@ def _get_unit_test_table(self, result_table, actual_or_expected: str): columns.remove("actual_or_expected") return unit_test_table.select(columns) - def _agate_table_to_str(self, table) -> str: - # Hack to get Agate table output as string - output = io.StringIO() - # "output" is a cli param: show_output_format - if self.config.args.output == "json": - table.to_json(path=output) - else: - table.print_table(output=output, max_rows=None) - return output.getvalue().strip() + def _get_daff_diff( + self, expected: agate.Table, actual: agate.Table, ordered: bool = False + ) -> daff.TableDiff: + + expected_daff_table = daff.PythonTableView(list_rows_from_table(expected)) + actual_daff_table = daff.PythonTableView(list_rows_from_table(actual)) + + alignment = daff.Coopy.compareTables(expected_daff_table, actual_daff_table).align() + result = daff.PythonTableView([]) + + flags = daff.CompareFlags() + flags.ordered = ordered + + diff = daff.TableDiff(alignment, flags) + diff.hilite(result) + return diff + + def _render_daff_diff(self, daff_diff: daff.TableDiff) -> str: + result = daff.PythonTableView([]) + daff_diff.hilite(result) + rendered = daff.TerminalDiffRender().render(result) + # strip colors if necessary + if not self.config.args.use_colors: + rendered = self._ANSI_ESCAPE.sub("", rendered) + + return rendered class TestSelector(ResourceTypeSelector): From 8c366548cb849f788c591078e6b8a403b60ef68b Mon Sep 17 00:00:00 2001 From: Gerda Shank Date: Thu, 16 Nov 2023 14:40:14 -0500 Subject: [PATCH 04/16] changie --- .changes/unreleased/Features-20231116-144006.yaml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 .changes/unreleased/Features-20231116-144006.yaml diff --git a/.changes/unreleased/Features-20231116-144006.yaml b/.changes/unreleased/Features-20231116-144006.yaml new file mode 100644 index 00000000000..b70e89e76ec --- /dev/null +++ b/.changes/unreleased/Features-20231116-144006.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Move unit testing to test command +time: 2023-11-16T14:40:06.121336-05:00 +custom: + Author: gshank + Issue: "8979" From 6e994e750586399da2132cb77197a35fb3989c99 Mon Sep 17 00:00:00 2001 From: Gerda Shank Date: Thu, 16 Nov 2023 15:54:19 -0500 Subject: [PATCH 05/16] test_type:unit --- core/dbt/graph/selector_methods.py | 8 +++++--- tests/functional/unit_testing/test_ut_sources.py | 5 ++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/core/dbt/graph/selector_methods.py b/core/dbt/graph/selector_methods.py index c8007cdab4f..a44ae59ea95 100644 --- a/core/dbt/graph/selector_methods.py +++ b/core/dbt/graph/selector_methods.py @@ -558,14 +558,16 @@ def search(self, included_nodes: Set[UniqueId], selector: str) -> Iterator[Uniqu search_type = GenericTestNode elif selector in ("singular", "data"): search_type = SingularTestNode + elif selector in ("unit"): + search_type = UnitTestDefinition else: raise DbtRuntimeError( f'Invalid test type selector {selector}: expected "generic" or ' '"singular"' ) - for node, real_node in self.parsed_nodes(included_nodes): - if isinstance(real_node, search_type): - yield node + for unique_id, node in self.parsed_and_unit_nodes(included_nodes): + if isinstance(node, search_type): + yield unique_id class StateSelectorMethod(SelectorMethod): diff --git a/tests/functional/unit_testing/test_ut_sources.py b/tests/functional/unit_testing/test_ut_sources.py index 68b7ed12ff1..06a11547ea6 100644 --- a/tests/functional/unit_testing/test_ut_sources.py +++ b/tests/functional/unit_testing/test_ut_sources.py @@ -65,6 +65,5 @@ def test_source_input(self, project): results = run_dbt(["run"]) len(results) == 1 - results = run_dbt(["test"]) - # following includes 2 non-unit tests - assert len(results) == 3 + results = run_dbt(["test", "--select", "test_type:unit"]) + assert len(results) == 1 From fb5f47162d4dd2dafd709bc2161d4e73092a17f1 Mon Sep 17 00:00:00 2001 From: Gerda Shank Date: Thu, 16 Nov 2023 16:08:27 -0500 Subject: [PATCH 06/16] Add unit test to build and make test --- core/dbt/task/build.py | 1 + tests/functional/unit_testing/test_ut_sources.py | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/core/dbt/task/build.py b/core/dbt/task/build.py index b1f1c1beb56..f6217736fba 100644 --- a/core/dbt/task/build.py +++ b/core/dbt/task/build.py @@ -84,6 +84,7 @@ class BuildTask(RunTask): NodeType.Snapshot: snapshot_model_runner, NodeType.Seed: seed_runner, NodeType.Test: test_runner, + NodeType.Unit: test_runner, } ALL_RESOURCE_VALUES = frozenset({x for x in RUNNER_MAP.keys()}) diff --git a/tests/functional/unit_testing/test_ut_sources.py b/tests/functional/unit_testing/test_ut_sources.py index 06a11547ea6..8719c9f1677 100644 --- a/tests/functional/unit_testing/test_ut_sources.py +++ b/tests/functional/unit_testing/test_ut_sources.py @@ -67,3 +67,9 @@ def test_source_input(self, project): results = run_dbt(["test", "--select", "test_type:unit"]) assert len(results) == 1 + + results = run_dbt(["build"]) + assert len(results) == 5 + result_unique_ids = [result.node.unique_id for result in results] + assert len(result_unique_ids) == 5 + assert "unit_test.test.customers.test_customers" in result_unique_ids From 8445a000a9c165f2ca59af9cb2c5b32057ac5216 Mon Sep 17 00:00:00 2001 From: Gerda Shank Date: Thu, 16 Nov 2023 16:26:27 -0500 Subject: [PATCH 07/16] Select test_type:data --- core/dbt/graph/selector_methods.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/core/dbt/graph/selector_methods.py b/core/dbt/graph/selector_methods.py index a44ae59ea95..ba931c65dd9 100644 --- a/core/dbt/graph/selector_methods.py +++ b/core/dbt/graph/selector_methods.py @@ -552,21 +552,23 @@ class TestTypeSelectorMethod(SelectorMethod): __test__ = False def search(self, included_nodes: Set[UniqueId], selector: str) -> Iterator[UniqueId]: - search_type: Type + search_types: List[Any] # continue supporting 'schema' + 'data' for backwards compatibility if selector in ("generic", "schema"): - search_type = GenericTestNode - elif selector in ("singular", "data"): - search_type = SingularTestNode + search_types = [GenericTestNode] + elif selector in ("data"): + search_types = [GenericTestNode, SingularTestNode] + elif selector in ("singular"): + search_types = [SingularTestNode] elif selector in ("unit"): - search_type = UnitTestDefinition + search_types = [UnitTestDefinition] else: raise DbtRuntimeError( - f'Invalid test type selector {selector}: expected "generic" or ' '"singular"' + f'Invalid test type selector {selector}: expected "generic", "singular", "unit", or "data"' ) for unique_id, node in self.parsed_and_unit_nodes(included_nodes): - if isinstance(node, search_type): + if isinstance(node, tuple(search_types)): yield unique_id From 661574d22ff3fce7fa6ec77e45f2f249c6a42765 Mon Sep 17 00:00:00 2001 From: Gerda Shank Date: Fri, 17 Nov 2023 16:02:43 -0500 Subject: [PATCH 08/16] Add unit tets to test_graph_selector_methods.py --- tests/unit/test_graph_selector_methods.py | 74 ++++++++++++++++++++--- 1 file changed, 66 insertions(+), 8 deletions(-) diff --git a/tests/unit/test_graph_selector_methods.py b/tests/unit/test_graph_selector_methods.py index 7cc863c52c8..e5a208c1ad1 100644 --- a/tests/unit/test_graph_selector_methods.py +++ b/tests/unit/test_graph_selector_methods.py @@ -28,10 +28,16 @@ TestMetadata, ColumnInfo, AccessType, + UnitTestDefinition, ) from dbt.contracts.graph.manifest import Manifest, ManifestMetadata from dbt.contracts.graph.saved_queries import QueryParams -from dbt.contracts.graph.unparsed import ExposureType, Owner +from dbt.contracts.graph.unparsed import ( + ExposureType, + Owner, + UnitTestInputFixture, + UnitTestOutputFixture, +) from dbt.contracts.state import PreviousState from dbt.node_types import NodeType from dbt.graph.selector_methods import ( @@ -223,16 +229,16 @@ def make_macro(pkg, name, macro_sql, path=None, depends_on_macros=None): def make_unique_test(pkg, test_model, column_name, path=None, refs=None, sources=None, tags=None): - return make_schema_test(pkg, "unique", test_model, {}, column_name=column_name) + return make_generic_test(pkg, "unique", test_model, {}, column_name=column_name) def make_not_null_test( pkg, test_model, column_name, path=None, refs=None, sources=None, tags=None ): - return make_schema_test(pkg, "not_null", test_model, {}, column_name=column_name) + return make_generic_test(pkg, "not_null", test_model, {}, column_name=column_name) -def make_schema_test( +def make_generic_test( pkg, test_name, test_model, @@ -323,7 +329,33 @@ def make_schema_test( ) -def make_data_test( +def make_unit_test( + pkg, + test_name, + test_model, +): + input_fixture = UnitTestInputFixture( + input="ref('table_model')", + rows=[{"id": 1, "string_a": "a"}], + ) + output_fixture = UnitTestOutputFixture( + rows=[{"id": 1, "string_a": "a"}], + ) + return UnitTestDefinition( + name=test_name, + model=test_model, + package_name=pkg, + resource_type=NodeType.Unit, + path="unit_tests.yml", + original_file_path="models/unit_tests.yml", + unique_id=f"unit.{pkg}.{test_model.name}__{test_name}", + given=[input_fixture], + expect=output_fixture, + fqn=[pkg, test_model.name, test_name], + ) + + +def make_singular_test( pkg, name, sql, refs=None, sources=None, tags=None, path=None, config_kwargs=None ): @@ -746,7 +778,7 @@ def ext_source_id_unique(ext_source): @pytest.fixture def view_test_nothing(view_model): - return make_data_test( + return make_singular_test( "pkg", "view_test_nothing", 'select * from {{ ref("view_model") }} limit 0', @@ -754,6 +786,15 @@ def view_test_nothing(view_model): ) +@pytest.fixture +def unit_test_table_model(table_model): + return make_unit_test( + "pkg", + "unit_test_table_model", + table_model, + ) + + # Support dots as namespace separators @pytest.fixture def namespaced_seed(): @@ -818,6 +859,7 @@ def manifest( macro_default_test_unique, macro_test_not_null, macro_default_test_not_null, + unit_test_table_model, ): nodes = [ seed, @@ -849,10 +891,12 @@ def manifest( macro_test_not_null, macro_default_test_not_null, ] + unit_tests = [unit_test_table_model] manifest = Manifest( nodes={n.unique_id: n for n in nodes}, sources={s.unique_id: s for s in sources}, macros={m.unique_id: m for m in macros}, + unit_tests={t.unique_id: t for t in unit_tests}, semantic_models={}, docs={}, files={}, @@ -873,7 +917,8 @@ def search_manifest_using_method(manifest, method, selection): | set(manifest.exposures) | set(manifest.metrics) | set(manifest.semantic_models) - | set(manifest.saved_queries), + | set(manifest.saved_queries) + | set(manifest.unit_tests), selection, ) results = {manifest.expect(uid).search_name for uid in selected} @@ -908,6 +953,7 @@ def test_select_fqn(manifest): "mynamespace.union_model", "mynamespace.ephemeral_model", "mynamespace.seed", + "unit_test_table_model", } assert search_manifest_using_method(manifest, method, "ext") == {"ext_model"} # versions @@ -934,6 +980,7 @@ def test_select_fqn(manifest): "mynamespace.union_model", "mynamespace.ephemeral_model", "union_model", + "unit_test_table_model", } # multiple wildcards assert search_manifest_using_method(manifest, method, "*unions*") == { @@ -947,6 +994,7 @@ def test_select_fqn(manifest): "table_model", "table_model_py", "table_model_csv", + "unit_test_table_model", } # wildcard and ? (matches exactly one character) assert search_manifest_using_method(manifest, method, "*ext_m?del") == {"ext_model"} @@ -1143,6 +1191,7 @@ def test_select_package(manifest): "mynamespace.seed", "mynamespace.ephemeral_model", "mynamespace.union_model", + "unit_test_table_model", } assert search_manifest_using_method(manifest, method, "ext") == { "ext_model", @@ -1255,7 +1304,16 @@ def test_select_test_type(manifest): "unique_view_model_id", "unique_ext_raw_ext_source_id", } - assert search_manifest_using_method(manifest, method, "data") == {"view_test_nothing"} + assert search_manifest_using_method(manifest, method, "data") == { + "view_test_nothing", + "unique_table_model_id", + "not_null_table_model_id", + "unique_view_model_id", + "unique_ext_raw_ext_source_id", + } + assert search_manifest_using_method(manifest, method, "unit") == { + "unit_test_table_model", + } def test_select_version(manifest): From ef90c98f2947dc389f0da0b4f622ace237136ebc Mon Sep 17 00:00:00 2001 From: Gerda Shank Date: Tue, 21 Nov 2023 10:12:29 -0500 Subject: [PATCH 09/16] Fix fqn to incude path components --- core/dbt/parser/unit_tests.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/core/dbt/parser/unit_tests.py b/core/dbt/parser/unit_tests.py index 667394f7162..321b22f1cf5 100644 --- a/core/dbt/parser/unit_tests.py +++ b/core/dbt/parser/unit_tests.py @@ -1,9 +1,11 @@ from csv import DictReader from pathlib import Path from typing import List, Set, Dict, Any +import os from dbt_extractor import py_extract_from_source, ExtractionError # type: ignore +from dbt import utils from dbt.config import RuntimeConfig from dbt.context.context_config import ContextConfig from dbt.context.providers import generate_parse_exposure, get_rendered @@ -253,12 +255,16 @@ def _load_rows_from_seed(self, ref_str: str) -> List[Dict[str, Any]]: def parse(self) -> ParseResult: for data in self.get_key_dicts(): unit_test = self._get_unit_test(data) - model_name_split = unit_test.model.split() tested_model_node = self._find_tested_model_node(unit_test) unit_test_case_unique_id = ( f"{NodeType.Unit}.{self.project.project_name}.{unit_test.model}.{unit_test.name}" ) - unit_test_fqn = [self.project.project_name] + model_name_split + [unit_test.name] + unit_test_fqn = self._build_fqn( + self.project.project_name, + self.yaml.path.original_file_path, + unit_test.model, + unit_test.name, + ) unit_test_config = self._build_unit_test_config(unit_test_fqn, unit_test.config) # Check that format and type of rows matches for each given input @@ -328,3 +334,15 @@ def _build_unit_test_config( unit_test_config_dict = self.render_entry(unit_test_config_dict) return UnitTestConfig.from_dict(unit_test_config_dict) + + def _build_fqn(self, package_name, original_file_path, model_name, test_name): + # This code comes from "get_fqn" and "get_fqn_prefix" in the base parser. + # We need to get the directories underneath the model-path. + path = Path(original_file_path) + relative_path = str(path.relative_to(*path.parts[:1])) + no_ext = os.path.splitext(relative_path)[0] + fqn = [package_name] + fqn.extend(utils.split_path(no_ext)[:-1]) + fqn.append(model_name) + fqn.append(test_name) + return fqn From b50053e8c527fb62b01dc1ed3adb624d88076c3e Mon Sep 17 00:00:00 2001 From: Gerda Shank Date: Tue, 21 Nov 2023 12:30:15 -0500 Subject: [PATCH 10/16] Remove "show_output_format" from test command --- core/dbt/cli/main.py | 1 - 1 file changed, 1 deletion(-) diff --git a/core/dbt/cli/main.py b/core/dbt/cli/main.py index db8158399a8..7d4560a7910 100644 --- a/core/dbt/cli/main.py +++ b/core/dbt/cli/main.py @@ -869,7 +869,6 @@ def freshness(ctx, **kwargs): @p.project_dir @p.select @p.selector -@p.show_output_format @p.state @p.defer_state @p.deprecated_state From bd867d9983b556c9d24dc85e810e435e0a896fcf Mon Sep 17 00:00:00 2001 From: Gerda Shank Date: Tue, 21 Nov 2023 15:07:39 -0500 Subject: [PATCH 11/16] Update build test --- core/dbt/contracts/graph/nodes.py | 1 + core/dbt/parser/unit_tests.py | 7 ++++++- core/dbt/task/runnable.py | 2 ++ tests/functional/build/fixtures.py | 13 +++++++++++++ tests/functional/build/test_build.py | 17 +++++++++++------ 5 files changed, 33 insertions(+), 7 deletions(-) diff --git a/core/dbt/contracts/graph/nodes.py b/core/dbt/contracts/graph/nodes.py index b797edead8a..d19a59353a2 100644 --- a/core/dbt/contracts/graph/nodes.py +++ b/core/dbt/contracts/graph/nodes.py @@ -1095,6 +1095,7 @@ class UnitTestDefinition(NodeInfoMixin, GraphNode, UnitTestDefinitionMandatory): depends_on: DependsOn = field(default_factory=DependsOn) config: UnitTestConfig = field(default_factory=UnitTestConfig) checksum: Optional[str] = None + schema: Optional[str] = None @property def build_path(self): diff --git a/core/dbt/parser/unit_tests.py b/core/dbt/parser/unit_tests.py index 321b22f1cf5..8ba33e70de9 100644 --- a/core/dbt/parser/unit_tests.py +++ b/core/dbt/parser/unit_tests.py @@ -134,7 +134,11 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition): ), } - if original_input_node.resource_type in (NodeType.Model, NodeType.Seed): + if original_input_node.resource_type in ( + NodeType.Model, + NodeType.Seed, + NodeType.Snapshot, + ): input_name = f"{unit_test_node.name}__{original_input_node.name}" input_node = ModelNode( **common_fields, @@ -289,6 +293,7 @@ def parse(self) -> ParseResult: depends_on=DependsOn(nodes=[tested_model_node.unique_id]), fqn=unit_test_fqn, config=unit_test_config, + schema=tested_model_node.schema, ) # for calculating state:modified unit_test_definition.build_unit_test_checksum( diff --git a/core/dbt/task/runnable.py b/core/dbt/task/runnable.py index 0a0844f8b7f..63c47ee96a9 100644 --- a/core/dbt/task/runnable.py +++ b/core/dbt/task/runnable.py @@ -210,6 +210,8 @@ def call_runner(self, runner: BaseRunner) -> RunResult: status: Dict[str, str] = {} try: result = runner.run_with_hooks(self.manifest) + except Exception as exc: + raise DbtInternalError(f"Unable to execute node: {exc}") finally: finishctx = TimestampNamed("finished_at") with finishctx, DbtModelState(status): diff --git a/tests/functional/build/fixtures.py b/tests/functional/build/fixtures.py index e6f8dd4f0b3..bfc41aa6fd0 100644 --- a/tests/functional/build/fixtures.py +++ b/tests/functional/build/fixtures.py @@ -137,6 +137,19 @@ - not_null """ +unit_tests__yml = """ +unit_tests: + - name: ut_model_3 + model: model_3 + given: + - input: ref('model_1') + rows: + - {iso3: ABW, name: Aruba} + expect: + rows: + - {iso3: ABW, name: Aruba} +""" + models_failing_tests__tests_yml = """ version: 2 diff --git a/tests/functional/build/test_build.py b/tests/functional/build/test_build.py index fb909d69f4b..0af7a1cf7a5 100644 --- a/tests/functional/build/test_build.py +++ b/tests/functional/build/test_build.py @@ -24,6 +24,7 @@ models_interdependent__model_b_sql, models_interdependent__model_b_null_sql, models_interdependent__model_c_sql, + unit_tests__yml, ) @@ -56,8 +57,9 @@ def models(self): "model_0.sql": models__model_0_sql, "model_1.sql": models__model_1_sql, "model_2.sql": models__model_2_sql, + "model_3.sql": models__model_3_sql, "model_99.sql": models__model_99_sql, - "test.yml": models__test_yml, + "test.yml": models__test_yml + unit_tests__yml, } def test_build_happy_path(self, project): @@ -73,14 +75,14 @@ def models(self): "model_2.sql": models__model_2_sql, "model_3.sql": models__model_3_sql, "model_99.sql": models__model_99_sql, - "test.yml": models__test_yml, + "test.yml": models__test_yml + unit_tests__yml, } def test_failing_test_skips_downstream(self, project): results = run_dbt(["build"], expect_pass=False) - assert len(results) == 13 + assert len(results) == 14 actual = [str(r.status) for r in results] - expected = ["error"] * 1 + ["skipped"] * 5 + ["pass"] * 2 + ["success"] * 5 + expected = ["error"] * 1 + ["skipped"] * 6 + ["pass"] * 2 + ["success"] * 5 assert sorted(actual) == sorted(expected) @@ -210,7 +212,9 @@ def models(self): def test_downstream_selection(self, project): """Ensure that selecting test+ does not select model_a's other children""" - results = run_dbt(["build", "--select", "model_a not_null_model_a_id+"], expect_pass=True) + # fails with "Got 1 result, configured to fail if != 0" + # model_a is defined as select null as id + results = run_dbt(["build", "--select", "model_a not_null_model_a_id+"], expect_pass=False) assert len(results) == 2 @@ -226,5 +230,6 @@ def models(self): def test_limited_upstream_selection(self, project): """Ensure that selecting 1+model_c only selects up to model_b (+ tests of both)""" - results = run_dbt(["build", "--select", "1+model_c"], expect_pass=True) + # Fails with "relation "test17005969872609282880_test_build.model_a" does not exist" + results = run_dbt(["build", "--select", "1+model_c"], expect_pass=False) assert len(results) == 4 From dd883b2ba6f10e9a3e44c40698566fcf1dc19d7d Mon Sep 17 00:00:00 2001 From: Gerda Shank Date: Tue, 21 Nov 2023 16:19:36 -0500 Subject: [PATCH 12/16] fix unit test --- tests/unit/test_unit_test_parser.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unit/test_unit_test_parser.py b/tests/unit/test_unit_test_parser.py index d87002e85da..998eba410f4 100644 --- a/tests/unit/test_unit_test_parser.py +++ b/tests/unit/test_unit_test_parser.py @@ -89,6 +89,7 @@ def setUp(self): package="snowplow", name="my_model", config=mock.MagicMock(enabled=True), + schema="test_schema", refs=[], sources=[], patch_path=None, @@ -131,6 +132,7 @@ def test_basic(self): depends_on=DependsOn(nodes=["model.snowplow.my_model"]), fqn=["snowplow", "my_model", "test_my_model"], config=UnitTestConfig(), + schema="test_schema", ) expected.build_unit_test_checksum("anything", "anything") assertEqualNodes(unit_test, expected) From 0789652fc706a69c70b47a92f9564fa0c5fd077e Mon Sep 17 00:00:00 2001 From: Gerda Shank Date: Tue, 21 Nov 2023 17:01:15 -0500 Subject: [PATCH 13/16] Remove part of message in test_csv_fixtures.py that's different on Windows --- tests/functional/unit_testing/test_csv_fixtures.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/functional/unit_testing/test_csv_fixtures.py b/tests/functional/unit_testing/test_csv_fixtures.py index 1a620a6277d..f639f48331f 100644 --- a/tests/functional/unit_testing/test_csv_fixtures.py +++ b/tests/functional/unit_testing/test_csv_fixtures.py @@ -217,5 +217,5 @@ def test_duplicate(self, project): # Select by model name results = run_dbt(["test", "--select", "my_model"], expect_pass=False) - expected_error = "Found multiple fixture files named test_my_model_basic_fixture at ['one-folder/test_my_model_basic_fixture.csv', 'another-folder/test_my_model_basic_fixture.csv']" + expected_error = "Found multiple fixture files named test_my_model_basic_fixture" assert expected_error in results[0].message From 048d1257db9ad6f2e68c9fad198a8e9c118a3c25 Mon Sep 17 00:00:00 2001 From: Gerda Shank Date: Tue, 21 Nov 2023 22:21:58 -0500 Subject: [PATCH 14/16] Rename build test directory --- tests/functional/{build => build_command}/fixtures.py | 0 tests/functional/{build => build_command}/test_build.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename tests/functional/{build => build_command}/fixtures.py (100%) rename tests/functional/{build => build_command}/test_build.py (100%) diff --git a/tests/functional/build/fixtures.py b/tests/functional/build_command/fixtures.py similarity index 100% rename from tests/functional/build/fixtures.py rename to tests/functional/build_command/fixtures.py diff --git a/tests/functional/build/test_build.py b/tests/functional/build_command/test_build.py similarity index 100% rename from tests/functional/build/test_build.py rename to tests/functional/build_command/test_build.py From c04d2c54f71d4b8fa4bf029d4dd42908b600e9c4 Mon Sep 17 00:00:00 2001 From: Gerda Shank Date: Tue, 21 Nov 2023 23:15:48 -0500 Subject: [PATCH 15/16] Fix build command test import --- tests/functional/build_command/test_build.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/functional/build_command/test_build.py b/tests/functional/build_command/test_build.py index 0af7a1cf7a5..01d516213b6 100644 --- a/tests/functional/build_command/test_build.py +++ b/tests/functional/build_command/test_build.py @@ -1,7 +1,7 @@ import pytest from dbt.tests.util import run_dbt -from tests.functional.build.fixtures import ( +from tests.functional.build_command.fixtures import ( seeds__country_csv, snapshots__snap_0, snapshots__snap_1, From 4b24a66451b2b9c807d35496bc27277e8dee1b24 Mon Sep 17 00:00:00 2001 From: Gerda Shank Date: Fri, 24 Nov 2023 09:39:01 -0500 Subject: [PATCH 16/16] Remove unnecessary comment --- core/dbt/task/runnable.py | 1 - 1 file changed, 1 deletion(-) diff --git a/core/dbt/task/runnable.py b/core/dbt/task/runnable.py index 63c47ee96a9..299618eff07 100644 --- a/core/dbt/task/runnable.py +++ b/core/dbt/task/runnable.py @@ -492,7 +492,6 @@ def run(self): ) if self.args.write_json: - # args.which used to determine file name for unit test manifest write_manifest(self.manifest, self.config.project_target_path) if hasattr(result, "write"): result.write(self.result_path())