From 30b570795165fdd4bc71e25716e7c832f6fc56b3 Mon Sep 17 00:00:00 2001 From: Gerda Shank Date: Sat, 19 Feb 2022 18:43:06 -0500 Subject: [PATCH] Reorganize dbt.tests.tables. Cleanup adapter handling --- core/dbt/tests/fixtures/project.py | 15 +- core/dbt/tests/tables.py | 136 ++++++++---------- .../functional/basic/test_simple_reference.py | 10 +- 3 files changed, 75 insertions(+), 86 deletions(-) diff --git a/core/dbt/tests/fixtures/project.py b/core/dbt/tests/fixtures/project.py index 3f83cd0a6df..b31777e6bde 100644 --- a/core/dbt/tests/fixtures/project.py +++ b/core/dbt/tests/fixtures/project.py @@ -6,7 +6,7 @@ import dbt.flags as flags from dbt.config.runtime import RuntimeConfig -from dbt.adapters.factory import get_adapter, register_adapter +from dbt.adapters.factory import get_adapter, register_adapter, reset_adapters from dbt.events.functions import setup_event_logger import yaml @@ -151,19 +151,22 @@ def selectors_yml(project_root, selectors): @pytest.fixture -def schema(unique_schema, project_root, profiles_root): +def schema(unique_schema, project_root, profiles_root, profiles_yml, dbt_project_yml): # Dummy args just to get adapter up and running - args = Namespace(profiles_dir=str(profiles_root), project_dir=str(project_root)) + # The profiles.yml and dbt_project.yml should already be written out + args = Namespace( + profiles_dir=str(profiles_root), project_dir=str(project_root), target=None, profile=None + ) runtime_config = RuntimeConfig.from_args(args) register_adapter(runtime_config) adapter = get_adapter(runtime_config) - # execute(adapter, "drop schema if exists {} cascade".format(unique_schema)) + execute(adapter, "drop schema if exists {} cascade".format(unique_schema)) execute(adapter, "create schema {}".format(unique_schema)) yield adapter - adapter = get_adapter(runtime_config) - # adapter.cleanup_connections() execute(adapter, "drop schema if exists {} cascade".format(unique_schema)) + adapter.cleanup_connections() + reset_adapters() def execute(adapter, sql, connection_name="__test"): diff --git a/core/dbt/tests/tables.py b/core/dbt/tests/tables.py index 66ba1ada07e..6a0a690daa1 100644 --- a/core/dbt/tests/tables.py +++ b/core/dbt/tests/tables.py @@ -3,6 +3,16 @@ from unittest.mock import patch from contextlib import contextmanager +# This code was copied from the earlier test framework in test/integration/base.py +# The goal is to vastly simplify this and replace it with calls to macros. +# For now, we use this to get the tests converted in a more straightforward way. +# Assertions: +# assert_tables_equal (old: assertTablesEqual) +# assert_many_relations_equal (old: assertManyRelationsEqual) +# assert_many_tables_equal (old: assertManyTablesEqual) +# assert_table_does_not_exist (old: assertTableDoesNotExist) +# assert_table_does_exist (old: assertTableDoesExist) + class TableComparison: def __init__(self, adapter, unique_schema, database): @@ -15,13 +25,7 @@ def __init__(self, adapter, unique_schema, database): else: self.quoting = {"database": False, "schema": False, "identifier": False} - def _assert_tables_equal_sql(self, relation_a, relation_b, columns=None): - if columns is None: - columns = self.get_relation_columns(relation_a) - column_names = [c[0] for c in columns] - sql = self.adapter.get_rows_different_sql(relation_a, relation_b, column_names) - return sql - + # assertion used in tests def assert_tables_equal( self, table_a, @@ -54,35 +58,7 @@ def assert_tables_equal( assert result[0] == 0, "row_count_difference nonzero: " + sql assert result[1] == 0, "num_mismatched nonzero: " + sql - def _make_relation(self, identifier, schema=None, database=None): - if schema is None: - schema = self.unique_schema - if database is None: - database = self.default_database - return self.adapter.Relation.create( - database=database, schema=schema, identifier=identifier, quote_policy=self.quoting - ) - - def get_many_relation_columns(self, relations): - """Returns a dict of (datbase, schema) -> (dict of (table_name -> list of columns))""" - schema_fqns = {} - for rel in relations: - this_schema = schema_fqns.setdefault((rel.database, rel.schema), []) - this_schema.append(rel.identifier) - - column_specs = {} - for key, tables in schema_fqns.items(): - database, schema = key - columns = self.get_many_table_columns(tables, schema, database=database) - table_columns = {} - for col in columns: - table_columns.setdefault(col[0], []).append(col[1:]) - for rel_name, columns in table_columns.items(): - key = (database, schema, rel_name) - column_specs[key] = columns - - return column_specs - + # assertion used in tests def assert_many_relations_equal(self, relations, default_schema=None, default_database=None): if default_schema is None: default_schema = self.unique_schema @@ -139,6 +115,7 @@ def assert_many_relations_equal(self, relations, default_schema=None, default_da assert result[0] == 0, "row_count_difference nonzero: " + sql assert result[1] == 0, "num_mismatched nonzero: " + sql + # assertion used in tests def assert_many_tables_equal(self, *args): schema = self.unique_schema @@ -172,42 +149,18 @@ def assert_many_tables_equal(self, *args): assert result[0] == 0, "row_count_difference nonzero: " + sql assert result[1] == 0, "num_mismatched nonzero: " + sql - def _assert_table_row_counts_equal(self, relation_a, relation_b): - cmp_query = """ - with table_a as ( - - select count(*) as num_rows from {} - - ), table_b as ( - - select count(*) as num_rows from {} - - ) - - select table_a.num_rows - table_b.num_rows as difference - from table_a, table_b - - """.format( - str(relation_a), str(relation_b) - ) - - res = run_sql(cmp_query, self.unique_schema, database=self.default_database, fetch="one") - - msg = ( - f"Row count of table {relation_a.identifier} doesn't match row count of " - f"table {relation_b.identifier}. ({res[0]} rows different" - ) - assert int(res[0]) == 0, msg - + # assertion used in tests def assert_table_does_not_exist(self, table, schema=None, database=None): columns = self.get_table_columns(table, schema, database) assert len(columns) == 0 + # assertion used in tests def assert_table_does_exist(self, table, schema=None, database=None): columns = self.get_table_columns(table, schema, database) assert len(columns) > 0 + # called by assert_tables_equal def _assert_table_columns_equal(self, relation_a, relation_b): table_a_result = self.get_relation_columns(relation_a) table_b_result = self.get_relation_columns(relation_b) @@ -232,7 +185,6 @@ def _assert_table_columns_equal(self, relation_a, relation_b): def get_relation_columns(self, relation): with self.get_connection(): columns = self.adapter.get_columns_in_relation(relation) - return sorted(((c.name, c.dtype, c.char_size) for c in columns), key=lambda x: x[0]) def get_table_columns(self, table, schema=None, database=None): @@ -247,6 +199,7 @@ def get_table_columns(self, table, schema=None, database=None): ) return self.get_relation_columns(relation) + # called by assert_many_table_equal def get_table_columns_as_dict(self, tables, schema=None): col_matrix = self.get_many_table_columns(tables, schema) res = {} @@ -263,7 +216,7 @@ def get_table_columns_as_dict(self, tables, schema=None): def column_schema(self): return "table_name, column_name, data_type, character_maximum_length" - # This should be overridden for Snowflake + # This should be overridden for Snowflake. Called by get_many_table_columns. def get_many_table_columns_information_schema(self, tables, schema, database=None): columns = self.column_schema @@ -293,12 +246,7 @@ def get_many_table_columns_information_schema(self, tables, schema, database=Non columns = run_sql(sql, self.unique_schema, database=self.default_database, fetch="all") return list(map(self.filter_many_columns, columns)) - def get_many_table_columns(self, tables, schema, database=None): - result = self.get_many_table_columns_information_schema(tables, schema, database) - result.sort(key=lambda x: "{}.{}".format(x[0], x[1])) - return result - - # Snoflake needs a static char_size + # Snowflake needs a static char_size def filter_many_columns(self, column): if len(column) == 3: table_name, column_name, data_type = column @@ -310,8 +258,7 @@ def filter_many_columns(self, column): @contextmanager def get_connection(self, name=None): """Create a test connection context where all executed macros, etc will - get self.adapter as the adapter. - + use the adapter created in the schema fixture. This allows tests to run normal adapter macros as if reset_adapters() were not called by handle_and_check (for asserts, etc) """ @@ -322,6 +269,49 @@ def get_connection(self, name=None): conn = self.adapter.connections.get_thread_connection() yield conn + def _make_relation(self, identifier, schema=None, database=None): + if schema is None: + schema = self.unique_schema + if database is None: + database = self.default_database + return self.adapter.Relation.create( + database=database, schema=schema, identifier=identifier, quote_policy=self.quoting + ) + + # called by get_many_relation_columns + def get_many_table_columns(self, tables, schema, database=None): + result = self.get_many_table_columns_information_schema(tables, schema, database) + result.sort(key=lambda x: "{}.{}".format(x[0], x[1])) + return result + + # called by assert_many_relations_equal + def get_many_relation_columns(self, relations): + """Returns a dict of (datbase, schema) -> (dict of (table_name -> list of columns))""" + schema_fqns = {} + for rel in relations: + this_schema = schema_fqns.setdefault((rel.database, rel.schema), []) + this_schema.append(rel.identifier) + + column_specs = {} + for key, tables in schema_fqns.items(): + database, schema = key + columns = self.get_many_table_columns(tables, schema, database=database) + table_columns = {} + for col in columns: + table_columns.setdefault(col[0], []).append(col[1:]) + for rel_name, columns in table_columns.items(): + key = (database, schema, rel_name) + column_specs[key] = columns + + return column_specs + + def _assert_tables_equal_sql(self, relation_a, relation_b, columns=None): + if columns is None: + columns = self.get_relation_columns(relation_a) + column_names = [c[0] for c in columns] + sql = self.adapter.get_rows_different_sql(relation_a, relation_b, column_names) + return sql + # needs overriding for presto def _ilike(target, value): diff --git a/tests/functional/basic/test_simple_reference.py b/tests/functional/basic/test_simple_reference.py index 030df491616..41dcba864ed 100644 --- a/tests/functional/basic/test_simple_reference.py +++ b/tests/functional/basic/test_simple_reference.py @@ -2,7 +2,6 @@ import os from dbt.tests.util import run_dbt, copy_file from dbt.tests.tables import TableComparison, get_tables_in_schema -from dbt.adapters.factory import get_adapter_by_type ephemeral_copy_sql = """ @@ -174,9 +173,8 @@ def test_simple_reference(project): results = run_dbt() assert len(results) == 8 - adapter = get_adapter_by_type("postgres") table_comp = TableComparison( - adapter=adapter, unique_schema=project.test_schema, database=project.database + adapter=project.adapter, unique_schema=project.test_schema, database=project.database ) # Copies should match @@ -227,9 +225,8 @@ def test_simple_reference_with_models_and_children(project): results = run_dbt(["run", "--models", "materialized_copy+", "ephemeral_copy+"]) assert len(results) == 3 - adapter = get_adapter_by_type("postgres") table_comp = TableComparison( - adapter=adapter, unique_schema=project.test_schema, database=project.database + adapter=project.adapter, unique_schema=project.test_schema, database=project.database ) # Copies should match @@ -268,9 +265,8 @@ def test_simple_ref_with_models(project): assert len(results) == 1 # Copies should match - adapter = get_adapter_by_type("postgres") table_comp = TableComparison( - adapter=adapter, unique_schema=project.test_schema, database=project.database + adapter=project.adapter, unique_schema=project.test_schema, database=project.database ) table_comp.assert_tables_equal("users", "materialized_copy")