-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow use of sources as unit testing inputs #9059
Changes from 3 commits
e92f753
17571dd
27b99bd
98a1b8b
c5b4428
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
kind: Features | ||
body: Support source inputs in unit tests | ||
time: 2023-11-11T19:11:50.870494-05:00 | ||
custom: | ||
Author: gshank | ||
Issue: "8507" |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
UnitTestDefinition, | ||
DependsOn, | ||
UnitTestConfig, | ||
UnitTestSourceDefinition, | ||
) | ||
from dbt.contracts.graph.unparsed import UnparsedUnitTest | ||
from dbt.exceptions import ParsingError, InvalidUnitTestGivenInput | ||
|
@@ -105,44 +106,52 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition): | |
- {id: 2, b: 2} | ||
""" | ||
# Add the model "input" nodes, consisting of all referenced models in the unit test. | ||
# This creates a model for every input in every test, so there may be multiple | ||
# input models substituting for the same input ref'd model. | ||
# This creates an ephemeral model for every input in every test, so there may be multiple | ||
# input models substituting for the same input ref'd model. Note that since these are | ||
# always "ephemeral" they just wrap the tested_node SQL in additional CTEs. No actual table | ||
# or view is created. | ||
for given in test_case.given: | ||
# extract the original_input_node from the ref in the "input" key of the given list | ||
original_input_node = self._get_original_input_node(given.input, tested_node) | ||
|
||
original_input_node_columns = None | ||
if ( | ||
original_input_node.resource_type == NodeType.Model | ||
and original_input_node.config.contract.enforced | ||
): | ||
original_input_node_columns = { | ||
column.name: column.data_type for column in original_input_node.columns | ||
} | ||
MichelleArk marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# TODO: include package_name? | ||
input_name = f"{unit_test_node.name}__{original_input_node.name}" | ||
input_unique_id = f"model.{package_name}.{input_name}" | ||
input_node = ModelNode( | ||
raw_code=self._build_fixture_raw_code( | ||
given.get_rows( | ||
self.root_project.project_root, self.root_project.fixture_paths | ||
), | ||
original_input_node_columns, | ||
project_root = self.root_project.project_root | ||
common_fields = { | ||
"resource_type": NodeType.Model, | ||
"package_name": package_name, | ||
"original_file_path": original_input_node.original_file_path, | ||
"unique_id": f"model.{package_name}.{input_name}", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we may need to include There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've changed this to include the source_name. This does make for pretty long unique_ids. Do we have any concerns about that? It's not like we're using that name to construct tables or anything... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So far I've noticed this issue creep up in #9015. I think we could shorten the node name for CTE generation (since it doesn't need to be unique) but keep the unique_id longer |
||
"config": ModelConfig(materialized="ephemeral"), | ||
"database": original_input_node.database, | ||
"alias": original_input_node.identifier, | ||
"schema": original_input_node.schema, | ||
"fqn": original_input_node.fqn, | ||
"checksum": FileHash.empty(), | ||
"raw_code": self._build_fixture_raw_code( | ||
given.get_rows(project_root, self.root_project.fixture_paths), None | ||
), | ||
resource_type=NodeType.Model, | ||
package_name=package_name, | ||
path=original_input_node.path, | ||
original_file_path=original_input_node.original_file_path, | ||
unique_id=input_unique_id, | ||
name=input_name, | ||
config=ModelConfig(materialized="ephemeral"), | ||
database=original_input_node.database, | ||
schema=original_input_node.schema, | ||
alias=original_input_node.alias, | ||
fqn=input_unique_id.split("."), | ||
checksum=FileHash.empty(), | ||
) | ||
} | ||
|
||
if original_input_node.resource_type == NodeType.Model: | ||
input_node = ModelNode( | ||
**common_fields, | ||
name=input_name, | ||
path=original_input_node.path, | ||
) | ||
elif original_input_node.resource_type == NodeType.Source: | ||
# We are reusing the database/schema/identifier from the original source, | ||
# but that shouldn't matter since this acts as an ephemeral model which just | ||
# wraps a CTE around the unit test node. | ||
input_node = UnitTestSourceDefinition( | ||
**common_fields, | ||
name=original_input_node.name, # must be the same name for source lookup to work | ||
path=input_name + ".sql", # for writing out compiled_code | ||
source_name=original_input_node.source_name, # needed for source lookup | ||
) | ||
# Sources need to go in the sources dictionary in order to create the right lookup | ||
self.unit_test_manifest.sources[input_node.unique_id] = input_node # type: ignore | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we anticipate any issues by having the sources dictionary contain a unique_id key that is prefixed with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It doesn't seem to care. We don't actually check the unique_id prefix that I can recall. If somebody starts parsing the unit_test_manifest, I suppose it might be confusing. But right now we're putting it in two places, so one of them will be wrong. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This probably isn't worth spending tons of time on.. but I think it could be possible to get around having to add the node to manifest.sources and do the lookup from the .nodes collection in UnitTestRuntimeSourceResolver since the unique_id will include source_name. kind of like what's done here: https://github.com/dbt-labs/dbt-core/blob/unit_testing_feature_branch/core/dbt/context/providers.py#L578 Not entirely sure what's more readable or less complex in this case. I can imagine having to maintain UnitTestSourceDefinitions across both dictionaries could be error-prone though..
Given that UnitTestSourceDefinition is a ModelNode, I think having it in nodes is 'more' correct There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the lookup behavior of sources and nodes is subtly different with regard to the meaning of package=None, so I don't think looking up sources as though they were nodes is worth it. |
||
|
||
# Both ModelNode and UnitTestSourceDefinition need to go in nodes dictionary | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for my own understanding - is this to enable cte injection? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah. There's code in compilation.py that looks up the existence of the cte in the nodes dictionary: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In theory we could also check for a UnitTestSourceDefinition and in sources, but that didn't feel like an improvement. |
||
self.unit_test_manifest.nodes[input_node.unique_id] = input_node | ||
|
||
# Populate this_input_node_unique_id if input fixture represents node being tested | ||
|
@@ -153,6 +162,8 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition): | |
unit_test_node.depends_on.nodes.append(input_node.unique_id) | ||
|
||
def _build_fixture_raw_code(self, rows, column_name_to_data_types) -> str: | ||
# We're not currently using column_name_to_data_types, but leaving here for | ||
# possible future use. | ||
return ("{{{{ get_fixture_sql({rows}, {column_name_to_data_types}) }}}}").format( | ||
rows=rows, column_name_to_data_types=column_name_to_data_types | ||
) | ||
|
@@ -178,18 +189,21 @@ def _get_original_input_node(self, input: str, tested_node: ModelNode): | |
raise InvalidUnitTestGivenInput(input=input) | ||
|
||
if statically_parsed["refs"]: | ||
for ref in statically_parsed["refs"]: | ||
name = ref.get("name") | ||
package = ref.get("package") | ||
version = ref.get("version") | ||
# TODO: disabled lookup, versioned lookup, public models | ||
original_input_node = self.manifest.ref_lookup.find( | ||
name, package, version, self.manifest | ||
) | ||
ref = list(statically_parsed["refs"])[0] | ||
name = ref.get("name") | ||
package = ref.get("package") | ||
version = ref.get("version") | ||
# TODO: disabled lookup, versioned lookup, public models | ||
original_input_node = self.manifest.ref_lookup.find( | ||
name, package, version, self.manifest | ||
) | ||
elif statically_parsed["sources"]: | ||
input_package_name, input_source_name = statically_parsed["sources"][0] | ||
source = list(statically_parsed["sources"])[0] | ||
input_source_name, input_name = source | ||
original_input_node = self.manifest.source_lookup.find( | ||
input_source_name, input_package_name, self.manifest | ||
f"{input_source_name}.{input_name}", | ||
self.root_project.project_name, | ||
gshank marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.manifest, | ||
) | ||
else: | ||
raise InvalidUnitTestGivenInput(input=input) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import pytest | ||
from dbt.tests.util import run_dbt | ||
|
||
raw_customers_csv = """id,first_name,last_name,email | ||
1,Michael,Perez,[email protected] | ||
2,Shawn,Mccoy,[email protected] | ||
3,Kathleen,Payne,[email protected] | ||
4,Jimmy,Cooper,[email protected] | ||
5,Katherine,Rice,[email protected] | ||
6,Sarah,Ryan,[email protected] | ||
7,Martin,Mcdonald,[email protected] | ||
8,Frank,Robinson,[email protected] | ||
9,Jennifer,Franklin,[email protected] | ||
10,Henry,Welch,[email protected] | ||
""" | ||
|
||
schema_sources_yml = """ | ||
sources: | ||
- name: seed_sources | ||
schema: "{{ target.schema }}" | ||
tables: | ||
- name: raw_customers | ||
columns: | ||
- name: id | ||
tests: | ||
- not_null: | ||
severity: "{{ 'error' if target.name == 'prod' else 'warn' }}" | ||
- unique | ||
- name: first_name | ||
- name: last_name | ||
- name: email | ||
unit_tests: | ||
- name: test_customers | ||
model: customers | ||
given: | ||
- input: source('seed_sources', 'raw_customers') | ||
rows: | ||
- {id: 1, first_name: Emily} | ||
expect: | ||
rows: | ||
- {id: 1, first_name: Emily} | ||
""" | ||
|
||
customers_sql = """ | ||
select * from {{ source('seed_sources', 'raw_customers') }} | ||
""" | ||
|
||
|
||
class TestUnitTestSourceInput: | ||
@pytest.fixture(scope="class") | ||
def seeds(self): | ||
return { | ||
"raw_customers.csv": raw_customers_csv, | ||
} | ||
|
||
@pytest.fixture(scope="class") | ||
def models(self): | ||
return { | ||
"customers.sql": customers_sql, | ||
"sources.yml": schema_sources_yml, | ||
} | ||
|
||
def test_source_input(self, project): | ||
results = run_dbt(["seed"]) | ||
results = run_dbt(["run"]) | ||
len(results) == 1 | ||
|
||
results = run_dbt(["unit-test"]) | ||
assert len(results) == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we simplify the logic here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How? We can't set the resource_type to Source because that breaks execution.