From b1dbd3968e6e3155d283302121bf18e677833823 Mon Sep 17 00:00:00 2001 From: Kyle Speer <54034650+kspeer825@users.noreply.github.com> Date: Fri, 5 Feb 2021 11:36:19 -0500 Subject: [PATCH] Move tests fix config (#48) * run-test not run-a-test * specify env in project dir * use v4 image * remove SCENARIOS usage * rename *_test_*.py -> test_*.py for auto fields, discoveries, bookmmarks * rename remaining tests * add context user, fix db imports tests * fix base imports in tests * fix the test_run call so we only run each test once * fix run sync method * auto fields test passing * update feature tests Co-authored-by: Kyle Speer --- .circleci/config.yml | 30 +- tests/base.py | 260 +++++++ tests/database.py | 844 +++++++++++++++++++++++ tests/spec.py | 155 +++++ tests/test_automatic_fields.py | 223 ++++++ tests/test_bookmarks.py | 299 ++++++++ tests/test_discovery_data_types.py | 370 ++++++++++ tests/test_discovery_multiple_dbs.py | 272 ++++++++ tests/test_discovery_names.py | 264 +++++++ tests/test_discovery_pks.py | 280 ++++++++ tests/test_discovery_unsupported_pks.py | 59 ++ tests/test_full_replication.py | 259 +++++++ tests/test_pagination.py | 68 ++ tests/test_saas_stream.py | 186 +++++ tests/test_start_date.py | 123 ++++ tests/test_sync_full.py | 722 +++++++++++++++++++ tests/test_sync_full_datetime.py | 312 +++++++++ tests/test_sync_full_decimal.py | 424 ++++++++++++ tests/test_sync_full_float.py | 226 ++++++ tests/test_sync_full_integers.py | 360 ++++++++++ tests/test_sync_full_multiple_dbs.py | 391 +++++++++++ tests/test_sync_full_names.py | 394 +++++++++++ tests/test_sync_full_others.py | 340 +++++++++ tests/test_sync_full_pks.py | 587 ++++++++++++++++ tests/test_sync_full_strings.py | 394 +++++++++++ tests/test_sync_incremental_datetime.py | 537 +++++++++++++++ tests/test_sync_incremental_decimal.py | 518 ++++++++++++++ tests/test_sync_incremental_float.py | 353 ++++++++++ tests/test_sync_incremental_integers.py | 399 +++++++++++ tests/test_sync_incremental_others.py | 572 ++++++++++++++++ tests/test_sync_incremental_pks.py | 805 ++++++++++++++++++++++ tests/test_sync_logical_datetime.py | 550 +++++++++++++++ tests/test_sync_logical_decimal.py | 555 +++++++++++++++ tests/test_sync_logical_float.py | 414 +++++++++++ tests/test_sync_logical_integers.py | 450 ++++++++++++ tests/test_sync_logical_multiple_dbs.py | 652 ++++++++++++++++++ tests/test_sync_logical_names.py | 607 ++++++++++++++++ tests/test_sync_logical_others.py | 568 +++++++++++++++ tests/test_sync_logical_pks.py | 875 ++++++++++++++++++++++++ 39 files changed, 15685 insertions(+), 12 deletions(-) create mode 100644 tests/base.py create mode 100644 tests/database.py create mode 100644 tests/spec.py create mode 100644 tests/test_automatic_fields.py create mode 100644 tests/test_bookmarks.py create mode 100644 tests/test_discovery_data_types.py create mode 100644 tests/test_discovery_multiple_dbs.py create mode 100644 tests/test_discovery_names.py create mode 100644 tests/test_discovery_pks.py create mode 100644 tests/test_discovery_unsupported_pks.py create mode 100644 tests/test_full_replication.py create mode 100644 tests/test_pagination.py create mode 100644 tests/test_saas_stream.py create mode 100644 tests/test_start_date.py create mode 100644 tests/test_sync_full.py create mode 100644 tests/test_sync_full_datetime.py create mode 100644 tests/test_sync_full_decimal.py create mode 100644 tests/test_sync_full_float.py create mode 100644 tests/test_sync_full_integers.py create mode 100644 tests/test_sync_full_multiple_dbs.py create mode 100644 tests/test_sync_full_names.py create mode 100644 tests/test_sync_full_others.py create mode 100644 tests/test_sync_full_pks.py create mode 100644 tests/test_sync_full_strings.py create mode 100644 tests/test_sync_incremental_datetime.py create mode 100644 tests/test_sync_incremental_decimal.py create mode 100644 tests/test_sync_incremental_float.py create mode 100644 tests/test_sync_incremental_integers.py create mode 100644 tests/test_sync_incremental_others.py create mode 100644 tests/test_sync_incremental_pks.py create mode 100644 tests/test_sync_logical_datetime.py create mode 100644 tests/test_sync_logical_decimal.py create mode 100644 tests/test_sync_logical_float.py create mode 100644 tests/test_sync_logical_integers.py create mode 100644 tests/test_sync_logical_multiple_dbs.py create mode 100644 tests/test_sync_logical_names.py create mode 100644 tests/test_sync_logical_others.py create mode 100644 tests/test_sync_logical_pks.py diff --git a/.circleci/config.yml b/.circleci/config.yml index c1bad05..e5a843e 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -31,7 +31,7 @@ jobs: - /root/.m2 tap_tester: docker: - - image: 218546966473.dkr.ecr.us-east-1.amazonaws.com/circle-ci:tap-tester-clj + - image: 218546966473.dkr.ecr.us-east-1.amazonaws.com/circle-ci:tap-tester-clj-v4 - image: mcr.microsoft.com/mssql/server:2017-latest environment: ACCEPT_EULA: Y @@ -51,26 +51,28 @@ jobs: - run: name: 'Tap Tester' command: | - aws s3 cp s3://com-stitchdata-dev-deployment-assets/environments/tap-tester/sandbox tap-tester.env - source tap-tester.env + cd /root/project + aws s3 cp s3://com-stitchdata-dev-deployment-assets/environments/tap-tester/sandbox dev_env.sh + source dev_env.sh aws s3 cp s3://com-stitchdata-dev-deployment-assets/environments/tap-mssql/sandbox tap-mssql.env source tap-mssql.env - source /usr/local/share/virtualenvs/tap-tester/bin/activate - cd /root/project/ lein deps - run-a-test --tap=/root/project/bin/tap-mssql \ - --target=target-stitch \ - --orchestrator=stitch-orchestrator \ - --email=harrison+sandboxtest@stitchdata.com \ - --password=$SANDBOX_PASSWORD \ - --client-id=50 \ - tap_tester.suites.mssql + source /usr/local/share/virtualenvs/tap-tester/bin/activate + run-test --tap=/root/project/bin/tap-mssql \ + --target=target-stitch \ + --orchestrator=stitch-orchestrator \ + --email=harrison+sandboxtest@stitchdata.com \ + --password=$SANDBOX_PASSWORD \ + --client-id=50 \ + tests + workflows: version: 2 build_and_test: jobs: - build - tap_tester: + context: circleci-user requires: - build build_daily: @@ -83,3 +85,7 @@ workflows: - master jobs: - build + - tap_tester: + context: circleci-user + requires: + - build diff --git a/tests/base.py b/tests/base.py new file mode 100644 index 0000000..795ed03 --- /dev/null +++ b/tests/base.py @@ -0,0 +1,260 @@ +""" +Setup expectations for test sub classes +Run discovery for as a prerequisite for most tests +""" +import unittest +import os +from datetime import datetime as dt +from datetime import timezone as tz + +from tap_tester import connections, menagerie, runner + +from spec import TapSpec + + +class BaseTapTest(TapSpec, unittest.TestCase): + """ + Setup expectations for test sub classes + Run discovery for as a prerequisite for most tests + """ + + @staticmethod + def name(): + """The name of the test within the suite""" + return "tap_tester_{}".format(TapSpec.tap_name()) + + def environment_variables(self): + return ({p for p in self.CONFIGURATION_ENVIRONMENT['properties'].values()} | + {c for c in self.CONFIGURATION_ENVIRONMENT['credentials'].values()}) + + def expected_streams(self): + """A set of expected stream ids""" + return set(self.expected_metadata().keys()) + + def child_streams(self): + """ + Return a set of streams that are child streams + based on having foreign key metadata + """ + return {stream for stream, metadata in self.expected_metadata().items() + if metadata.get(self.FOREIGN_KEYS)} + + def expected_primary_keys_by_stream_id(self): + """ + return a dictionary with key of table name (stream_id) + and value as a set of primary key fields + """ + return {table: properties.get(self.PRIMARY_KEYS, set()) + for table, properties in self.expected_metadata().items()} + + def expected_replication_keys(self): + """ + return a dictionary with key of table name + and value as a set of replication key fields + """ + return {table: properties.get(self.REPLICATION_KEYS, set()) + for table, properties + in self.expected_metadata().items()} + + def add_expected_metadata(cls, database_name, schema_name, table_name, column_name, + column_type, primary_key, values=None, view: bool = False): + column_metadata = cls.expected_column_metadata(cls, column_name, column_type, primary_key) + + data = { + cls.DATABASE_NAME: database_name, + cls.SCHEMA: "public" if schema_name is None else schema_name, + cls.STREAM: table_name, + cls.VIEW: view, + cls.PRIMARY_KEYS: primary_key, + cls.ROWS: 0, + cls.SELECTED: None, + cls.FIELDS: column_metadata, + cls.VALUES: values + } + + cls.EXPECTED_METADATA["{}_{}_{}".format( + data[cls.DATABASE_NAME], + data[cls.SCHEMA], + data[cls.STREAM])] = data + + def expected_column_metadata(self, column_name, column_type, primary_key): + column_metadata = [{x[0]: {self.DATATYPE: x[1]}} for x in list(zip(column_name, column_type))] + # primary keys have inclusion of automatic + for field in column_metadata: + if set(field.keys()).intersection(primary_key): + field[list(field.keys())[0]][self.INCLUSION] = self.AUTOMATIC_FIELDS + field[list(field.keys())[0]][self.DEFAULT_SELECT] = True + + # other fields are available if supported otherwise unavailable (unsupported) + for field in column_metadata: + if not set(field.keys()).intersection(primary_key): + if field[list(field.keys())[0]][self.DATATYPE] in self.SUPPORTED_DATATYPES: + field[list(field.keys())[0]][self.INCLUSION] = self.AVAILABLE_FIELDS + field[list(field.keys())[0]][self.DEFAULT_SELECT] = True + else: + field[list(field.keys())[0]][self.INCLUSION] = self.UNAVAILABLE_FIELDS + field[list(field.keys())[0]][self.DEFAULT_SELECT] = False + + # float's and real's don't keep there precision and 24 or less floats are actually reals + for field in column_metadata: + datatype = field[list(field.keys())[0]][self.DATATYPE] + index = datatype.find("(") + if index > -1: + if datatype[:index] == "float": + if int(datatype[index+1:-1]) <= 24: + field[list(field.keys())[0]][self.DATATYPE] = "real" + else: + field[list(field.keys())[0]][self.DATATYPE] = "float" + + # rowversion shows up as type timestamp, they are synonyms + for field in column_metadata: + datatype = field[list(field.keys())[0]][self.DATATYPE] + if datatype == "rowversion": + field[list(field.keys())[0]][self.DATATYPE] = "timestamp" + + # TODO - BUG - Remove this if we determine sql-datatypes should include precision/scale + for field in column_metadata: + datatype = field[list(field.keys())[0]][self.DATATYPE] + index = datatype.find("(") + if index > -1: # and "numeric" not in datatype and "decimal" not in datatype: + field[list(field.keys())[0]][self.DATATYPE] = datatype[:index] + return column_metadata + + def expected_foreign_keys(self): + """ + return a dictionary with key of table name + and value as a set of foreign key fields + """ + return {table: properties.get(self.FOREIGN_KEYS, set()) + for table, properties + in self.expected_metadata().items()} + + def expected_replication_method(self): + """return a dictionary with key of table name nd value of replication method""" + return {table: properties.get(self.REPLICATION_METHOD, None) + for table, properties + in self.expected_metadata().items()} + + def setUp(self): + """Verify that you have set the prerequisites to run the tap (creds, etc.)""" + missing_envs = [x for x in self.environment_variables() if os.getenv(x) is None] + if missing_envs: + raise Exception("Missing test-required environment variables: {}".format(missing_envs)) + + ######################### + # Helper Methods # + ######################### + + def create_connection(self, original_properties: bool = True): + """Create a new connection with the test name""" + + # Create the connection + conn_id = connections.ensure_connection(self, original_properties) + + # Run a check job using orchestrator (discovery) + check_job_name = runner.run_check_mode(self, conn_id) + + # Assert that the check job succeeded + exit_status = menagerie.get_exit_status(conn_id, check_job_name) + menagerie.verify_check_exit_status(self, exit_status, check_job_name) + return conn_id + + def run_sync(self, conn_id): + """ + Run a sync job and make sure it exited properly. + Return a dictionary with keys of streams synced + and values of records synced for each stream + """ + # Run a sync job using orchestrator + sync_job_name = runner.run_sync_mode(self, conn_id) + + # Verify tap and target exit codes + exit_status = menagerie.get_exit_status(conn_id, sync_job_name) + menagerie.verify_sync_exit_status(self, exit_status, sync_job_name) + + # Verify actual rows were synced + sync_record_count = runner.examine_target_output_file( + self, conn_id, self.expected_streams(), self.expected_primary_keys_by_stream_id()) + return sync_record_count + + @staticmethod + def local_to_utc(date: dt): + """Convert a datetime with timezone information to utc""" + utc = dt(date.year, date.month, date.day, date.hour, date.minute, + date.second, date.microsecond, tz.utc) + + if date.tzinfo and hasattr(date.tzinfo, "_offset"): + utc += date.tzinfo._offset + + return utc + + def max_bookmarks_by_stream(self, sync_records): + """ + Return the maximum value for the replication key for each stream + which is the bookmark expected value. + + Comparisons are based on the class of the bookmark value. Dates will be + string compared which works for ISO date-time strings + """ + max_bookmarks = {} + for stream, batch in sync_records.items(): + + upsert_messages = [m for m in batch.get('messages') if m['action'] == 'upsert'] + stream_bookmark_key = self.expected_replication_keys().get(stream, set()) + assert len(stream_bookmark_key) == 1 # There shouldn't be a compound replication key + stream_bookmark_key = stream_bookmark_key.pop() + + bk_values = [message["data"].get(stream_bookmark_key) for message in upsert_messages] + max_bookmarks[stream] = {stream_bookmark_key: None} + for bk_value in bk_values: + if bk_value is None: + continue + + if max_bookmarks[stream][stream_bookmark_key] is None: + max_bookmarks[stream][stream_bookmark_key] = bk_value + + if bk_value > max_bookmarks[stream][stream_bookmark_key]: + max_bookmarks[stream][stream_bookmark_key] = bk_value + return max_bookmarks + + def min_bookmarks_by_stream(self, sync_records): + """Return the minimum value for the replication key for each stream""" + min_bookmarks = {} + for stream, batch in sync_records.items(): + + upsert_messages = [m for m in batch.get('messages') if m['action'] == 'upsert'] + stream_bookmark_key = self.expected_replication_keys().get(stream, set()) + assert len(stream_bookmark_key) == 1 # There shouldn't be a compound replication key + (stream_bookmark_key, ) = stream_bookmark_key + + bk_values = [message["data"].get(stream_bookmark_key) for message in upsert_messages] + min_bookmarks[stream] = {stream_bookmark_key: None} + for bk_value in bk_values: + if bk_value is None: + continue + + if min_bookmarks[stream][stream_bookmark_key] is None: + min_bookmarks[stream][stream_bookmark_key] = bk_value + + if bk_value < min_bookmarks[stream][stream_bookmark_key]: + min_bookmarks[stream][stream_bookmark_key] = bk_value + return min_bookmarks + + @staticmethod + def select_all_streams_and_fields(conn_id, catalogs, select_all_fields: bool = True, + additional_md=[], non_selected_properties=[]): + """Select all streams and all fields within streams""" + for catalog in catalogs: + schema = menagerie.get_annotated_schema(conn_id, catalog['stream_id']) + + if not select_all_fields and not non_selected_properties: + # get a list of all properties so that none are selected + non_selected_properties = schema.get('annotated-schema', {}).get( + 'properties', {}).keys() + + connections.select_catalog_and_fields_via_metadata( + conn_id, catalog, schema, additional_md, non_selected_properties) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.start_date = self.get_properties().get("start_date") diff --git a/tests/database.py b/tests/database.py new file mode 100644 index 0000000..09be739 --- /dev/null +++ b/tests/database.py @@ -0,0 +1,844 @@ +import os +import socket +from pprint import pprint +from random import randint, sample + +import pyodbc + +USERNAME = os.getenv("STITCH_TAP_MSSQL_TEST_DATABASE_USER") +PASSWORD = os.getenv("STITCH_TAP_MSSQL_TEST_DATABASE_PASSWORD") +HOST = "localhost" + +LOWER_ALPHAS, UPPER_ALPHAS, DIGITS, OTHERS = set(), set(), set(), set() +for letter in range(97, 123): + LOWER_ALPHAS.add(chr(letter)) + +for letter in range(65, 91): + UPPER_ALPHAS.add(chr(letter)) + +for digit in range(48, 58): + DIGITS.add(chr(digit)) + +for invalid in set().union(range(32, 48), range(58, 65), range(91, 97), range(123, 127)): + OTHERS.add(chr(invalid)) + +ALPHA_NUMERIC = LOWER_ALPHAS.union(UPPER_ALPHAS).union(DIGITS) + + +def mssql_cursor_context_manager(*args): + """Decorator to switch into the iFrame before the method and switch back out after""" + + server = "{},1433".format(HOST) + database = "master" + connection_string = ( + "DRIVER={{ODBC Driver 17 for SQL Server}}" + ";SERVER={};DATABASE={};UID={};PWD={}".format( + server, database, USERNAME, PASSWORD)) + + print(connection_string.replace(PASSWORD, "[REDACTED]")) + connection = pyodbc.connect(connection_string, autocommit=True) + # https://github.com/mkleehammer/pyodbc/wiki/Unicode#configuring-specific-databases + # connection.setdecoding(pyodbc.SQL_CHAR, encoding='utf-8') + # connection.setdecoding(pyodbc.SQL_WCHAR, encoding='utf-8') + # connection.setencoding(encoding='utf-8') + # connection.add_output_converter(-155, handle_datetimeoffset) + + with connection.cursor() as cursor: + for q in args: + print(q) + if isinstance(q, tuple): + cursor.executemany(*q) + else: + cursor.execute(q) + try: + results = cursor.fetchall() + except pyodbc.ProgrammingError: + results = None + + connection.close() + + return results + + +def drop_all_user_databases(): + """ + Drop all user databases. Please run the PRINT first and make sure you're not dropping anything you may regret. You may want to take backups of all databases first just in case. + + DECLARE @sql NVARCHAR(MAX) = N''; + + SELECT @sql += N' + DROP DATABASE ' + QUOTENAME(name) + + N';' + FROM sys.databases + WHERE name NOT IN (N'master',N'tempdb',N'model',N'msdb'); + + PRINT @sql; + -- EXEC master.sys.sp_executesql @sql; + :return: + """ + query_list = [ + # "DECLARE @sql NVARCHAR(MAX) = N'';" + "SELECT N'DROP DATABASE ' + QUOTENAME(name) + N';' " + "FROM sys.databases " + "WHERE name NOT IN (N'master',N'tempdb',N'model',N'msdb',N'rdsadmin');" + # "PRINT @sql;" + # "EXEC master.sys.sp_executesql @sql;" + ] + + results = mssql_cursor_context_manager(*query_list) + query_list = [x[0] for x in results] + if query_list: + mssql_cursor_context_manager(*query_list) + + +def drop_database(db_name): + return ["DROP DATABASE IF EXISTS {};".format(db_name)] + + +def create_database(db_name, collation: str = None): + """ + CREATE DATABASE database_name + [ CONTAINMENT = { NONE | PARTIAL } ] + [ ON + [ PRIMARY ] [ ,...n ] + [ , [ ,...n ] ] + [ LOG ON [ ,...n ] ] + ] + [ COLLATE collation_name ] + [ WITH