Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

schema tests defined by macros #339

Closed
wants to merge 13 commits into from
4 changes: 3 additions & 1 deletion dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import dbt.project
import dbt.utils
import dbt.include

from dbt.model import Model
from dbt.utils import This, Var, is_enabled, get_materialization, NodeType, \
Expand Down Expand Up @@ -63,7 +64,8 @@ def recursively_parse_macros_for_node(node, flat_graph, context):
context.get(package_name, {}) \
.update(macro_map)

if package_name == node.get('package_name'):
if package_name in (node.get('package_name'),
dbt.include.GLOBAL_PROJECT_NAME):
context.update(macro_map)

return context
Expand Down
5 changes: 5 additions & 0 deletions dbt/include/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

import os

GLOBAL_DBT_MODULES_PATH = os.path.dirname(__file__)
GLOBAL_PROJECT_NAME = 'dbt'
5 changes: 5 additions & 0 deletions dbt/include/global_project/dbt_project.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

name: dbt
version: 1.0

macro-paths: ["macros"]
31 changes: 31 additions & 0 deletions dbt/include/global_project/macros/schema_tests/accepted_values.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@

{% macro test_accepted_values(model, field, values) %}

with all_values as (

select distinct
{{ field }} as value_field

from {{ model }}

),

validation_errors as (

select
value_field

from all_values
where value_field not in (
{% for value in values -%}

'{{ value }}' {% if not loop.last -%} , {%- endif %}

{%- endfor %}
)
)

select count(*)
from validation_errors

{% endmacro %}
18 changes: 18 additions & 0 deletions dbt/include/global_project/macros/schema_tests/not_null.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@

{% macro test_not_null(model, arg) %}

with validation as (

select
{{ arg }} as not_null_field

from {{ model }}

)

select count(*)
from validation
where not_null_field is null

{% endmacro %}

27 changes: 27 additions & 0 deletions dbt/include/global_project/macros/schema_tests/relationships.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@

{% macro test_relationships(model, field, to, from) %}

with parent as (

select
{{ field }} as id

from {{ to }}

),

child as (

select
{{ from }} as id

from {{ model }}

)

select count(*)
from child
where id is not null
and id not in (select id from parent)

{% endmacro %}
28 changes: 28 additions & 0 deletions dbt/include/global_project/macros/schema_tests/unique.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@

{% macro test_unique(model, arg) %}

with validation as (

select
{{ arg }} as unique_field

from {{ model }}
where {{ arg }} is not null

),

validation_errors as (

select
unique_field

from validation
group by unique_field
having count(*) > 1

)

select count(*)
from validation_errors

{% endmacro %}
135 changes: 52 additions & 83 deletions dbt/parser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import os
import yaml
import re

import dbt.flags
import dbt.model
Expand All @@ -14,54 +15,9 @@
import dbt.contracts.project

from dbt.utils import NodeType
from dbt.compat import basestring, to_string
from dbt.logger import GLOBAL_LOGGER as logger

QUERY_VALIDATE_NOT_NULL = """
with validation as (
select {field} as f
from {ref}
)
select count(*) from validation where f is null
"""


QUERY_VALIDATE_UNIQUE = """
with validation as (
select {field} as f
from {ref}
where {field} is not null
),
validation_errors as (
select f from validation group by f having count(*) > 1
)
select count(*) from validation_errors
"""


QUERY_VALIDATE_ACCEPTED_VALUES = """
with all_values as (
select distinct {field} as f
from {ref}
),
validation_errors as (
select f from all_values where f not in ({values_csv})
)
select count(*) from validation_errors
"""


QUERY_VALIDATE_REFERENTIAL_INTEGRITY = """
with parent as (
select {parent_field} as id
from {parent_ref}
), child as (
select {child_field} as id
from {child_ref}
)
select count(*) from child
where id not in (select id from parent) and id is not null
"""


def get_path(resource_type, package_name, resource_name):
return "{}.{}.{}".format(resource_type, package_name, resource_name)
Expand Down Expand Up @@ -328,6 +284,15 @@ def parse_schema_tests(tests, root_project, projects):
if configs is None:
continue

if not isinstance(configs, (list, tuple)):

dbt.utils.compiler_warning(
model_name,
"Invalid test config given in {} near {}".format(
test.get('path'),
configs))
continue

for config in configs:
to_add = parse_schema_test(
test, model_name, config, test_type,
Expand All @@ -341,57 +306,61 @@ def parse_schema_tests(tests, root_project, projects):
return to_return


def parse_schema_test(test_base, model_name, test_config, test_type,
root_project_config, package_project_config,
all_projects):
if test_type == 'not_null':
raw_sql = QUERY_VALIDATE_NOT_NULL.format(
ref="{{ref('"+model_name+"')}}", field=test_config)
name_key = test_config
def get_nice_schema_test_name(test_type, test_name, args):

flat_args = []
for arg_name in sorted(args):
arg_val = args[arg_name]

if isinstance(arg_val, dict):
parts = arg_val.values()
elif isinstance(arg_val, (list, tuple)):
parts = arg_val
else:
parts = [arg_val]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be anything but a dict? looks like below we make sure it's a dict before calling this. parts = kwargs.values() should cover all the cases, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This func takes arbitrary schema test configs and spits out a human readable, unique name.

my_model:
  constraints:
    test_something:
      - { some_value: True, other_thing: ['abc', 'def'] }

So this func would operate on { some_value: True, other_thing: ['abc', 'def'] } and spit out:

test_something_my_model_abc_def__True

I actually think that in practice this is kind of annoying/confusing. This "nice" name becomes the compiled filename and shows up in the dbt test output. Ideally, we'd keep the test config args around and show something like:

ERROR running test "test_something" for model "my_model" with args:
  some_value: True
  other_thing: ['abc', 'def']

but that's not really how things work currently. Something to consider for the future though

Copy link
Contributor Author

@drewbanin drewbanin Mar 25, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh gosh, and to answer your question, yes: The arg_val var will be the value of each item in the supplied dict (args). So here it is a bool, then a dict. could also be a string/list/int etc


elif test_type == 'unique':
raw_sql = QUERY_VALIDATE_UNIQUE.format(
ref="{{ref('"+model_name+"')}}", field=test_config)
name_key = test_config
flat_args.extend([str(part) for part in parts])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use dbt.compat.basestring

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍


elif test_type == 'relationships':
if not isinstance(test_config, dict):
return None
clean_flat_args = [re.sub('[^0-9a-zA-Z_]+', '_', arg) for arg in flat_args]
unique = "__".join(clean_flat_args)
return '{}_{}_{}'.format(test_type, test_name, unique)

child_field = test_config.get('from')
parent_field = test_config.get('field')
parent_model = test_config.get('to')

raw_sql = QUERY_VALIDATE_REFERENTIAL_INTEGRITY.format(
child_field=child_field,
child_ref="{{ref('"+model_name+"')}}",
parent_field=parent_field,
parent_ref=("{{ref('"+parent_model+"')}}"))
def as_kwarg(key, value):
test_value = to_string(value)
is_function = re.match(r'^\s*(ref|var)\(.+\)$', test_value) is not None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cmcarthur how do you feel about this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i feel ok about it. not in love with regex parsing but what can you do.

is the intention to support ref and var in custom schema tests? it doesn't look like those would be passed in here unless i'm missing something


name_key = '{}_to_{}_{}'.format(child_field, parent_model,
parent_field)
# if the value is a function, don't wrap it in quotes!
if is_function:
formatted_value = value
else:
formatted_value = value.__repr__()

elif test_type == 'accepted_values':
if not isinstance(test_config, dict):
return None
return "{key}={value}".format(key=key, value=formatted_value)

raw_sql = QUERY_VALIDATE_ACCEPTED_VALUES.format(
ref="{{ref('"+model_name+"')}}",
field=test_config.get('field', ''),
values_csv="'{}'".format(
"','".join([str(v) for v in test_config.get('values', [])])))

name_key = test_config.get('field')
def parse_schema_test(test_base, model_name, test_config, test_type,
root_project_config, package_project_config,
all_projects):

if isinstance(test_config, (basestring, int, float, bool)):
test_args = {'arg': test_config}
else:
raise dbt.exceptions.ValidationException(
'Unknown schema test type {}'.format(test_type))
test_args = test_config

name = '{}_{}_{}'.format(test_type, model_name, name_key)
# sort the dict so the keys are rendered deterministically (for tests)
kwargs = [as_kwarg(key, test_args[key]) for key in sorted(test_args)]

raw_sql = "{{{{ {macro}(model=ref('{model}'), {kwargs}) }}}}".format(**{
'model': model_name,
'macro': "test_{}".format(test_type),
'kwargs': ", ".join(kwargs)
})

name = get_nice_schema_test_name(test_type, model_name, test_args)

pseudo_path = dbt.utils.get_pseudo_test_path(name, test_base.get('path'),
'schema_test')

to_return = {
'name': name,
'resource_type': test_base.get('resource_type'),
Expand Down
34 changes: 27 additions & 7 deletions dbt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json

import dbt.project
from dbt.include import GLOBAL_DBT_MODULES_PATH

from dbt.compat import basestring
from dbt.logger import GLOBAL_LOGGER as logger
Expand Down Expand Up @@ -42,26 +43,32 @@ def __repr__(self):
return self.schema_table(self.schema, self.table)


def compiler_error(model, msg):
def get_model_name_or_none(model):
if model is None:
name = '<None>'
elif isinstance(model, str):

elif isinstance(model, basestring):
name = model
elif isinstance(model, dict):
name = model.get('name')
else:
name = model.nice_name
return name


def compiler_error(model, msg):
name = get_model_name_or_none(model)
raise RuntimeError(
"! Compilation error while compiling model {}:\n! {}\n"
.format(name, msg)
)


def compiler_warning(model, msg):
name = get_model_name_or_none(model)
logger.info(
"* Compilation warning while compiling model {}:\n* {}\n"
.format(model.nice_name, msg)
.format(name, msg)
)


Expand Down Expand Up @@ -153,9 +160,21 @@ def find_model_by_fqn(models, fqn):


def dependency_projects(project):
for obj in os.listdir(project['modules-path']):
full_obj = os.path.join(project['modules-path'], obj)
if os.path.isdir(full_obj):
module_paths = [
GLOBAL_DBT_MODULES_PATH,
project['modules-path']
]

for module_path in module_paths:
for obj in os.listdir(module_path):
full_obj = os.path.join(module_path, obj)

if not os.path.isdir(full_obj) or obj.startswith('__'):
# exclude non-dirs and dirs that start with __
# the latter could be something like __pycache__
# for the global dbt modules dir
continue

try:
yield dbt.project.read_project(
os.path.join(full_obj, 'dbt_project.yml'),
Expand All @@ -164,7 +183,8 @@ def dependency_projects(project):
args=project.args)
except dbt.project.DbtProjectError as e:
logger.info(
"Error reading dependency project at {}".format(full_obj)
"Error reading dependency project at {}".format(
full_obj)
)
logger.info(str(e))

Expand Down
Loading