Skip to content

Commit

Permalink
Reorganize dbt.tests.tables. Cleanup adapter handling
Browse files Browse the repository at this point in the history
  • Loading branch information
gshank committed Feb 19, 2022
1 parent 5109f8c commit 30b5707
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 86 deletions.
15 changes: 9 additions & 6 deletions core/dbt/tests/fixtures/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import dbt.flags as flags

from dbt.config.runtime import RuntimeConfig
from dbt.adapters.factory import get_adapter, register_adapter
from dbt.adapters.factory import get_adapter, register_adapter, reset_adapters
from dbt.events.functions import setup_event_logger

import yaml
Expand Down Expand Up @@ -151,19 +151,22 @@ def selectors_yml(project_root, selectors):


@pytest.fixture
def schema(unique_schema, project_root, profiles_root):
def schema(unique_schema, project_root, profiles_root, profiles_yml, dbt_project_yml):
# Dummy args just to get adapter up and running
args = Namespace(profiles_dir=str(profiles_root), project_dir=str(project_root))
# The profiles.yml and dbt_project.yml should already be written out
args = Namespace(
profiles_dir=str(profiles_root), project_dir=str(project_root), target=None, profile=None
)
runtime_config = RuntimeConfig.from_args(args)

register_adapter(runtime_config)
adapter = get_adapter(runtime_config)
# execute(adapter, "drop schema if exists {} cascade".format(unique_schema))
execute(adapter, "drop schema if exists {} cascade".format(unique_schema))
execute(adapter, "create schema {}".format(unique_schema))
yield adapter
adapter = get_adapter(runtime_config)
# adapter.cleanup_connections()
execute(adapter, "drop schema if exists {} cascade".format(unique_schema))
adapter.cleanup_connections()
reset_adapters()


def execute(adapter, sql, connection_name="__test"):
Expand Down
136 changes: 63 additions & 73 deletions core/dbt/tests/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@
from unittest.mock import patch
from contextlib import contextmanager

# This code was copied from the earlier test framework in test/integration/base.py
# The goal is to vastly simplify this and replace it with calls to macros.
# For now, we use this to get the tests converted in a more straightforward way.
# Assertions:
# assert_tables_equal (old: assertTablesEqual)
# assert_many_relations_equal (old: assertManyRelationsEqual)
# assert_many_tables_equal (old: assertManyTablesEqual)
# assert_table_does_not_exist (old: assertTableDoesNotExist)
# assert_table_does_exist (old: assertTableDoesExist)


class TableComparison:
def __init__(self, adapter, unique_schema, database):
Expand All @@ -15,13 +25,7 @@ def __init__(self, adapter, unique_schema, database):
else:
self.quoting = {"database": False, "schema": False, "identifier": False}

def _assert_tables_equal_sql(self, relation_a, relation_b, columns=None):
if columns is None:
columns = self.get_relation_columns(relation_a)
column_names = [c[0] for c in columns]
sql = self.adapter.get_rows_different_sql(relation_a, relation_b, column_names)
return sql

# assertion used in tests
def assert_tables_equal(
self,
table_a,
Expand Down Expand Up @@ -54,35 +58,7 @@ def assert_tables_equal(
assert result[0] == 0, "row_count_difference nonzero: " + sql
assert result[1] == 0, "num_mismatched nonzero: " + sql

def _make_relation(self, identifier, schema=None, database=None):
if schema is None:
schema = self.unique_schema
if database is None:
database = self.default_database
return self.adapter.Relation.create(
database=database, schema=schema, identifier=identifier, quote_policy=self.quoting
)

def get_many_relation_columns(self, relations):
"""Returns a dict of (datbase, schema) -> (dict of (table_name -> list of columns))"""
schema_fqns = {}
for rel in relations:
this_schema = schema_fqns.setdefault((rel.database, rel.schema), [])
this_schema.append(rel.identifier)

column_specs = {}
for key, tables in schema_fqns.items():
database, schema = key
columns = self.get_many_table_columns(tables, schema, database=database)
table_columns = {}
for col in columns:
table_columns.setdefault(col[0], []).append(col[1:])
for rel_name, columns in table_columns.items():
key = (database, schema, rel_name)
column_specs[key] = columns

return column_specs

# assertion used in tests
def assert_many_relations_equal(self, relations, default_schema=None, default_database=None):
if default_schema is None:
default_schema = self.unique_schema
Expand Down Expand Up @@ -139,6 +115,7 @@ def assert_many_relations_equal(self, relations, default_schema=None, default_da
assert result[0] == 0, "row_count_difference nonzero: " + sql
assert result[1] == 0, "num_mismatched nonzero: " + sql

# assertion used in tests
def assert_many_tables_equal(self, *args):
schema = self.unique_schema

Expand Down Expand Up @@ -172,42 +149,18 @@ def assert_many_tables_equal(self, *args):
assert result[0] == 0, "row_count_difference nonzero: " + sql
assert result[1] == 0, "num_mismatched nonzero: " + sql

def _assert_table_row_counts_equal(self, relation_a, relation_b):
cmp_query = """
with table_a as (
select count(*) as num_rows from {}
), table_b as (
select count(*) as num_rows from {}
)
select table_a.num_rows - table_b.num_rows as difference
from table_a, table_b
""".format(
str(relation_a), str(relation_b)
)

res = run_sql(cmp_query, self.unique_schema, database=self.default_database, fetch="one")

msg = (
f"Row count of table {relation_a.identifier} doesn't match row count of "
f"table {relation_b.identifier}. ({res[0]} rows different"
)
assert int(res[0]) == 0, msg

# assertion used in tests
def assert_table_does_not_exist(self, table, schema=None, database=None):
columns = self.get_table_columns(table, schema, database)
assert len(columns) == 0

# assertion used in tests
def assert_table_does_exist(self, table, schema=None, database=None):
columns = self.get_table_columns(table, schema, database)

assert len(columns) > 0

# called by assert_tables_equal
def _assert_table_columns_equal(self, relation_a, relation_b):
table_a_result = self.get_relation_columns(relation_a)
table_b_result = self.get_relation_columns(relation_b)
Expand All @@ -232,7 +185,6 @@ def _assert_table_columns_equal(self, relation_a, relation_b):
def get_relation_columns(self, relation):
with self.get_connection():
columns = self.adapter.get_columns_in_relation(relation)

return sorted(((c.name, c.dtype, c.char_size) for c in columns), key=lambda x: x[0])

def get_table_columns(self, table, schema=None, database=None):
Expand All @@ -247,6 +199,7 @@ def get_table_columns(self, table, schema=None, database=None):
)
return self.get_relation_columns(relation)

# called by assert_many_table_equal
def get_table_columns_as_dict(self, tables, schema=None):
col_matrix = self.get_many_table_columns(tables, schema)
res = {}
Expand All @@ -263,7 +216,7 @@ def get_table_columns_as_dict(self, tables, schema=None):
def column_schema(self):
return "table_name, column_name, data_type, character_maximum_length"

# This should be overridden for Snowflake
# This should be overridden for Snowflake. Called by get_many_table_columns.
def get_many_table_columns_information_schema(self, tables, schema, database=None):
columns = self.column_schema

Expand Down Expand Up @@ -293,12 +246,7 @@ def get_many_table_columns_information_schema(self, tables, schema, database=Non
columns = run_sql(sql, self.unique_schema, database=self.default_database, fetch="all")
return list(map(self.filter_many_columns, columns))

def get_many_table_columns(self, tables, schema, database=None):
result = self.get_many_table_columns_information_schema(tables, schema, database)
result.sort(key=lambda x: "{}.{}".format(x[0], x[1]))
return result

# Snoflake needs a static char_size
# Snowflake needs a static char_size
def filter_many_columns(self, column):
if len(column) == 3:
table_name, column_name, data_type = column
Expand All @@ -310,8 +258,7 @@ def filter_many_columns(self, column):
@contextmanager
def get_connection(self, name=None):
"""Create a test connection context where all executed macros, etc will
get self.adapter as the adapter.
use the adapter created in the schema fixture.
This allows tests to run normal adapter macros as if reset_adapters()
were not called by handle_and_check (for asserts, etc)
"""
Expand All @@ -322,6 +269,49 @@ def get_connection(self, name=None):
conn = self.adapter.connections.get_thread_connection()
yield conn

def _make_relation(self, identifier, schema=None, database=None):
if schema is None:
schema = self.unique_schema
if database is None:
database = self.default_database
return self.adapter.Relation.create(
database=database, schema=schema, identifier=identifier, quote_policy=self.quoting
)

# called by get_many_relation_columns
def get_many_table_columns(self, tables, schema, database=None):
result = self.get_many_table_columns_information_schema(tables, schema, database)
result.sort(key=lambda x: "{}.{}".format(x[0], x[1]))
return result

# called by assert_many_relations_equal
def get_many_relation_columns(self, relations):
"""Returns a dict of (datbase, schema) -> (dict of (table_name -> list of columns))"""
schema_fqns = {}
for rel in relations:
this_schema = schema_fqns.setdefault((rel.database, rel.schema), [])
this_schema.append(rel.identifier)

column_specs = {}
for key, tables in schema_fqns.items():
database, schema = key
columns = self.get_many_table_columns(tables, schema, database=database)
table_columns = {}
for col in columns:
table_columns.setdefault(col[0], []).append(col[1:])
for rel_name, columns in table_columns.items():
key = (database, schema, rel_name)
column_specs[key] = columns

return column_specs

def _assert_tables_equal_sql(self, relation_a, relation_b, columns=None):
if columns is None:
columns = self.get_relation_columns(relation_a)
column_names = [c[0] for c in columns]
sql = self.adapter.get_rows_different_sql(relation_a, relation_b, column_names)
return sql


# needs overriding for presto
def _ilike(target, value):
Expand Down
10 changes: 3 additions & 7 deletions tests/functional/basic/test_simple_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
from dbt.tests.util import run_dbt, copy_file
from dbt.tests.tables import TableComparison, get_tables_in_schema
from dbt.adapters.factory import get_adapter_by_type


ephemeral_copy_sql = """
Expand Down Expand Up @@ -174,9 +173,8 @@ def test_simple_reference(project):
results = run_dbt()
assert len(results) == 8

adapter = get_adapter_by_type("postgres")
table_comp = TableComparison(
adapter=adapter, unique_schema=project.test_schema, database=project.database
adapter=project.adapter, unique_schema=project.test_schema, database=project.database
)

# Copies should match
Expand Down Expand Up @@ -227,9 +225,8 @@ def test_simple_reference_with_models_and_children(project):
results = run_dbt(["run", "--models", "materialized_copy+", "ephemeral_copy+"])
assert len(results) == 3

adapter = get_adapter_by_type("postgres")
table_comp = TableComparison(
adapter=adapter, unique_schema=project.test_schema, database=project.database
adapter=project.adapter, unique_schema=project.test_schema, database=project.database
)

# Copies should match
Expand Down Expand Up @@ -268,9 +265,8 @@ def test_simple_ref_with_models(project):
assert len(results) == 1

# Copies should match
adapter = get_adapter_by_type("postgres")
table_comp = TableComparison(
adapter=adapter, unique_schema=project.test_schema, database=project.database
adapter=project.adapter, unique_schema=project.test_schema, database=project.database
)
table_comp.assert_tables_equal("users", "materialized_copy")

Expand Down

0 comments on commit 30b5707

Please sign in to comment.