Skip to content

Commit

Permalink
convert test to data_test (#9201)
Browse files Browse the repository at this point in the history
* convert test to data_test

* generate proto types

* fixing tests

* add tests

* add more tests

* test cleanup

* WIP

* fix graph

* fix testing manifest

* set resource type back to test and reset unique id

* reset expected run results

* cleanup

* changie

* modify to only look for tests under columns in schema files

* stop using dashes
  • Loading branch information
emmyoop authored Dec 7, 2023
1 parent ca82f54 commit a570a2c
Show file tree
Hide file tree
Showing 110 changed files with 1,734 additions and 1,275 deletions.
7 changes: 7 additions & 0 deletions .changes/unreleased/Features-20231205-131717.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Features
body: Convert the `tests` config to `data_tests` in both dbt_project.yml and schema files.
in schema files.
time: 2023-12-05T13:17:17.647765-06:00
custom:
Author: emmyoop
Issue: "8699"
2 changes: 1 addition & 1 deletion core/dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
def print_compile_stats(stats):
names = {
NodeType.Model: "model",
NodeType.Test: "test",
NodeType.Test: "data test",
NodeType.Unit: "unit test",
NodeType.Snapshot: "snapshot",
NodeType.Analysis: "analysis",
Expand Down
14 changes: 8 additions & 6 deletions core/dbt/config/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def create_project(self, rendered: RenderComponents) -> "Project":
seeds: Dict[str, Any]
snapshots: Dict[str, Any]
sources: Dict[str, Any]
tests: Dict[str, Any]
data_tests: Dict[str, Any]
unit_tests: Dict[str, Any]
metrics: Dict[str, Any]
semantic_models: Dict[str, Any]
Expand All @@ -454,7 +454,9 @@ def create_project(self, rendered: RenderComponents) -> "Project":
seeds = cfg.seeds
snapshots = cfg.snapshots
sources = cfg.sources
tests = cfg.tests
# the `tests` config is deprecated but still allowed. Copy it into
# `data_tests` to simplify logic throughout the rest of the system.
data_tests = cfg.data_tests if "data_tests" in rendered.project_dict else cfg.tests
unit_tests = cfg.unit_tests
metrics = cfg.metrics
semantic_models = cfg.semantic_models
Expand Down Expand Up @@ -516,7 +518,7 @@ def create_project(self, rendered: RenderComponents) -> "Project":
selectors=selectors,
query_comment=query_comment,
sources=sources,
tests=tests,
data_tests=data_tests,
unit_tests=unit_tests,
metrics=metrics,
semantic_models=semantic_models,
Expand Down Expand Up @@ -627,7 +629,7 @@ class Project:
seeds: Dict[str, Any]
snapshots: Dict[str, Any]
sources: Dict[str, Any]
tests: Dict[str, Any]
data_tests: Dict[str, Any]
unit_tests: Dict[str, Any]
metrics: Dict[str, Any]
semantic_models: Dict[str, Any]
Expand Down Expand Up @@ -713,8 +715,8 @@ def to_project_config(self, with_packages=False):
"seeds": self.seeds,
"snapshots": self.snapshots,
"sources": self.sources,
"tests": self.tests,
"unit-tests": self.unit_tests,
"data_tests": self.data_tests,
"unit_tests": self.unit_tests,
"metrics": self.metrics,
"semantic-models": self.semantic_models,
"saved-queries": self.saved_queries,
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/config/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def should_render_keypath(self, keypath: Keypath) -> bool:
if first == "vars":
return False

if first in {"seeds", "models", "snapshots", "tests"}:
if first in {"seeds", "models", "snapshots", "tests", "data_tests"}:
keypath_parts = {(k.lstrip("+ ") if isinstance(k, str) else k) for k in keypath}
# model-level hooks
late_rendered_hooks = {"pre-hook", "post-hook", "pre_hook", "post_hook"}
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/config/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def from_parts(
selectors=project.selectors,
query_comment=project.query_comment,
sources=project.sources,
tests=project.tests,
data_tests=project.data_tests,
unit_tests=project.unit_tests,
metrics=project.metrics,
semantic_models=project.semantic_models,
Expand Down Expand Up @@ -324,7 +324,7 @@ def get_resource_config_paths(self) -> Dict[str, PathSet]:
"seeds": self._get_config_paths(self.seeds),
"snapshots": self._get_config_paths(self.snapshots),
"sources": self._get_config_paths(self.sources),
"tests": self._get_config_paths(self.tests),
"data_tests": self._get_config_paths(self.data_tests),
"unit_tests": self._get_config_paths(self.unit_tests),
"metrics": self._get_config_paths(self.metrics),
"semantic_models": self._get_config_paths(self.semantic_models),
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/context/context_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_config_dict(self, resource_type: NodeType) -> Dict[str, Any]:
elif resource_type == NodeType.Source:
model_configs = unrendered.get("sources")
elif resource_type == NodeType.Test:
model_configs = unrendered.get("tests")
model_configs = unrendered.get("data_tests")
elif resource_type == NodeType.Metric:
model_configs = unrendered.get("metrics")
elif resource_type == NodeType.SemanticModel:
Expand Down Expand Up @@ -73,7 +73,7 @@ def get_config_dict(self, resource_type: NodeType) -> Dict[str, Any]:
elif resource_type == NodeType.Source:
model_configs = self.project.sources
elif resource_type == NodeType.Test:
model_configs = self.project.tests
model_configs = self.project.data_tests
elif resource_type == NodeType.Metric:
model_configs = self.project.metrics
elif resource_type == NodeType.SemanticModel:
Expand Down
40 changes: 20 additions & 20 deletions core/dbt/contracts/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def remote(cls, contents: str, project_name: str, language: str) -> "SourceFile"
class SchemaSourceFile(BaseSourceFile):
dfy: Dict[str, Any] = field(default_factory=dict)
# these are in the manifest.nodes dictionary
tests: Dict[str, Any] = field(default_factory=dict)
data_tests: Dict[str, Any] = field(default_factory=dict)
sources: List[str] = field(default_factory=list)
exposures: List[str] = field(default_factory=list)
metrics: List[str] = field(default_factory=list)
Expand Down Expand Up @@ -273,31 +273,31 @@ def append_patch(self, yaml_key, unique_id):
def add_test(self, node_unique_id, test_from):
name = test_from["name"]
key = test_from["key"]
if key not in self.tests:
self.tests[key] = {}
if name not in self.tests[key]:
self.tests[key][name] = []
self.tests[key][name].append(node_unique_id)
if key not in self.data_tests:
self.data_tests[key] = {}
if name not in self.data_tests[key]:
self.data_tests[key][name] = []
self.data_tests[key][name].append(node_unique_id)

# this is only used in unit tests
# this is only used in tests/unit
def remove_tests(self, yaml_key, name):
if yaml_key in self.tests:
if name in self.tests[yaml_key]:
del self.tests[yaml_key][name]
if yaml_key in self.data_tests:
if name in self.data_tests[yaml_key]:
del self.data_tests[yaml_key][name]

# this is only used in tests (unit + functional)
# this is only used in the tests directory (unit + functional)
def get_tests(self, yaml_key, name):
if yaml_key in self.tests:
if name in self.tests[yaml_key]:
return self.tests[yaml_key][name]
if yaml_key in self.data_tests:
if name in self.data_tests[yaml_key]:
return self.data_tests[yaml_key][name]
return []

def get_key_and_name_for_test(self, test_unique_id):
yaml_key = None
block_name = None
for key in self.tests.keys():
for name in self.tests[key]:
for unique_id in self.tests[key][name]:
for key in self.data_tests.keys():
for name in self.data_tests[key]:
for unique_id in self.data_tests[key][name]:
if unique_id == test_unique_id:
yaml_key = key
block_name = name
Expand All @@ -306,9 +306,9 @@ def get_key_and_name_for_test(self, test_unique_id):

def get_all_test_ids(self):
test_ids = []
for key in self.tests.keys():
for name in self.tests[key]:
test_ids.extend(self.tests[key][name])
for key in self.data_tests.keys():
for name in self.data_tests[key]:
test_ids.extend(self.data_tests[key][name])
return test_ids

def add_env_var(self, var, yaml_key, name):
Expand Down
43 changes: 34 additions & 9 deletions core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from mashumaro.types import SerializableType
from typing import Optional, Union, List, Dict, Any, Sequence, Tuple, Iterator, Literal

from dbt import deprecations
from dbt.dataclass_schema import dbtClassMixin, ExtensibleDbtClassMixin

from dbt.clients.system import write_file
Expand Down Expand Up @@ -43,10 +44,7 @@
from dbt.contracts.graph.semantic_layer_common import WhereFilterIntersection
from dbt.contracts.util import Replaceable, AdditionalPropertiesMixin
from dbt.events.functions import warn_or_error
from dbt.exceptions import (
ParsingError,
ContractBreakingChangeError,
)
from dbt.exceptions import ParsingError, ContractBreakingChangeError, ValidationError
from dbt.events.types import (
SeedIncreased,
SeedExceedsLimitSamePath,
Expand Down Expand Up @@ -1237,6 +1235,24 @@ def get_full_source_name(self):
def get_source_representation(self):
return f'source("{self.source.name}", "{self.table.name}")'

def validate_data_tests(self):
"""
sources parse tests differently than models, so we need to do some validation
here where it's done in the PatchParser for other nodes
"""
for column in self.columns:
if column.tests and column.data_tests:
raise ValidationError(

Check warning on line 1245 in core/dbt/contracts/graph/nodes.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/contracts/graph/nodes.py#L1245

Added line #L1245 was not covered by tests
"Invalid test config: cannot have both 'tests' and 'data_tests' defined"
)
if column.tests:
deprecations.warn(
"project-test-config",
deprecated_path="tests",
exp_path="data_tests",
)
column.data_tests = column.tests

@property
def quote_columns(self) -> Optional[bool]:
result = None
Expand All @@ -1251,14 +1267,23 @@ def columns(self) -> Sequence[UnparsedColumn]:
return [] if self.table.columns is None else self.table.columns

def get_tests(self) -> Iterator[Tuple[Dict[str, Any], Optional[UnparsedColumn]]]:
for test in self.tests:
yield normalize_test(test), None
self.validate_data_tests()
for data_test in self.data_tests:
yield normalize_test(data_test), None

for column in self.columns:
if column.tests is not None:
for test in column.tests:
yield normalize_test(test), column
if column.data_tests is not None:
for data_test in column.data_tests:
yield normalize_test(data_test), column

@property
def data_tests(self) -> List[TestDef]:
if self.table.data_tests is None:
return []

Check warning on line 1282 in core/dbt/contracts/graph/nodes.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/contracts/graph/nodes.py#L1282

Added line #L1282 was not covered by tests
else:
return self.table.data_tests

# deprecated
@property
def tests(self) -> List[TestDef]:
if self.table.tests is None:
Expand Down
11 changes: 8 additions & 3 deletions core/dbt/contracts/graph/unparsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class HasColumnProps(AdditionalPropertiesMixin, ExtensibleDbtClassMixin, Replace

@dataclass
class HasColumnAndTestProps(HasColumnProps):
data_tests: List[TestDef] = field(default_factory=list)
tests: List[TestDef] = field(default_factory=list)


Expand Down Expand Up @@ -145,7 +146,7 @@ class UnparsedVersion(dbtClassMixin):
config: Dict[str, Any] = field(default_factory=dict)
constraints: List[Dict[str, Any]] = field(default_factory=list)
docs: Docs = field(default_factory=Docs)
tests: Optional[List[TestDef]] = None
data_tests: Optional[List[TestDef]] = None
columns: Sequence[Union[dbt.helper_types.IncludeExclude, UnparsedColumn]] = field(
default_factory=list
)
Expand Down Expand Up @@ -248,7 +249,11 @@ def get_tests_for_version(self, version: NodeVersion) -> List[TestDef]:
f"get_tests_for_version called for version '{version}' not in version map"
)
unparsed_version = self._version_map[version]
return unparsed_version.tests if unparsed_version.tests is not None else self.tests
return (
unparsed_version.data_tests
if unparsed_version.data_tests is not None
else self.data_tests
)


@dataclass
Expand Down Expand Up @@ -401,7 +406,7 @@ class SourceTablePatch(dbtClassMixin):
freshness: Optional[FreshnessThreshold] = field(default_factory=FreshnessThreshold)
external: Optional[ExternalTable] = None
tags: Optional[List[str]] = None
tests: Optional[List[TestDef]] = None
data_tests: Optional[List[TestDef]] = None
columns: Optional[Sequence[UnparsedColumn]] = None

def to_patch_dict(self) -> Dict[str, Any]:
Expand Down
15 changes: 12 additions & 3 deletions core/dbt/contracts/project.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dbt import deprecations
from dbt.contracts.util import Replaceable, Mergeable, list_str, Identifier
from dbt.contracts.connection import QueryComment, UserConfigContract
from dbt.helper_types import NoValue
Expand Down Expand Up @@ -194,7 +195,7 @@ class Project(dbtClassMixin, Replaceable):
source_paths: Optional[List[str]] = None
model_paths: Optional[List[str]] = None
macro_paths: Optional[List[str]] = None
data_paths: Optional[List[str]] = None
data_paths: Optional[List[str]] = None # deprecated
seed_paths: Optional[List[str]] = None
test_paths: Optional[List[str]] = None
analysis_paths: Optional[List[str]] = None
Expand All @@ -216,7 +217,8 @@ class Project(dbtClassMixin, Replaceable):
snapshots: Dict[str, Any] = field(default_factory=dict)
analyses: Dict[str, Any] = field(default_factory=dict)
sources: Dict[str, Any] = field(default_factory=dict)
tests: Dict[str, Any] = field(default_factory=dict)
tests: Dict[str, Any] = field(default_factory=dict) # deprecated
data_tests: Dict[str, Any] = field(default_factory=dict)
unit_tests: Dict[str, Any] = field(default_factory=dict)
metrics: Dict[str, Any] = field(default_factory=dict)
semantic_models: Dict[str, Any] = field(default_factory=dict)
Expand Down Expand Up @@ -260,7 +262,6 @@ class Config(dbtMashConfig):
"semantic_models": "semantic-models",
"saved_queries": "saved-queries",
"dbt_cloud": "dbt-cloud",
"unit_tests": "unit-tests",
}

@classmethod
Expand All @@ -282,6 +283,14 @@ def validate(cls, data):
raise ValidationError(
f"Invalid dbt_cloud config. Expected a 'dict' but got '{type(data['dbt_cloud'])}'"
)
if data.get("tests", None) and data.get("data_tests", None):
raise ValidationError(
"Invalid project config: cannot have both 'tests' and 'data_tests' defined"
)
if "tests" in data:
deprecations.warn(
"project-test-config", deprecated_path="tests", exp_path="data_tests"
)


@dataclass
Expand Down
Loading

0 comments on commit a570a2c

Please sign in to comment.