From e2af871a5adc1b6b0269ef900e04320550ac1bb8 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Mon, 4 Mar 2019 21:08:38 -0700 Subject: [PATCH] per-thread connections parsing now always opens a connection, instead of waiting to need it remove model_name/available_raw/etc --- core/dbt/adapters/base/connections.py | 260 ++++++++---------- core/dbt/adapters/base/impl.py | 148 +++++----- core/dbt/adapters/base/meta.py | 27 +- core/dbt/adapters/factory.py | 4 +- core/dbt/adapters/sql/connections.py | 52 ++-- core/dbt/adapters/sql/impl.py | 81 ++---- core/dbt/clients/system.py | 6 +- core/dbt/compat.py | 4 + core/dbt/context/common.py | 40 +-- core/dbt/context/parser.py | 16 +- core/dbt/context/runtime.py | 5 +- core/dbt/node_runners.py | 44 ++- core/dbt/parser/base.py | 7 +- core/dbt/ssh_forward.py | 10 - core/dbt/task/base.py | 2 +- core/dbt/task/compile.py | 5 +- core/dbt/task/generate.py | 8 +- core/dbt/task/run.py | 42 ++- core/dbt/utils.py | 6 +- .../dbt/adapters/bigquery/connections.py | 65 +++-- .../bigquery/dbt/adapters/bigquery/impl.py | 86 +++--- .../dbt/adapters/postgres/connections.py | 7 +- .../postgres/dbt/adapters/postgres/impl.py | 16 +- .../dbt/adapters/redshift/connections.py | 10 +- .../redshift/dbt/adapters/redshift/impl.py | 7 +- .../dbt/adapters/snowflake/connections.py | 21 +- .../test_graph_selection.py | 27 +- .../test_concurrent_transaction.py | 10 +- .../test_external_reference.py | 9 +- test/integration/base.py | 106 ++++--- test/unit/test_bigquery_adapter.py | 10 +- test/unit/test_graph.py | 5 +- test/unit/test_parser.py | 5 +- test/unit/test_postgres_adapter.py | 36 ++- test/unit/test_redshift_adapter.py | 24 +- test/unit/test_snowflake_adapter.py | 32 ++- test/unit/utils.py | 12 +- 37 files changed, 576 insertions(+), 679 deletions(-) delete mode 100644 core/dbt/ssh_forward.py diff --git a/core/dbt/adapters/base/connections.py b/core/dbt/adapters/base/connections.py index c65e932454e..169f7c7d46a 100644 --- a/core/dbt/adapters/base/connections.py +++ b/core/dbt/adapters/base/connections.py @@ -1,12 +1,13 @@ import abc import multiprocessing +import os import six import dbt.exceptions import dbt.flags from dbt.api import APIObject -from dbt.compat import abstractclassmethod +from dbt.compat import abstractclassmethod, get_ident from dbt.contracts.connection import Connection from dbt.logger import GLOBAL_LOGGER as logger from dbt.utils import translate_aliases @@ -71,6 +72,7 @@ class BaseConnectionManager(object): - open - begin - commit + - clear_transaction - execute You must also set the 'TYPE' class attribute with a class-unique constant @@ -80,83 +82,93 @@ class BaseConnectionManager(object): def __init__(self, profile): self.profile = profile - self.in_use = {} - self.available = [] + self.thread_connections = {} self.lock = multiprocessing.RLock() - self._set_initial_connections() - - def _set_initial_connections(self): - self.available = [] - # set up the array of connections in the 'init' state. - # we add a magic number, 2 because there are overhead connections, - # one for pre- and post-run hooks and other misc operations that occur - # before the run starts, and one for integration tests. - for idx in range(self.profile.threads + 2): - self.available.append(self._empty_connection()) - - def _empty_connection(self): - return Connection( - type=self.TYPE, - name=None, - state='init', - transaction_open=False, - handle=None, - credentials=self.profile.credentials - ) + + @staticmethod + def get_thread_identifier(): + # note that get_ident() may be re-used, but we should never experience + # that within a single process + return (os.getpid(), get_ident()) + + def get_thread_connection(self): + key = self.get_thread_identifier() + with self.lock: + if key not in self.thread_connections: + raise RuntimeError( + 'connection never acquired for thread {}, have {}' + .format(key, list(self.thread_connections)) + ) + return self.thread_connections[key] + + def get_if_exists(self): + key = self.get_thread_identifier() + with self.lock: + return self.thread_connections.get(key) + + def clear_thread_connection(self): + key = self.get_thread_identifier() + with self.lock: + if key in self.thread_connections: + del self.thread_connections[key] + + def clear_transaction(self): + """Clear any existing transactions.""" + conn = self.get_thread_connection() + if conn is not None: + self.begin() + self.commit() @abc.abstractmethod - def exception_handler(self, sql, connection_name='master'): + def exception_handler(self, sql): """Create a context manager that handles exceptions caused by database interactions. :param str sql: The SQL string that the block inside the context manager is executing. - :param str connection_name: The name of the connection being used :return: A context manager that handles exceptions raised by the underlying database. """ raise dbt.exceptions.NotImplementedException( '`exception_handler` is not implemented for this adapter!') - def get(self, name=None): - """This is thread-safe as long as two threads don't use the same - "name". - """ + def set_connection_name(self, name=None): if name is None: # if a name isn't specified, we'll re-use a single handle # named 'master' name = 'master' - with self.lock: - if name in self.in_use: - return self.in_use[name] + conn = self.get_if_exists() + thread_id_key = self.get_thread_identifier() - logger.debug('Acquiring new {} connection "{}".' - .format(self.TYPE, name)) - - if not self.available: - raise dbt.exceptions.InternalException( - 'Tried to request a new connection "{}" but ' - 'the maximum number of connections are already ' - 'allocated!'.format(name) - ) + if conn is None: + conn = Connection( + type=self.TYPE, + name=None, + state='init', + transaction_open=False, + handle=None, + credentials=self.profile.credentials + ) + self.thread_connections[thread_id_key] = conn - connection = self.available.pop() - # connection is temporarily neither in use nor available, but both - # collections are in a sane state, so we can release the lock. + if conn.name == name and conn.state == 'open': + return conn - # this potentially calls open(), but does so without holding the lock - connection = self.assign(connection, name) + logger.debug('Acquiring new {} connection "{}".' + .format(self.TYPE, name)) - with self.lock: - if name in self.in_use: - raise dbt.exceptions.InternalException( - 'Two threads concurrently tried to get the same name: {}' - .format(name) - ) - self.in_use[name] = connection + if conn.state == 'open': + logger.debug( + 'Re-using an available connection from the pool (formerly {}).' + .format(conn.name)) + else: + logger.debug('Opening a new connection, currently in state {}' + .format(conn.state)) + self.open(conn) - return connection + conn.name = name + return conn @abc.abstractmethod def cancel_open(self): @@ -183,81 +195,39 @@ def open(cls, connection): '`open` is not implemented for this adapter!' ) - def assign(self, conn, name): - """Open a connection if it's not already open, and assign it name - regardless. - - The caller is responsible for putting the assigned connection into the - in_use collection. - - :param Connection conn: A connection, in any state. - :param str name: The name of the connection to set. - """ - if name is None: - name = 'master' - - conn.name = name - - if conn.state == 'open': - logger.debug('Re-using an available connection from the pool.') - else: - logger.debug('Opening a new connection, currently in state {}' - .format(conn.state)) - conn = self.open(conn) - - return conn - - def _release_connection(self, conn): - if conn.state == 'open': - if conn.transaction_open is True: - self._rollback(conn) - conn.name = None - else: - self.close(conn) - - def release(self, name): + def release(self): with self.lock: - if name not in self.in_use: + conn = self.get_if_exists() + if conn is None: return - to_release = self.in_use.pop(name) - # to_release is temporarily neither in use nor available, but both - # collections are in a sane state, so we can release the lock. - try: - self._release_connection(to_release) - except: - # if rollback or close failed, replace our busted connection with - # a new one - to_release = self._empty_connection() + if conn.state == 'open': + if conn.transaction_open is True: + self._rollback(conn) + else: + self.close(conn) + except Exception: + # if rollback or close failed, remove our busted connection + self.clear_thread_connection() raise - finally: - # now that this connection has been rolled back and the name reset, - # or the connection has been closed, put it back on the available - # list - with self.lock: - self.available.append(to_release) def cleanup_all(self): with self.lock: - for name, connection in self.in_use.items(): - if connection.state != 'closed': + for connection in self.thread_connections.values(): + if connection.state not in {'closed', 'init'}: logger.debug("Connection '{}' was left open." - .format(name)) + .format(connection.name)) else: logger.debug("Connection '{}' was properly closed." - .format(name)) - - conns_in_use = list(self.in_use.values()) - for conn in conns_in_use + self.available: - self.close(conn) + .format(connection.name)) + self.close(connection) # garbage collect these connections - self.in_use.clear() - self._set_initial_connections() + self.thread_connections.clear() @abc.abstractmethod - def begin(self, name): + def begin(self): """Begin a transaction. (passable) :param str name: The name of the connection to use. @@ -266,34 +236,32 @@ def begin(self, name): '`begin` is not implemented for this adapter!' ) - def get_if_exists(self, name): - if name is None: - name = 'master' - - if self.in_use.get(name) is None: - return - - return self.get(name) - @abc.abstractmethod - def commit(self, connection): - """Commit a transaction. (passable) - - :param str name: The name of the connection to use. - """ + def commit(self): + """Commit a transaction. (passable)""" raise dbt.exceptions.NotImplementedException( '`commit` is not implemented for this adapter!' ) - def _rollback_handle(self, connection): + @classmethod + def _rollback_handle(cls, connection): """Perform the actual rollback operation.""" connection.handle.rollback() - def _rollback(self, connection): - """Roll back the given connection. + @classmethod + def _close_handle(cls, connection): + """Perform the actual close operation.""" + # On windows, sometimes connection handles don't have a close() attr. + if hasattr(connection.handle, 'close'): + logger.debug('On {}: Close'.format(connection.name)) + connection.handle.close() + else: + logger.debug('On {}: No close available on handle' + .format(connection.name)) - The connection does not have to be in in_use or available, so this - operation does not require the lock. + @classmethod + def _rollback(cls, connection): + """Roll back the given connection. """ if dbt.flags.STRICT_MODE: assert isinstance(connection, Connection) @@ -304,7 +272,7 @@ def _rollback(self, connection): 'it does not have one open!'.format(connection.name)) logger.debug('On {}: ROLLBACK'.format(connection.name)) - self._rollback_handle(connection) + cls._rollback_handle(connection) connection.transaction_open = False @@ -320,40 +288,28 @@ def close(cls, connection): return connection if connection.transaction_open and connection.handle: - connection.handle.rollback() + cls._rollback_handle(connection) connection.transaction_open = False - # On windows, sometimes connection handles don't have a close() attr. - if hasattr(connection.handle, 'close'): - connection.handle.close() - else: - logger.debug('On {}: No close available on handle' - .format(connection.name)) - + cls._close_handle(connection) connection.state = 'closed' return connection - def commit_if_has_connection(self, name): + def commit_if_has_connection(self): """If the named connection exists, commit the current transaction. :param str name: The name of the connection to use. """ - connection = self.in_use.get(name) + connection = self.get_if_exists() if connection: - self.commit(connection) - - def clear_transaction(self, conn_name='master'): - conn = self.begin(conn_name) - self.commit(conn) - return conn_name + self.commit() @abc.abstractmethod - def execute(self, sql, name=None, auto_begin=False, fetch=False): + def execute(self, sql, auto_begin=False, fetch=False): """Execute the given SQL. :param str sql: The sql to execute. - :param Optional[str] name: The name to use for the connection. :param bool auto_begin: If set, and dbt is not currently inside a transaction, automatically begin one. :param bool fetch: If set, fetch results. diff --git a/core/dbt/adapters/base/impl.py b/core/dbt/adapters/base/impl.py index e290042ce80..a0dbdd80e7c 100644 --- a/core/dbt/adapters/base/impl.py +++ b/core/dbt/adapters/base/impl.py @@ -1,4 +1,5 @@ import abc +from contextlib import contextmanager import agate import pytz @@ -16,8 +17,7 @@ from dbt.schema import Column from dbt.utils import filter_null_values -from dbt.adapters.base.meta import AdapterMeta, available, available_raw, \ - available_deprecated +from dbt.adapters.base.meta import AdapterMeta, available, available_deprecated from dbt.adapters.base import BaseRelation from dbt.adapters.cache import RelationsCache @@ -193,29 +193,40 @@ def __init__(self, config): ### # Methods that pass through to the connection manager ### - def acquire_connection(self, name): - return self.connections.get(name) + def acquire_connection(self, name=None): + return self.connections.set_connection_name(name) - def release_connection(self, name): - return self.connections.release(name) + def release_connection(self): + return self.connections.release() def cleanup_connections(self): return self.connections.cleanup_all() - def clear_transaction(self, conn_name='master'): - return self.connections.clear_transaction(conn_name) + def clear_transaction(self): + self.connections.clear_transaction() - def commit_if_has_connection(self, name): - return self.connections.commit_if_has_connection(name) + def commit_if_has_connection(self): + return self.connections.commit_if_has_connection() + + def nice_connection_name(self): + conn = self.connections.get_thread_connection() + if conn is None or conn.name is None: + return '' + return conn.name + + @contextmanager + def connection_named(self, name): + try: + yield self.acquire_connection(name) + finally: + self.release_connection() @available - def execute(self, sql, model_name=None, auto_begin=False, fetch=False): + def execute(self, sql, auto_begin=False, fetch=False): """Execute the given SQL. This is a thin wrapper around ConnectionManager.execute. :param str sql: The sql to execute. - :param Optional[str] model_name: The model name to use for the - connection. :param bool auto_begin: If set, and dbt is not currently inside a transaction, automatically begin one. :param bool fetch: If set, fetch results. @@ -224,7 +235,6 @@ def execute(self, sql, model_name=None, auto_begin=False, fetch=False): """ return self.connections.execute( sql=sql, - name=model_name, auto_begin=auto_begin, fetch=fetch ) @@ -258,14 +268,15 @@ def check_internal_manifest(self): ### # Caching methods ### - def _schema_is_cached(self, database, schema, model_name=None): + def _schema_is_cached(self, database, schema): """Check if the schema is cached, and by default logs if it is not.""" + if dbt.flags.USE_CACHE is False: return False elif (database, schema) not in self.cache: logger.debug( 'On "{}": cache miss for schema "{}.{}", this is inefficient' - .format(model_name or '', database, schema) + .format(self.nice_connection_name(), database, schema) ) return False else: @@ -330,11 +341,12 @@ def set_relations_cache(self, manifest, clear=False): self.cache.clear() self._relations_cache_for_schemas(manifest) - def cache_new_relation(self, relation, model_name=None): + def cache_new_relation(self, relation): """Cache a new relation in dbt. It will show up in `list relations`.""" if relation is None: + name = self.nice_connection_name() dbt.exceptions.raise_compiler_error( - 'Attempted to cache a null relation for {}'.format(model_name) + 'Attempted to cache a null relation for {}'.format(name) ) if dbt.flags.USE_CACHE: self.cache.add(relation) @@ -364,11 +376,10 @@ def is_cancelable(cls): # Abstract methods about schemas ### @abc.abstractmethod - def list_schemas(self, database, model_name=None): + def list_schemas(self, database): """Get a list of existing schemas. :param str database: The name of the database to list under. - :param Optional[str] model_name: The name of the connection to query as :return: All schemas that currently exist in the database :rtype: List[str] """ @@ -376,7 +387,7 @@ def list_schemas(self, database, model_name=None): '`list_schemas` is not implemented for this adapter!' ) - def check_schema_exists(self, database, schema, model_name=None): + def check_schema_exists(self, database, schema): """Check if a schema exists. The default implementation of this is potentially unnecessarily slow, @@ -385,7 +396,7 @@ def check_schema_exists(self, database, schema, model_name=None): """ search = ( s.lower() for s in - self.list_schemas(database=database, model_name=model_name) + self.list_schemas(database=database) ) return schema.lower() in search @@ -394,14 +405,12 @@ def check_schema_exists(self, database, schema, model_name=None): ### @abc.abstractmethod @available - def drop_relation(self, relation, model_name=None): + def drop_relation(self, relation): """Drop the given relation. *Implementors must call self.cache.drop() to preserve cache state!* :param self.Relation relation: The relation to drop - :param Optional[str] model_name: The name of the model to use for the - connection. """ raise dbt.exceptions.NotImplementedException( '`drop_relation` is not implemented for this adapter!' @@ -409,27 +418,24 @@ def drop_relation(self, relation, model_name=None): @abc.abstractmethod @available - def truncate_relation(self, relation, model_name=None): + def truncate_relation(self, relation): """Truncate the given relation. :param self.Relation relation: The relation to truncate - :param Optional[str] model_name: The name of the model to use for the - connection.""" + """ raise dbt.exceptions.NotImplementedException( '`truncate_relation` is not implemented for this adapter!' ) @abc.abstractmethod @available - def rename_relation(self, from_relation, to_relation, model_name=None): + def rename_relation(self, from_relation, to_relation): """Rename the relation from from_relation to to_relation. Implementors must call self.cache.rename() to preserve cache state. :param self.Relation from_relation: The original relation name :param self.Relation to_relation: The new relation name - :param Optional[str] model_name: The name of the model to use for the - connection. """ raise dbt.exceptions.NotImplementedException( '`rename_relation` is not implemented for this adapter!' @@ -437,12 +443,10 @@ def rename_relation(self, from_relation, to_relation, model_name=None): @abc.abstractmethod @available - def get_columns_in_relation(self, relation, model_name=None): + def get_columns_in_relation(self, relation): """Get a list of the columns in the given Relation. :param self.Relation relation: The relation to query for. - :param Optional[str] model_name: The name of the model to use for the - connection. :return: Information about all columns in the given relation. :rtype: List[self.Column] """ @@ -451,7 +455,7 @@ def get_columns_in_relation(self, relation, model_name=None): ) @available_deprecated('get_columns_in_relation') - def get_columns_in_table(self, schema, identifier, model_name=None): + def get_columns_in_table(self, schema, identifier): """DEPRECATED: Get a list of the columns in the given table.""" relation = self.Relation.create( database=self.config.credentials.database, @@ -459,26 +463,23 @@ def get_columns_in_table(self, schema, identifier, model_name=None): identifier=identifier, quote_policy=self.config.quoting ) - return self.get_columns_in_relation(relation, model_name=model_name) + return self.get_columns_in_relation(relation) @abc.abstractmethod - def expand_column_types(self, goal, current, model_name=None): + def expand_column_types(self, goal, current): """Expand the current table's types to match the goal table. (passable) :param self.Relation goal: A relation that currently exists in the database with columns of the desired types. :param self.Relation current: A relation that currently exists in the database with columns of unspecified types. - :param Optional[str] model_name: The name of the model to use for the - connection. """ raise dbt.exceptions.NotImplementedException( '`expand_target_column_types` is not implemented for this adapter!' ) @abc.abstractmethod - def list_relations_without_caching(self, information_schema, schema, - model_name=None): + def list_relations_without_caching(self, information_schema, schema): """List relations in the given schema, bypassing the cache. This is used as the underlying behavior to fill the cache. @@ -486,8 +487,6 @@ def list_relations_without_caching(self, information_schema, schema, :param Relation information_schema: The information schema to list relations from. :param str schema: The name of the schema to list relations from. - :param Optional[str] model_name: The name of the model to use for the - connection. :return: The relations in schema :retype: List[self.Relation] """ @@ -500,7 +499,7 @@ def list_relations_without_caching(self, information_schema, schema, # Provided methods about relations ### @available - def get_missing_columns(self, from_relation, to_relation, model_name=None): + def get_missing_columns(self, from_relation, to_relation): """Returns dict of {column:type} for columns in from_table that are missing from to_relation """ @@ -520,12 +519,12 @@ def get_missing_columns(self, from_relation, to_relation, model_name=None): from_columns = { col.name: col for col in - self.get_columns_in_relation(from_relation, model_name=model_name) + self.get_columns_in_relation(from_relation) } to_columns = { col.name: col for col in - self.get_columns_in_relation(to_relation, model_name=model_name) + self.get_columns_in_relation(to_relation) } missing_columns = set(from_columns.keys()) - set(to_columns.keys()) @@ -536,8 +535,7 @@ def get_missing_columns(self, from_relation, to_relation, model_name=None): ] @available - def expand_target_column_types(self, temp_table, to_relation, - model_name=None): + def expand_target_column_types(self, temp_table, to_relation): if not isinstance(to_relation, self.Relation): dbt.exceptions.invalid_type_error( method_name='expand_target_column_types', @@ -552,10 +550,10 @@ def expand_target_column_types(self, temp_table, to_relation, type='table', quote_policy=self.config.quoting ) - self.expand_column_types(goal, to_relation, model_name) + self.expand_column_types(goal, to_relation) - def list_relations(self, database, schema, model_name=None): - if self._schema_is_cached(database, schema, model_name): + def list_relations(self, database, schema): + if self._schema_is_cached(database, schema): return self.cache.get_relations(database, schema) information_schema = self.Relation.create( @@ -566,11 +564,11 @@ def list_relations(self, database, schema, model_name=None): # we can't build the relations cache because we don't have a # manifest so we can't run any operations. relations = self.list_relations_without_caching( - information_schema, schema, model_name=model_name + information_schema, schema ) - logger.debug('with schema={}, model_name={}, relations={}' - .format(schema, model_name, relations)) + logger.debug('with database={}, schema={}, relations={}' + .format(database, schema, relations)) return relations def _make_match_kwargs(self, database, schema, identifier): @@ -603,8 +601,8 @@ def _make_match(self, relations_list, database, schema, identifier): return matches @available - def get_relation(self, database, schema, identifier, model_name=None): - relations_list = self.list_relations(database, schema, model_name) + def get_relation(self, database, schema, identifier): + relations_list = self.list_relations(database, schema) matches = self._make_match(relations_list, database, schema, identifier) @@ -625,11 +623,10 @@ def get_relation(self, database, schema, identifier, model_name=None): return None @available_deprecated('get_relation') - def already_exists(self, schema, name, model_name=None): + def already_exists(self, schema, name): """DEPRECATED: Return if a model already exists in the database""" database = self.config.credentials.database - relation = self.get_relation(database, schema, name, - model_name=model_name) + relation = self.get_relation(database, schema, name) return relation is not None ### @@ -638,30 +635,26 @@ def already_exists(self, schema, name, model_name=None): ### @abc.abstractmethod @available - def create_schema(self, database, schema, model_name=None): + def create_schema(self, database, schema): """Create the given schema if it does not exist. :param str schema: The schema name to create. - :param Optional[str] model_name: The name of the model to use for the - connection. """ raise dbt.exceptions.NotImplementedException( '`create_schema` is not implemented for this adapter!' ) @abc.abstractmethod - def drop_schema(self, database, schema, model_name=None): + def drop_schema(self, database, schema): """Drop the given schema (and everything in it) if it exists. :param str schema: The schema name to drop. - :param Optional[str] model_name: The name of the model to use for the - connection. """ raise dbt.exceptions.NotImplementedException( '`drop_schema` is not implemented for this adapter!' ) - @available_raw + @available @abstractclassmethod def quote(cls, identifier): """Quote the given identifier, as appropriate for the database. @@ -675,7 +668,7 @@ def quote(cls, identifier): ) @available - def quote_as_configured(self, identifier, quote_key, model_name=None): + def quote_as_configured(self, identifier, quote_key): """Quote or do not quote the given identifer as configured in the project config for the quote key. @@ -770,7 +763,7 @@ def convert_time_type(cls, agate_table, col_idx): raise dbt.exceptions.NotImplementedException( '`convert_time_type` is not implemented for this adapter!') - @available_raw + @available @classmethod def convert_type(cls, agate_table, col_idx): return cls.convert_agate_type(agate_table, col_idx) @@ -794,8 +787,7 @@ def convert_agate_type(cls, agate_table, col_idx): # Operations involving the manifest ### def execute_macro(self, macro_name, manifest=None, project=None, - context_override=None, kwargs=None, release=False, - connection_name=None): + context_override=None, kwargs=None, release=False): """Look macro_name up in the manifest and execute its results. :param str macro_name: The name of the macro to execute. @@ -809,8 +801,6 @@ def execute_macro(self, macro_name, manifest=None, project=None, :param Optional[dict] kwargs: An optional dict of keyword args used to pass to the macro. :param bool release: If True, release the connection after executing. - :param Optional[str] connection_name: The connection name to use, or - use the macro name. Return an an AttrDict with three attributes: 'table', 'data', and 'status'. 'table' is an agate.Table. @@ -819,8 +809,6 @@ def execute_macro(self, macro_name, manifest=None, project=None, kwargs = {} if context_override is None: context_override = {} - if connection_name is None: - connection_name = macro_name if manifest is None: manifest = self._internal_manifest @@ -838,15 +826,13 @@ def execute_macro(self, macro_name, manifest=None, project=None, 'dbt could not find a macro with the name "{}" in {}' .format(macro_name, package_name) ) - # This causes a reference cycle, as dbt.context.runtime.generate() # ends up calling get_adapter, so the import has to be here. import dbt.context.runtime macro_context = dbt.context.runtime.generate_macro( macro, self.config, - manifest, - connection_name + manifest ) macro_context.update(context_override) @@ -856,7 +842,7 @@ def execute_macro(self, macro_name, manifest=None, project=None, result = macro_function(**kwargs) finally: if release: - self.release_connection(connection_name) + self.release_connection() return result @classmethod @@ -884,8 +870,7 @@ def cancel_open_connections(self): """Cancel all open connections.""" return self.connections.cancel_open() - def calculate_freshness(self, source, loaded_at_field, manifest=None, - connection_name=None): + def calculate_freshness(self, source, loaded_at_field, manifest=None): """Calculate the freshness of sources in dbt, and return it""" # in the future `source` will be a Relation instead of a string kwargs = { @@ -898,8 +883,7 @@ def calculate_freshness(self, source, loaded_at_field, manifest=None, FRESHNESS_MACRO_NAME, kwargs=kwargs, release=True, - manifest=manifest, - connection_name=connection_name + manifest=manifest ) # now we have a 1-row table of the maximum `loaded_at_field` value and # the current time according to the db. diff --git a/core/dbt/adapters/base/meta.py b/core/dbt/adapters/base/meta.py index b7968fe06ba..14201c93563 100644 --- a/core/dbt/adapters/base/meta.py +++ b/core/dbt/adapters/base/meta.py @@ -9,17 +9,6 @@ def available(func): arguments. """ func._is_available_ = True - func._model_name_ = True - return func - - -def available_raw(func): - """A decorator to indicate that a method on the adapter will be exposed to - the database wrapper, and the model name will be injected into the - arguments. - """ - func._is_available_ = True - func._model_name_ = False return func @@ -57,24 +46,16 @@ def __new__(mcls, name, bases, namespace, **kwargs): # dict mapping the method name to whether the model name should be # injected into the arguments. All methods in here are exposed to the # context. - available_model = set() - available_raw = set() + available = set() # collect base class data first for base in bases: - available_model.update(getattr(base, '_available_model_', set())) - available_raw.update(getattr(base, '_available_raw_', set())) + available.update(getattr(base, '_available_', set())) # override with local data if it exists for name, value in namespace.items(): if getattr(value, '_is_available_', False): - if getattr(value, '_model_name_', False): - available_raw.discard(name) - available_model.add(name) - else: - available_model.discard(name) - available_raw.add(name) + available.add(name) - cls._available_model_ = frozenset(available_model) - cls._available_raw_ = frozenset(available_raw) + cls._available_ = frozenset(available) return cls diff --git a/core/dbt/adapters/factory.py b/core/dbt/adapters/factory.py index 2cbe2dc7ac6..39ba9d070c8 100644 --- a/core/dbt/adapters/factory.py +++ b/core/dbt/adapters/factory.py @@ -1,5 +1,3 @@ -from dbt.logger import GLOBAL_LOGGER as logger - import dbt.exceptions from importlib import import_module from dbt.include.global_project import PACKAGES @@ -30,7 +28,7 @@ def get_relation_class_by_name(adapter_name): def load_plugin(adapter_name): try: - mod = import_module('.'+adapter_name, 'dbt.adapters') + mod = import_module('.' + adapter_name, 'dbt.adapters') except ImportError: raise dbt.exceptions.RuntimeException( "Could not find adapter type {}!".format(adapter_name) diff --git a/core/dbt/adapters/sql/connections.py b/core/dbt/adapters/sql/connections.py index a0c7bedf2ed..13bd312e876 100644 --- a/core/dbt/adapters/sql/connections.py +++ b/core/dbt/adapters/sql/connections.py @@ -31,30 +31,28 @@ def cancel(self, connection): def cancel_open(self): names = [] with self.lock: - for name, connection in self.in_use.items(): - if name == 'master': + for connection in self.thread_connections.values(): + if connection.name == 'master': continue self.cancel(connection) - names.append(name) + names.append(connection.name) return names - def add_query(self, sql, name=None, auto_begin=True, bindings=None, + def add_query(self, sql, auto_begin=True, bindings=None, abridge_sql_log=False): - connection = self.get(name) - connection_name = connection.name - + connection = self.get_thread_connection() if auto_begin and connection.transaction_open is False: - self.begin(connection_name) + self.begin() logger.debug('Using {} connection "{}".' - .format(self.TYPE, connection_name)) + .format(self.TYPE, connection.name)) - with self.exception_handler(sql, connection_name): + with self.exception_handler(sql): if abridge_sql_log: - logger.debug('On %s: %s....', connection_name, sql[0:512]) + logger.debug('On %s: %s....', connection.name, sql[0:512]) else: - logger.debug('On %s: %s', connection_name, sql) + logger.debug('On %s: %s', connection.name, sql) pre = time.time() cursor = connection.handle.cursor() @@ -90,9 +88,8 @@ def get_result_from_cursor(cls, cursor): return dbt.clients.agate_helper.table_from_data(data, column_names) - def execute(self, sql, name=None, auto_begin=False, fetch=False): - self.get(name) - _, cursor = self.add_query(sql, name, auto_begin) + def execute(self, sql, auto_begin=False, fetch=False): + _, cursor = self.add_query(sql, auto_begin) status = self.get_status(cursor) if fetch: table = self.get_result_from_cursor(cursor) @@ -100,14 +97,14 @@ def execute(self, sql, name=None, auto_begin=False, fetch=False): table = dbt.clients.agate_helper.empty_table() return status, table - def add_begin_query(self, name): - return self.add_query('BEGIN', name, auto_begin=False) + def add_begin_query(self): + return self.add_query('BEGIN', auto_begin=False) - def add_commit_query(self, name): - return self.add_query('COMMIT', name, auto_begin=False) + def add_commit_query(self): + return self.add_query('COMMIT', auto_begin=False) - def begin(self, name): - connection = self.get(name) + def begin(self): + connection = self.get_thread_connection() if dbt.flags.STRICT_MODE: assert isinstance(connection, Connection) @@ -117,29 +114,24 @@ def begin(self, name): 'Tried to begin a new transaction on connection "{}", but ' 'it already had one open!'.format(connection.get('name'))) - self.add_begin_query(name) + self.add_begin_query() connection.transaction_open = True - self.in_use[name] = connection - return connection - def commit(self, connection): - + def commit(self): + connection = self.get_thread_connection() if dbt.flags.STRICT_MODE: assert isinstance(connection, Connection) - connection = self.get(connection.name) - if connection.transaction_open is False: raise dbt.exceptions.InternalException( 'Tried to commit transaction on connection "{}", but ' 'it does not have one open!'.format(connection.name)) logger.debug('On {}: COMMIT'.format(connection.name)) - self.add_commit_query(connection.name) + self.add_commit_query() connection.transaction_open = False - self.in_use[connection.name] = connection return connection diff --git a/core/dbt/adapters/sql/impl.py b/core/dbt/adapters/sql/impl.py index 8a6ace3bef7..245b812def1 100644 --- a/core/dbt/adapters/sql/impl.py +++ b/core/dbt/adapters/sql/impl.py @@ -36,14 +36,12 @@ class SQLAdapter(BaseAdapter): - get_columns_in_relation """ @available - def add_query(self, sql, model_name=None, auto_begin=True, bindings=None, + def add_query(self, sql, auto_begin=True, bindings=None, abridge_sql_log=False): """Add a query to the current transaction. A thin wrapper around ConnectionManager.add_query. :param str sql: The SQL query to add - :param Optional[str] model_name: The name of the connection the - transaction is on :param bool auto_begin: If set and there is no transaction in progress, begin a new one. :param Optional[List[object]]: An optional list of bindings for the @@ -51,8 +49,8 @@ def add_query(self, sql, model_name=None, auto_begin=True, bindings=None, :param bool abridge_sql_log: If set, limit the raw sql logged to 512 characters """ - return self.connections.add_query(sql, model_name, auto_begin, - bindings, abridge_sql_log) + return self.connections.add_query(sql, auto_begin, bindings, + abridge_sql_log) @classmethod def convert_text_type(cls, agate_table, col_idx): @@ -83,15 +81,15 @@ def convert_time_type(cls, agate_table, col_idx): def is_cancelable(cls): return True - def expand_column_types(self, goal, current, model_name=None): + def expand_column_types(self, goal, current): reference_columns = { c.name: c for c in - self.get_columns_in_relation(goal, model_name=model_name) + self.get_columns_in_relation(goal) } target_columns = { c.name: c for c - in self.get_columns_in_relation(current, model_name=model_name) + in self.get_columns_in_relation(current) } for column_name, reference_column in reference_columns.items(): @@ -104,14 +102,9 @@ def expand_column_types(self, goal, current, model_name=None): logger.debug("Changing col type from %s to %s in table %s", target_column.data_type, new_type, current) - self.alter_column_type(current, column_name, new_type, - model_name=model_name) + self.alter_column_type(current, column_name, new_type) - if model_name is None: - self.release_connection('master') - - def alter_column_type(self, relation, column_name, new_column_type, - model_name=None): + def alter_column_type(self, relation, column_name, new_column_type): """ 1. Create a new column (w/ temp name and correct type) 2. Copy data over to it @@ -125,11 +118,10 @@ def alter_column_type(self, relation, column_name, new_column_type, } self.execute_macro( ALTER_COLUMN_TYPE_MACRO_NAME, - kwargs=kwargs, - connection_name=model_name + kwargs=kwargs ) - def drop_relation(self, relation, model_name=None): + def drop_relation(self, relation): if dbt.flags.USE_CACHE: self.cache.drop(relation) if relation.type is None: @@ -139,66 +131,54 @@ def drop_relation(self, relation, model_name=None): self.execute_macro( DROP_RELATION_MACRO_NAME, - kwargs={'relation': relation}, - connection_name=model_name + kwargs={'relation': relation} ) - def truncate_relation(self, relation, model_name=None): + def truncate_relation(self, relation): self.execute_macro( TRUNCATE_RELATION_MACRO_NAME, - kwargs={'relation': relation}, - connection_name=model_name + kwargs={'relation': relation} ) - def rename_relation(self, from_relation, to_relation, model_name=None): + def rename_relation(self, from_relation, to_relation): if dbt.flags.USE_CACHE: self.cache.rename(from_relation, to_relation) kwargs = {'from_relation': from_relation, 'to_relation': to_relation} self.execute_macro( RENAME_RELATION_MACRO_NAME, - kwargs=kwargs, - connection_name=model_name + kwargs=kwargs ) - def get_columns_in_relation(self, relation, model_name=None): + def get_columns_in_relation(self, relation): return self.execute_macro( GET_COLUMNS_IN_RELATION_MACRO_NAME, - kwargs={'relation': relation}, - connection_name=model_name + kwargs={'relation': relation} ) - def create_schema(self, database, schema, model_name=None): + def create_schema(self, database, schema): logger.debug('Creating schema "%s"."%s".', database, schema) - if model_name is None: - model_name = 'master' kwargs = { 'database_name': self.quote_as_configured(database, 'database'), 'schema_name': self.quote_as_configured(schema, 'schema'), } - self.execute_macro(CREATE_SCHEMA_MACRO_NAME, - kwargs=kwargs, - connection_name=model_name) - self.commit_if_has_connection(model_name) + self.execute_macro(CREATE_SCHEMA_MACRO_NAME, kwargs=kwargs) + self.commit_if_has_connection() - def drop_schema(self, database, schema, model_name=None): + def drop_schema(self, database, schema): logger.debug('Dropping schema "%s"."%s".', database, schema) kwargs = { 'database_name': self.quote_as_configured(database, 'database'), 'schema_name': self.quote_as_configured(schema, 'schema'), } self.execute_macro(DROP_SCHEMA_MACRO_NAME, - kwargs=kwargs, - connection_name=model_name) + kwargs=kwargs) - def list_relations_without_caching(self, information_schema, schema, - model_name=None): + def list_relations_without_caching(self, information_schema, schema): kwargs = {'information_schema': information_schema, 'schema': schema} results = self.execute_macro( LIST_RELATIONS_MACRO_NAME, - kwargs=kwargs, - connection_name=model_name, - release=True + kwargs=kwargs ) relations = [] @@ -219,19 +199,15 @@ def list_relations_without_caching(self, information_schema, schema, def quote(cls, identifier): return '"{}"'.format(identifier) - def list_schemas(self, database, model_name=None): + def list_schemas(self, database): results = self.execute_macro( LIST_SCHEMAS_MACRO_NAME, - kwargs={'database': database}, - connection_name=model_name, - # release when the model_name is none, as that implies we were - # called by node_runners.py. - release=(model_name is None) + kwargs={'database': database} ) return [row[0] for row in results] - def check_schema_exists(self, database, schema, model_name=None): + def check_schema_exists(self, database, schema): information_schema = self.Relation.create( database=database, schema=schema ).information_schema() @@ -239,7 +215,6 @@ def check_schema_exists(self, database, schema, model_name=None): kwargs = {'information_schema': information_schema, 'schema': schema} results = self.execute_macro( CHECK_SCHEMA_EXISTS_MACRO_NAME, - kwargs=kwargs, - connection_name=model_name + kwargs=kwargs ) return results[0][0] > 0 diff --git a/core/dbt/clients/system.py b/core/dbt/clients/system.py index 559069379d2..5a3b8353b95 100644 --- a/core/dbt/clients/system.py +++ b/core/dbt/clients/system.py @@ -250,7 +250,7 @@ def _handle_windows_error(exc, cwd, cmd): cls = dbt.exceptions.WorkingDirectoryError else: message = 'Unknown error: {} (errno={}: "{}")'.format( - str(exc), exc.errno, errno.errorcode.get(exc.errno, '') + str(exc), exc.errno, errno.errorcode.get(exc.errno, '') ) raise cls(cwd, cmd, message) @@ -312,7 +312,7 @@ def run_cmd(cwd, cmd, env=None): def download(url, path): response = requests.get(url) with open(path, 'wb') as handle: - for block in response.iter_content(1024*64): + for block in response.iter_content(1024 * 64): handle.write(block) @@ -382,7 +382,7 @@ def move(src, dst): except OSError: # probably different drives if os.path.isdir(src): - if _absnorm(dst+'\\').startswith(_absnorm(src+'\\')): + if _absnorm(dst + '\\').startswith(_absnorm(src + '\\')): # dst is inside src raise EnvironmentError( "Cannot move a directory '{}' into itself '{}'" diff --git a/core/dbt/compat.py b/core/dbt/compat.py index a3fe87d273f..2548476a124 100644 --- a/core/dbt/compat.py +++ b/core/dbt/compat.py @@ -1,3 +1,5 @@ +# flake8: noqa + import abc import codecs import json @@ -34,10 +36,12 @@ from SimpleHTTPServer import SimpleHTTPRequestHandler from SocketServer import TCPServer from Queue import PriorityQueue + from thread import get_ident else: from http.server import SimpleHTTPRequestHandler from socketserver import TCPServer from queue import PriorityQueue + from threading import get_ident def to_unicode(s): diff --git a/core/dbt/context/common.py b/core/dbt/context/common.py index c46a3fb8528..53e120a8985 100644 --- a/core/dbt/context/common.py +++ b/core/dbt/context/common.py @@ -1,5 +1,3 @@ -import copy -import functools import json import os @@ -51,27 +49,15 @@ def create(self, *args, **kwargs): class DatabaseWrapper(object): """ - Wrapper for runtime database interaction. Mostly a compatibility layer now. + Wrapper for runtime database interaction. Applies the runtime quote policy + via a relation proxy. """ - def __init__(self, connection_name, adapter): - self.connection_name = connection_name + def __init__(self, adapter): self.adapter = adapter self.Relation = RelationProxy(adapter) - def wrap(self, name): - func = getattr(self.adapter, name) - - @functools.wraps(func) - def wrapped(*args, **kwargs): - kwargs['model_name'] = self.connection_name - return func(*args, **kwargs) - - return wrapped - def __getattr__(self, name): - if name in self.adapter._available_model_: - return self.wrap(name) - elif name in self.adapter._available_raw_: + if name in self.adapter._available_: return getattr(self.adapter, name) else: raise AttributeError( @@ -88,7 +74,7 @@ def type(self): return self.adapter.type() def commit(self): - return self.adapter.commit_if_has_connection(self.connection_name) + return self.adapter.commit_if_has_connection() def _add_macro_map(context, package_name, macro_map): @@ -364,7 +350,7 @@ def get_datetime_module_context(): def generate_base(model, model_dict, config, manifest, source_config, - provider, connection_name): + provider, adapter=None): """Generate the common aspects of the config dict.""" if provider is None: raise dbt.exceptions.InternalException( @@ -377,6 +363,7 @@ def generate_base(model, model_dict, config, manifest, source_config, target['type'] = config.credentials.type target.pop('pass', None) target['name'] = target_name + adapter = get_adapter(config) context = {'env': target} @@ -384,7 +371,7 @@ def generate_base(model, model_dict, config, manifest, source_config, pre_hooks = None post_hooks = None - db_wrapper = DatabaseWrapper(connection_name, adapter) + db_wrapper = DatabaseWrapper(adapter) context = dbt.utils.merge(context, { "adapter": db_wrapper, @@ -443,7 +430,7 @@ def modify_generated_context(context, model, model_dict, config, manifest): return context -def generate_execute_macro(model, config, manifest, provider, connection_name): +def generate_execute_macro(model, config, manifest, provider): """Internally, macros can be executed like nodes, with some restrictions: - they don't have have all values available that nodes do: @@ -452,8 +439,8 @@ def generate_execute_macro(model, config, manifest, provider, connection_name): - they can't be configured with config() directives """ model_dict = model.serialize() - context = generate_base(model, model_dict, config, manifest, - None, provider, connection_name) + context = generate_base(model, model_dict, config, manifest, None, + provider) return modify_generated_context(context, model, model_dict, config, manifest) @@ -462,7 +449,7 @@ def generate_execute_macro(model, config, manifest, provider, connection_name): def generate_model(model, config, manifest, source_config, provider): model_dict = model.to_dict() context = generate_base(model, model_dict, config, manifest, - source_config, provider, model.get('name')) + source_config, provider) # operations (hooks) don't get a 'this' if model.resource_type != NodeType.Operation: this = get_this_relation(context['adapter'], config, model_dict) @@ -487,5 +474,4 @@ def generate(model, config, manifest, source_config=None, provider=None): or dbt.context.runtime.generate """ - return generate_model(model, config, manifest, source_config, - provider) + return generate_model(model, config, manifest, source_config, provider) diff --git a/core/dbt/context/parser.py b/core/dbt/context/parser.py index 2a9d2a87881..a933047680d 100644 --- a/core/dbt/context/parser.py +++ b/core/dbt/context/parser.py @@ -1,6 +1,7 @@ import dbt.exceptions import dbt.context.common +from dbt.adapters.factory import get_adapter execute = False @@ -97,12 +98,17 @@ def get(self, name, validator=None, default=None): def generate(model, runtime_config, manifest, source_config): - return dbt.context.common.generate( - model, runtime_config, manifest, source_config, dbt.context.parser) + # during parsing, we don't have a connection, but we might need one, so we + # have to acquire it. + # In the future, it would be nice to lazily open the connection, as in some + # projects it would be possible to parse without connecting to the db + with get_adapter(runtime_config).connection_named(model.get('name')): + return dbt.context.common.generate( + model, runtime_config, manifest, source_config, dbt.context.parser + ) -def generate_macro(model, runtime_config, manifest, connection_name): +def generate_macro(model, runtime_config, manifest): return dbt.context.common.generate_execute_macro( - model, runtime_config, manifest, dbt.context.parser, - connection_name + model, runtime_config, manifest, dbt.context.parser ) diff --git a/core/dbt/context/runtime.py b/core/dbt/context/runtime.py index 40dcb77e73f..2fc7b32cddb 100644 --- a/core/dbt/context/runtime.py +++ b/core/dbt/context/runtime.py @@ -123,8 +123,7 @@ def generate(model, runtime_config, manifest): model, runtime_config, manifest, None, dbt.context.runtime) -def generate_macro(model, runtime_config, manifest, connection_name): +def generate_macro(model, runtime_config, manifest): return dbt.context.common.generate_execute_macro( - model, runtime_config, manifest, dbt.context.runtime, - connection_name + model, runtime_config, manifest, dbt.context.runtime ) diff --git a/core/dbt/node_runners.py b/core/dbt/node_runners.py index 09e84c7ed73..c0d2d26e76f 100644 --- a/core/dbt/node_runners.py +++ b/core/dbt/node_runners.py @@ -1,13 +1,10 @@ from dbt.logger import GLOBAL_LOGGER as logger from dbt.exceptions import NotImplementedException, CompilationException, \ RuntimeException, InternalException, missing_materialization -from dbt.utils import get_nodes_by_tags -from dbt.node_types import NodeType, RunHookType -from dbt.adapters.factory import get_adapter +from dbt.node_types import NodeType from dbt.contracts.results import RunModelResult, collect_timing_info, \ SourceFreshnessResult, PartialResult, RemoteCompileResult, RemoteRunResult from dbt.compilation import compile_node -from dbt.utils import timestring import dbt.clients.jinja import dbt.context.runtime @@ -19,8 +16,6 @@ import dbt.schema import dbt.writer -import six -import sys import threading import time import traceback @@ -131,6 +126,7 @@ def safe_run(self, manifest): result = None try: + self.adapter.acquire_connection(self.node.get('name')) with collect_timing_info('compile') as timing_info: # if we fail here, we still have a compiled node to return # this has the benefit of showing a build path for the errant @@ -160,9 +156,10 @@ def safe_run(self, manifest): prefix = 'Internal error executing {}'.format(build_path) error = "{prefix}\n{error}\n\n{note}".format( - prefix=dbt.ui.printer.red(prefix), - error=str(e).strip(), - note=INTERNAL_ERROR_STRING) + prefix=dbt.ui.printer.red(prefix), + error=str(e).strip(), + note=INTERNAL_ERROR_STRING + ) logger.debug(error) error = dbt.compat.to_string(e) @@ -171,11 +168,13 @@ def safe_run(self, manifest): if node_description is None: node_description = self.node.unique_id prefix = "Unhandled error while executing {description}".format( - description=node_description) + description=node_description + ) error = "{prefix}\n{error}".format( - prefix=dbt.ui.printer.red(prefix), - error=str(e).strip()) + prefix=dbt.ui.printer.red(prefix), + error=str(e).strip() + ) logger.error(error) logger.debug('', exc_info=True) @@ -202,13 +201,12 @@ def _safe_release_connection(self): """Try to release a connection. If an exception is hit, log and return the error string. """ - node_name = self.node.name try: - self.adapter.release_connection(node_name) + self.adapter.release_connection() except Exception as exc: logger.debug( 'Error releasing connection for node {}: {!s}\n{}' - .format(node_name, exc, traceback.format_exc()) + .format(self.node.name, exc, traceback.format_exc()) ) return dbt.compat.to_string(exc) @@ -372,7 +370,8 @@ def _calculate_status(self, target_freshness, freshness): continue target = target_freshness[fullkey] - kwargs = {target['period']+'s': target['count']} + kwname = target['period'] + 's' + kwargs = {kwname: target['count']} if freshness > timedelta(**kwargs).total_seconds(): return key return 'pass' @@ -401,12 +400,12 @@ def from_run_result(self, result, start_time, timing_info): def execute(self, compiled_node, manifest): relation = self.adapter.Relation.create_from_source(compiled_node) # given a Source, calculate its fresnhess. - freshness = self.adapter.calculate_freshness( - relation, - compiled_node.loaded_at_field, - manifest=manifest, - connection_name=compiled_node.unique_id - ) + with self.adapter.connection_named(compiled_node.unique_id): + freshness = self.adapter.calculate_freshness( + relation, + compiled_node.loaded_at_field, + manifest=manifest + ) status = self._calculate_status( compiled_node.freshness, freshness['age'] @@ -447,7 +446,6 @@ def print_start_line(self): def execute_test(self, test): res, table = self.adapter.execute( test.wrapped_sql, - model_name=test.name, auto_begin=True, fetch=True) diff --git a/core/dbt/parser/base.py b/core/dbt/parser/base.py index d21b48cb1ea..903d977864f 100644 --- a/core/dbt/parser/base.py +++ b/core/dbt/parser/base.py @@ -13,7 +13,6 @@ from dbt.logger import GLOBAL_LOGGER as logger from dbt.contracts.graph.parsed import ParsedNode from dbt.parser.source_config import SourceConfig -from dbt.node_types import NodeType class BaseParser(object): @@ -88,7 +87,7 @@ def get_schema(_): else: root_context = dbt.context.parser.generate_macro( get_schema_macro, self.root_project_config, - self.macro_manifest, 'generate_schema_name' + self.macro_manifest ) get_schema = get_schema_macro.generator(root_context) @@ -159,10 +158,6 @@ def _render_with_context(self, parsed_node, config): parsed_node.raw_sql, context, parsed_node.to_shallow_dict(), capture_macros=True) - # Clean up any open conns opened by adapter functions that hit the db - db_wrapper = context['adapter'] - db_wrapper.adapter.release_connection(parsed_node.name) - def _update_parsed_node_info(self, parsed_node, config): """Given the SourceConfig used for parsing and the parsed node, generate and set the true values to use, overriding the temporary parse diff --git a/core/dbt/ssh_forward.py b/core/dbt/ssh_forward.py deleted file mode 100644 index 0ff32097998..00000000000 --- a/core/dbt/ssh_forward.py +++ /dev/null @@ -1,10 +0,0 @@ -import logging - -# modules are only imported once -- make sure that we don't have > 1 -# because subsequent tunnels will block waiting to acquire the port - -server = None - - -def get_or_create_tunnel(host, port, user, remote_host, remote_port, timeout): - pass diff --git a/core/dbt/task/base.py b/core/dbt/task/base.py index 4fe9611beed..a592a941af1 100644 --- a/core/dbt/task/base.py +++ b/core/dbt/task/base.py @@ -5,9 +5,9 @@ from dbt.config import RuntimeConfig, Project from dbt.config.profile import read_profile, PROFILES_DIR -from dbt import flags from dbt import tracking from dbt.logger import GLOBAL_LOGGER as logger +from dbt.utils import to_string import dbt.exceptions diff --git a/core/dbt/task/compile.py b/core/dbt/task/compile.py index 3aae70d67da..1514a5f7ebc 100644 --- a/core/dbt/task/compile.py +++ b/core/dbt/task/compile.py @@ -2,7 +2,6 @@ from dbt.adapters.factory import get_adapter from dbt.compilation import compile_manifest -from dbt.exceptions import RuntimeException from dbt.loader import load_all_projects, GraphLoader from dbt.node_runners import CompileRunner, RPCCompileRunner from dbt.node_types import NodeType @@ -10,7 +9,7 @@ from dbt.parser.util import ParserUtils import dbt.ui.printer -from dbt.task.runnable import ManifestTask, GraphRunnableTask, RemoteCallable +from dbt.task.runnable import GraphRunnableTask, RemoteCallable class CompileTask(GraphRunnableTask): @@ -70,7 +69,7 @@ def handle_request(self, name, sql): 'name': name, 'root_path': request_path, 'resource_type': NodeType.RPCCall, - 'path': name+'.sql', + 'path': name + '.sql', 'original_file_path': 'from remote system', 'package_name': self.config.project_name, 'raw_sql': sql, diff --git a/core/dbt/task/generate.py b/core/dbt/task/generate.py index 58b2238a02a..db7c91504f6 100644 --- a/core/dbt/task/generate.py +++ b/core/dbt/task/generate.py @@ -1,4 +1,3 @@ -import json import os import shutil @@ -202,11 +201,12 @@ def run(self): DOCS_INDEX_FILE_PATH, os.path.join(self.config.target_path, 'index.html')) - manifest = self._get_manifest() adapter = get_adapter(self.config) + with adapter.connection_named('generate_catalog'): + manifest = self._get_manifest() - dbt.ui.printer.print_timestamped_line("Building catalog") - results = adapter.get_catalog(manifest) + dbt.ui.printer.print_timestamped_line("Building catalog") + results = adapter.get_catalog(manifest) results = [ dict(zip(results.column_names, row)) diff --git a/core/dbt/task/run.py b/core/dbt/task/run.py index 3ab251aee1c..574519b9e5d 100644 --- a/core/dbt/task/run.py +++ b/core/dbt/task/run.py @@ -29,34 +29,29 @@ def run_hooks(self, adapter, hook_type, extra_context): ordered_hooks = sorted(hooks, key=lambda h: h.get('index', len(hooks))) - for i, hook in enumerate(ordered_hooks): - model_name = hook.get('name') - - # This will clear out an open transaction if there is one. + with adapter.connection_named(hook_type): # on-run-* hooks should run outside of a transaction. This happens # b/c psycopg2 automatically begins a transaction when a connection - # is created. TODO : Move transaction logic out of here, and - # implement a for-loop over these sql statements in jinja-land. - # Also, consider configuring psycopg2 (and other adapters?) to - # ensure that a transaction is only created if dbt initiates it. - adapter.clear_transaction(model_name) - compiled = compile_node(adapter, self.config, hook, self.manifest, - extra_context) - statement = compiled.wrapped_sql + # is created. + adapter.clear_transaction() + + for i, hook in enumerate(ordered_hooks): + compiled = compile_node(adapter, self.config, hook, + self.manifest, extra_context) + statement = compiled.wrapped_sql - hook_index = hook.get('index', len(hooks)) - hook_dict = get_hook_dict(statement, index=hook_index) + hook_index = hook.get('index', len(hooks)) + hook_dict = get_hook_dict(statement, index=hook_index) - if dbt.flags.STRICT_MODE: - Hook(**hook_dict) + if dbt.flags.STRICT_MODE: + Hook(**hook_dict) - sql = hook_dict.get('sql', '') + sql = hook_dict.get('sql', '') - if len(sql.strip()) > 0: - adapter.execute(sql, model_name=model_name, auto_begin=False, - fetch=False) + if len(sql.strip()) > 0: + adapter.execute(sql, auto_begin=False, fetch=False) - adapter.release_connection(model_name) + adapter.release_connection() def safe_run_hooks(self, adapter, hook_type, extra_context): try: @@ -82,8 +77,11 @@ def print_results_line(cls, results, execution_time): .format(stat_line=stat_line, execution=execution)) def before_run(self, adapter, selected_uids): - self.populate_adapter_cache(adapter) + with adapter.connection_named('master'): + self.populate_adapter_cache(adapter) self.safe_run_hooks(adapter, RunHookType.Start, {}) + with adapter.connection_named('master'): + self.populate_adapter_cache(adapter) self.create_schemas(adapter, selected_uids) def after_run(self, adapter, results): diff --git a/core/dbt/utils.py b/core/dbt/utils.py index c5fc8977e98..5139cee9dc6 100644 --- a/core/dbt/utils.py +++ b/core/dbt/utils.py @@ -475,8 +475,10 @@ def translate_aliases(kwargs, aliases): key_names = ', '.join("{}".format(k) for k in kwargs if aliases.get(k) == canonical_key) - raise AliasException('Got duplicate keys: ({}) all map to "{}"' - .format(key_names, canonical_key)) + raise dbt.exceptions.AliasException( + 'Got duplicate keys: ({}) all map to "{}"' + .format(key_names, canonical_key) + ) result[canonical_key] = value diff --git a/plugins/bigquery/dbt/adapters/bigquery/connections.py b/plugins/bigquery/dbt/adapters/bigquery/connections.py index fbe0448adff..84ed923d257 100644 --- a/plugins/bigquery/dbt/adapters/bigquery/connections.py +++ b/plugins/bigquery/dbt/adapters/bigquery/connections.py @@ -1,4 +1,3 @@ -import abc from contextlib import contextmanager import google.auth @@ -10,7 +9,6 @@ import dbt.clients.agate_helper import dbt.exceptions from dbt.adapters.base import BaseConnectionManager, Credentials -from dbt.compat import abstractclassmethod from dbt.logger import GLOBAL_LOGGER as logger @@ -77,8 +75,11 @@ def handle_error(cls, error, message, sql): raise dbt.exceptions.DatabaseException(error_msg) + def clear_transaction(self): + pass + @contextmanager - def exception_handler(self, sql, connection_name='master'): + def exception_handler(self, sql): try: yield @@ -104,10 +105,10 @@ def close(cls, connection): return connection - def begin(self, name): + def begin(self): pass - def commit(self, connection): + def commit(self): pass @classmethod @@ -178,25 +179,25 @@ def get_table_from_response(cls, resp): rows = [dict(row.items()) for row in resp] return dbt.clients.agate_helper.table_from_data(rows, column_names) - def raw_execute(self, sql, name=None, fetch=False): - conn = self.get(name) + def raw_execute(self, sql, fetch=False): + conn = self.get_thread_connection() client = conn.handle - logger.debug('On %s: %s', name, sql) + logger.debug('On %s: %s', conn.name, sql) job_config = google.cloud.bigquery.QueryJobConfig() job_config.use_legacy_sql = False query_job = client.query(sql, job_config) # this blocks until the query has completed - with self.exception_handler(sql, conn.name): + with self.exception_handler(sql): iterator = query_job.result() return query_job, iterator - def execute(self, sql, name=None, auto_begin=False, fetch=None): + def execute(self, sql, auto_begin=False, fetch=None): # auto_begin is ignored on bigquery, and only included for consistency - _, iterator = self.raw_execute(sql, name=name, fetch=fetch) + _, iterator = self.raw_execute(sql, fetch=fetch) if fetch: res = self.get_table_from_response(iterator) @@ -207,32 +208,31 @@ def execute(self, sql, name=None, auto_begin=False, fetch=None): status = 'OK' return status, res - def create_bigquery_table(self, database, schema, table_name, conn_name, - callback, sql): + def create_bigquery_table(self, database, schema, table_name, callback, + sql): """Create a bigquery table. The caller must supply a callback that takes one argument, a `google.cloud.bigquery.Table`, and mutates it. """ - conn = self.get(conn_name) + conn = self.get_thread_connection() client = conn.handle view_ref = self.table_ref(database, schema, table_name, conn) view = google.cloud.bigquery.Table(view_ref) callback(view) - with self.exception_handler(sql, conn.name): + with self.exception_handler(sql): client.create_table(view) - def create_view(self, database, schema, table_name, conn_name, sql): + def create_view(self, database, schema, table_name, sql): def callback(table): table.view_query = sql table.view_use_legacy_sql = False - self.create_bigquery_table(database, schema, table_name, conn_name, - callback, sql) + self.create_bigquery_table(database, schema, table_name, callback, sql) - def create_table(self, database, schema, table_name, conn_name, sql): - conn = self.get(conn_name) + def create_table(self, database, schema, table_name, sql): + conn = self.get_thread_connection() client = conn.handle table_ref = self.table_ref(database, schema, table_name, conn) @@ -243,16 +243,15 @@ def create_table(self, database, schema, table_name, conn_name, sql): query_job = client.query(sql, job_config=job_config) # this waits for the job to complete - with self.exception_handler(sql, conn_name): + with self.exception_handler(sql): query_job.result(timeout=self.get_timeout(conn)) - def create_date_partitioned_table(self, database, schema, table_name, - conn_name): + def create_date_partitioned_table(self, database, schema, table_name): def callback(table): table.partitioning_type = 'DAY' - self.create_bigquery_table(database, schema, table_name, conn_name, - callback, 'CREATE DAY PARTITIONED TABLE') + self.create_bigquery_table(database, schema, table_name, callback, + 'CREATE DAY PARTITIONED TABLE') @staticmethod def dataset(database, schema, conn): @@ -263,24 +262,24 @@ def table_ref(self, database, schema, table_name, conn): dataset = self.dataset(database, schema, conn) return dataset.table(table_name) - def get_bq_table(self, database, schema, identifier, conn_name=None): + def get_bq_table(self, database, schema, identifier): """Get a bigquery table for a schema/model.""" - conn = self.get(conn_name) + conn = self.get_thread_connection() table_ref = self.table_ref(database, schema, identifier, conn) return conn.handle.get_table(table_ref) - def drop_dataset(self, database, schema, conn_name=None): - conn = self.get(conn_name) + def drop_dataset(self, database, schema): + conn = self.get_thread_connection() dataset = self.dataset(database, schema, conn) client = conn.handle - with self.exception_handler('drop dataset', conn.name): + with self.exception_handler('drop dataset'): for table in client.list_tables(dataset): client.delete_table(table.reference) client.delete_dataset(dataset) - def create_dataset(self, database, schema, conn_name=None): - conn = self.get(conn_name) + def create_dataset(self, database, schema): + conn = self.get_thread_connection() client = conn.handle dataset = self.dataset(database, schema, conn) @@ -291,5 +290,5 @@ def create_dataset(self, database, schema, conn_name=None): except google.api_core.exceptions.NotFound: pass - with self.exception_handler('create dataset', conn.name): + with self.exception_handler('create dataset'): client.create_dataset(dataset) diff --git a/plugins/bigquery/dbt/adapters/bigquery/impl.py b/plugins/bigquery/dbt/adapters/bigquery/impl.py index ad5fbf724bf..bb4c45f7f59 100644 --- a/plugins/bigquery/dbt/adapters/bigquery/impl.py +++ b/plugins/bigquery/dbt/adapters/bigquery/impl.py @@ -65,13 +65,12 @@ def date_function(cls): def is_cancelable(cls): return False - def drop_relation(self, relation, model_name=None): - is_cached = self._schema_is_cached(relation.database, relation.schema, - model_name) + def drop_relation(self, relation): + is_cached = self._schema_is_cached(relation.database, relation.schema) if is_cached: self.cache.drop(relation) - conn = self.connections.get(model_name) + conn = self.connections.get_thread_connection() client = conn.handle dataset = self.connections.dataset(relation.database, relation.schema, @@ -79,32 +78,31 @@ def drop_relation(self, relation, model_name=None): relation_object = dataset.table(relation.identifier) client.delete_table(relation_object) - def truncate_relation(self, relation, model_name=None): + def truncate_relation(self, relation): raise dbt.exceptions.NotImplementedException( '`truncate` is not implemented for this adapter!' ) - def rename_relation(self, from_relation, to_relation, model_name=None): + def rename_relation(self, from_relation, to_relation): raise dbt.exceptions.NotImplementedException( '`rename_relation` is not implemented for this adapter!' ) - def list_schemas(self, database, model_name=None): - conn = self.connections.get(model_name) + def list_schemas(self, database): + conn = self.connections.get_thread_connection() client = conn.handle - with self.connections.exception_handler('list dataset', conn.name): + with self.connections.exception_handler('list dataset'): all_datasets = client.list_datasets(project=database, include_all=True) return [ds.dataset_id for ds in all_datasets] - def get_columns_in_relation(self, relation, model_name=None): + def get_columns_in_relation(self, relation): try: table = self.connections.get_bq_table( database=relation.database, schema=relation.schema, - identifier=relation.table_name, - conn_name=model_name + identifier=relation.table_name ) return self._get_dbt_columns_from_bq_table(table) @@ -112,13 +110,12 @@ def get_columns_in_relation(self, relation, model_name=None): logger.debug("get_columns_in_relation error: {}".format(e)) return [] - def expand_column_types(self, goal, current, model_name=None): + def expand_column_types(self, goal, current): # This is a no-op on BigQuery pass - def list_relations_without_caching(self, information_schema, schema, - model_name=None): - connection = self.connections.get(model_name) + def list_relations_without_caching(self, information_schema, schema): + connection = self.connections.get_thread_connection() client = connection.handle bigquery_dataset = self.connections.dataset( @@ -144,15 +141,14 @@ def list_relations_without_caching(self, information_schema, schema, except google.api_core.exceptions.NotFound as e: return [] - def get_relation(self, database, schema, identifier, model_name=None): - if self._schema_is_cached(database, schema, model_name): + def get_relation(self, database, schema, identifier): + if self._schema_is_cached(database, schema): # if it's in the cache, use the parent's model of going through # the relations cache and picking out the relation return super(BigQueryAdapter, self).get_relation( database=database, schema=schema, - identifier=identifier, - model_name=model_name + identifier=identifier ) try: @@ -161,16 +157,16 @@ def get_relation(self, database, schema, identifier, model_name=None): table = None return self._bq_table_to_relation(table) - def create_schema(self, database, schema, model_name=None): + def create_schema(self, database, schema): logger.debug('Creating schema "%s.%s".', database, schema) - self.connections.create_dataset(database, schema, model_name) + self.connections.create_dataset(database, schema) - def drop_schema(self, database, schema, model_name=None): + def drop_schema(self, database, schema): logger.debug('Dropping schema "%s.%s".', database, schema) - if not self.check_schema_exists(database, schema, model_name): + if not self.check_schema_exists(database, schema): return - self.connections.drop_dataset(database, schema, model_name) + self.connections.drop_dataset(database, schema) @classmethod def quote(cls, identifier): @@ -232,16 +228,14 @@ def _agate_to_schema(self, agate_table, column_override): def _materialize_as_view(self, model): model_database = model.get('database') model_schema = model.get('schema') - model_name = model.get('name') model_alias = model.get('alias') model_sql = model.get('injected_sql') - logger.debug("Model SQL ({}):\n{}".format(model_name, model_sql)) + logger.debug("Model SQL ({}):\n{}".format(model_alias, model_sql)) self.connections.create_view( database=model_database, schema=model_schema, table_name=model_alias, - conn_name=model_name, sql=model_sql ) return "CREATE VIEW" @@ -249,7 +243,6 @@ def _materialize_as_view(self, model): def _materialize_as_table(self, model, model_sql, decorator=None): model_database = model.get('database') model_schema = model.get('schema') - model_name = model.get('name') model_alias = model.get('alias') if decorator is None: @@ -261,7 +254,6 @@ def _materialize_as_table(self, model, model_sql, decorator=None): self.connections.create_table( database=model_database, schema=model_schema, - conn_name=model_name, table_name=table_name, sql=model_sql ) @@ -307,10 +299,10 @@ def warning_on_hooks(hook_type): dbt.ui.printer.COLOR_FG_YELLOW) @available - def add_query(self, sql, model_name=None, auto_begin=True, - bindings=None, abridge_sql_log=False): - if model_name in ['on-run-start', 'on-run-end']: - self.warning_on_hooks(model_name) + def add_query(self, sql, auto_begin=True, bindings=None, + abridge_sql_log=False): + if self.nice_connection_name() in ['on-run-start', 'on-run-end']: + self.warning_on_hooks(self.nice_connection_name()) else: raise dbt.exceptions.NotImplementedException( '`add_query` is not implemented for this adapter!') @@ -319,24 +311,24 @@ def add_query(self, sql, model_name=None, auto_begin=True, # Special bigquery adapter methods ### @available - def make_date_partitioned_table(self, relation, model_name=None): + def make_date_partitioned_table(self, relation): return self.connections.create_date_partitioned_table( database=relation.database, schema=relation.schema, - table_name=relation.identifier, - conn_name=model_name + table_name=relation.identifier ) @available def execute_model(self, model, materialization, sql_override=None, - decorator=None, model_name=None): + decorator=None): if sql_override is None: sql_override = model.get('injected_sql') if flags.STRICT_MODE: - connection = self.connections.get(model.get('name')) + connection = self.connections.get_thread_connection() assert isinstance(connection, Connection) + assert(connection.name == model.get('name')) if materialization == 'view': res = self._materialize_as_view(model) @@ -349,10 +341,10 @@ def execute_model(self, model, materialization, sql_override=None, return res @available - def create_temporary_table(self, sql, model_name=None, **kwargs): + def create_temporary_table(self, sql, **kwargs): # BQ queries always return a temp table with their results - query_job, _ = self.connections.raw_execute(sql, model_name) + query_job, _ = self.connections.raw_execute(sql) bq_table = query_job.destination return self.Relation.create( @@ -366,12 +358,12 @@ def create_temporary_table(self, sql, model_name=None, **kwargs): type=BigQueryRelation.Table) @available - def alter_table_add_columns(self, relation, columns, model_name=None): + def alter_table_add_columns(self, relation, columns): logger.debug('Adding columns ({}) to table {}".'.format( columns, relation)) - conn = self.connections.get(model_name) + conn = self.connections.get_thread_connection() client = conn.handle table_ref = self.connections.table_ref(relation.database, @@ -387,9 +379,9 @@ def alter_table_add_columns(self, relation, columns, model_name=None): @available def load_dataframe(self, database, schema, table_name, agate_table, - column_override, model_name=None): + column_override): bq_schema = self._agate_to_schema(agate_table, column_override) - conn = self.connections.get(model_name) + conn = self.connections.get_thread_connection() client = conn.handle table = self.connections.table_ref(database, schema, table_name, conn) @@ -403,7 +395,7 @@ def load_dataframe(self, database, schema, table_name, agate_table, job_config=load_config) timeout = self.connections.get_timeout(conn) - with self.connections.exception_handler("LOAD TABLE", conn.name): + with self.connections.exception_handler("LOAD TABLE"): self.poll_until_job_completes(job, timeout) ### @@ -474,7 +466,7 @@ def _get_stats_columns(cls, table, relation_type): return zip(column_names, column_values) def get_catalog(self, manifest): - connection = self.connections.get('catalog') + connection = self.connections.get_thread_connection() client = connection.handle schemas = manifest.get_used_schemas() diff --git a/plugins/postgres/dbt/adapters/postgres/connections.py b/plugins/postgres/dbt/adapters/postgres/connections.py index 664d79ff541..6ba185ada92 100644 --- a/plugins/postgres/dbt/adapters/postgres/connections.py +++ b/plugins/postgres/dbt/adapters/postgres/connections.py @@ -61,7 +61,7 @@ class PostgresConnectionManager(SQLConnectionManager): TYPE = 'postgres' @contextmanager - def exception_handler(self, sql, connection_name='master'): + def exception_handler(self, sql): try: yield @@ -70,7 +70,7 @@ def exception_handler(self, sql, connection_name='master'): try: # attempt to release the connection - self.release(connection_name) + self.release() except psycopg2.Error: logger.debug("Failed to release connection!") pass @@ -81,7 +81,7 @@ def exception_handler(self, sql, connection_name='master'): except Exception as e: logger.debug("Error running SQL: %s", sql) logger.debug("Rolling back transaction.") - self.release(connection_name) + self.release() raise dbt.exceptions.RuntimeException(e) @classmethod @@ -90,7 +90,6 @@ def open(cls, connection): logger.debug('Connection is already open, skipping open.') return connection - base_credentials = connection.credentials credentials = cls.get_credentials(connection.credentials.incorporate()) kwargs = {} keepalives_idle = credentials.get('keepalives_idle', diff --git a/plugins/postgres/dbt/adapters/postgres/impl.py b/plugins/postgres/dbt/adapters/postgres/impl.py index 87487c7f791..88be130cabb 100644 --- a/plugins/postgres/dbt/adapters/postgres/impl.py +++ b/plugins/postgres/dbt/adapters/postgres/impl.py @@ -1,15 +1,8 @@ -import psycopg2 - -import time - -from dbt.adapters.base.meta import available_raw +from dbt.adapters.base.meta import available from dbt.adapters.sql import SQLAdapter from dbt.adapters.postgres import PostgresConnectionManager import dbt.compat import dbt.exceptions -import agate - -from dbt.logger import GLOBAL_LOGGER as logger # note that this isn't an adapter macro, so just a single underscore @@ -23,7 +16,7 @@ class PostgresAdapter(SQLAdapter): def date_function(cls): return 'now()' - @available_raw + @available def verify_database(self, database): database = database.strip('"') expected = self.config.credentials.database @@ -75,10 +68,7 @@ def _link_cached_relations(self, manifest): self.verify_database(db) schemas.add(schema) - try: - self._link_cached_database_relations(schemas) - finally: - self.release_connection(GET_RELATIONS_MACRO_NAME) + self._link_cached_database_relations(schemas) def _relations_cache_for_schemas(self, manifest): super(PostgresAdapter, self)._relations_cache_for_schemas(manifest) diff --git a/plugins/redshift/dbt/adapters/redshift/connections.py b/plugins/redshift/dbt/adapters/redshift/connections.py index fe85a0e1858..9ba5dcc792a 100644 --- a/plugins/redshift/dbt/adapters/redshift/connections.py +++ b/plugins/redshift/dbt/adapters/redshift/connections.py @@ -95,16 +95,16 @@ def fresh_transaction(self, name=None): """ with drop_lock: - connection = self.get(name) + connection = self.get_thread_connection() if connection.transaction_open: - self.commit(connection) + self.commit() - self.begin(connection.name) + self.begin() yield - self.commit(connection) - self.begin(connection.name) + self.commit() + self.begin() @classmethod def fetch_cluster_credentials(cls, db_user, db_name, cluster_id, diff --git a/plugins/redshift/dbt/adapters/redshift/impl.py b/plugins/redshift/dbt/adapters/redshift/impl.py index 08f0dcff0e4..50934fba862 100644 --- a/plugins/redshift/dbt/adapters/redshift/impl.py +++ b/plugins/redshift/dbt/adapters/redshift/impl.py @@ -1,7 +1,6 @@ from dbt.adapters.postgres import PostgresAdapter from dbt.adapters.redshift import RedshiftConnectionManager from dbt.logger import GLOBAL_LOGGER as logger # noqa -import dbt.exceptions class RedshiftAdapter(PostgresAdapter): @@ -13,7 +12,7 @@ class RedshiftAdapter(PostgresAdapter): def date_function(cls): return 'getdate()' - def drop_relation(self, relation, model_name=None): + def drop_relation(self, relation): """ In Redshift, DROP TABLE ... CASCADE should not be used inside a transaction. Redshift doesn't prevent the CASCADE @@ -28,9 +27,9 @@ def drop_relation(self, relation, model_name=None): https://docs.aws.amazon.com/redshift/latest/dg/r_DROP_TABLE.html """ - with self.connections.fresh_transaction(model_name): + with self.connections.fresh_transaction(): parent = super(RedshiftAdapter, self) - return parent.drop_relation(relation, model_name) + return parent.drop_relation(relation) @classmethod def convert_text_type(cls, agate_table, col_idx): diff --git a/plugins/snowflake/dbt/adapters/snowflake/connections.py b/plugins/snowflake/dbt/adapters/snowflake/connections.py index c7f117a060c..a6b9a67e2a7 100644 --- a/plugins/snowflake/dbt/adapters/snowflake/connections.py +++ b/plugins/snowflake/dbt/adapters/snowflake/connections.py @@ -72,7 +72,7 @@ class SnowflakeConnectionManager(SQLConnectionManager): TYPE = 'snowflake' @contextmanager - def exception_handler(self, sql, connection_name='master'): + def exception_handler(self, sql): try: yield except snowflake.connector.errors.ProgrammingError as e: @@ -83,7 +83,7 @@ def exception_handler(self, sql, connection_name='master'): if 'Empty SQL statement' in msg: logger.debug("got empty sql statement, moving on") elif 'This session does not have a current database' in msg: - self.release(connection_name) + self.release() raise dbt.exceptions.FailedToConnectException( ('{}\n\nThis error sometimes occurs when invalid ' 'credentials are provided, or when your default role ' @@ -91,12 +91,12 @@ def exception_handler(self, sql, connection_name='master'): 'Please double check your profile and try again.') .format(msg)) else: - self.release(connection_name) + self.release() raise dbt.exceptions.DatabaseException(msg) except Exception as e: logger.debug("Error running SQL: %s", sql) logger.debug("Rolling back transaction.") - self.release(connection_name) + self.release() raise dbt.exceptions.RuntimeException(e.msg) @classmethod @@ -141,8 +141,6 @@ def open(cls, connection): raise dbt.exceptions.FailedToConnectException(str(e)) - return connection - def cancel(self, connection): handle = connection.handle sid = handle.session_id @@ -193,7 +191,7 @@ def _get_private_key(cls, private_key_path, private_key_passphrase): format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption()) - def add_query(self, sql, model_name=None, auto_begin=True, + def add_query(self, sql, auto_begin=True, bindings=None, abridge_sql_log=False): connection = None @@ -219,7 +217,7 @@ def add_query(self, sql, model_name=None, auto_begin=True, parent = super(SnowflakeConnectionManager, self) connection, cursor = parent.add_query( - individual_query, model_name, auto_begin, + individual_query, auto_begin, bindings=bindings, abridge_sql_log=abridge_sql_log ) @@ -229,11 +227,14 @@ def add_query(self, sql, model_name=None, auto_begin=True, "Tried to run an empty query on model '{}'. If you are " "conditionally running\nsql, eg. in a model hook, make " "sure your `else` clause contains valid sql!\n\n" - "Provided SQL:\n{}".format(model_name, sql)) + "Provided SQL:\n{}" + .format(self.nice_connection_name(), sql) + ) return connection, cursor - def _rollback_handle(self, connection): + @classmethod + def _rollback_handle(cls, connection): """On snowflake, rolling back the handle of an aborted session raises an exception. """ diff --git a/test/integration/007_graph_selection_tests/test_graph_selection.py b/test/integration/007_graph_selection_tests/test_graph_selection.py index 7f6bfa87d73..1e2890acd14 100644 --- a/test/integration/007_graph_selection_tests/test_graph_selection.py +++ b/test/integration/007_graph_selection_tests/test_graph_selection.py @@ -12,20 +12,19 @@ def models(self): return "test/integration/007_graph_selection_tests/models" def assert_correct_schemas(self): - exists = self.adapter.check_schema_exists( - self.default_database, - self.unique_schema(), - '__test' - ) - self.assertTrue(exists) - - schema = self.unique_schema()+'_and_then' - exists = self.adapter.check_schema_exists( - self.default_database, - schema, - '__test' - ) - self.assertFalse(exists) + with self.test_connection(): + exists = self.adapter.check_schema_exists( + self.default_database, + self.unique_schema() + ) + self.assertTrue(exists) + + schema = self.unique_schema()+'_and_then' + exists = self.adapter.check_schema_exists( + self.default_database, + schema + ) + self.assertFalse(exists) @attr(type='postgres') def test__postgres__specific_model(self): diff --git a/test/integration/032_concurrent_transaction_test/test_concurrent_transaction.py b/test/integration/032_concurrent_transaction_test/test_concurrent_transaction.py index cfb9876a4c7..afb0b710125 100644 --- a/test/integration/032_concurrent_transaction_test/test_concurrent_transaction.py +++ b/test/integration/032_concurrent_transaction_test/test_concurrent_transaction.py @@ -1,6 +1,7 @@ from nose.plugins.attrib import attr from test.integration.base import DBTIntegrationTest import threading +from dbt.adapters.factory import get_adapter class BaseTestConcurrentTransaction(DBTIntegrationTest): @@ -10,6 +11,10 @@ def reset(self): 'model_1': 'wait', } + def setUp(self): + super(BaseTestConcurrentTransaction, self).setUp() + self.reset() + @property def schema(self): return "concurrent_transaction_032" @@ -26,7 +31,8 @@ def project_config(self): def run_select_and_check(self, rel, sql): connection_name = '__test_{}'.format(id(threading.current_thread())) try: - res = self.run_sql(sql, fetch='one', connection_name=connection_name) + with get_adapter(self.config).connection_named(connection_name) as conn: + res = self.run_sql_common(self.transform_sql(sql), 'one', conn) # The result is the output of f_sleep(), which is True if res[0] == True: @@ -54,7 +60,7 @@ def async_select(self, rel, sleep=10): sleep=sleep, rel=rel) - thread = threading.Thread(target=lambda: self.run_select_and_check(rel, query)) + thread = threading.Thread(target=self.run_select_and_check, args=(rel, query)) thread.start() return thread diff --git a/test/integration/037_external_reference_test/test_external_reference.py b/test/integration/037_external_reference_test/test_external_reference.py index bd754ae169c..ba6bf73bdb6 100644 --- a/test/integration/037_external_reference_test/test_external_reference.py +++ b/test/integration/037_external_reference_test/test_external_reference.py @@ -29,8 +29,8 @@ def tearDown(self): # This has to happen before we drop the external schema, because # otherwise postgres hangs forever. self._drop_schemas() - self.adapter.drop_schema(self.default_database, self.external_schema, - model_name='__test') + with self.test_connection(): + self.adapter.drop_schema(self.default_database, self.external_schema) super(TestExternalReference, self).tearDown() @use_profile('postgres') @@ -39,6 +39,7 @@ def test__postgres__external_reference(self): # running it again should succeed self.assertEquals(len(self.run_dbt()), 1) + # The opposite of the test above -- check that external relations that # depend on a dbt model do not create issues with caching class TestExternalDependency(DBTIntegrationTest): @@ -54,8 +55,8 @@ def tearDown(self): # This has to happen before we drop the external schema, because # otherwise postgres hangs forever. self._drop_schemas() - self.adapter.drop_schema(self.default_database, self.external_schema, - model_name='__test') + with self.test_connection(): + self.adapter.drop_schema(self.default_database, self.external_schema) super(TestExternalDependency, self).tearDown() @use_profile('postgres') diff --git a/test/integration/base.py b/test/integration/base.py index 137e04afbae..369bc09c387 100644 --- a/test/integration/base.py +++ b/test/integration/base.py @@ -1,6 +1,7 @@ import unittest import dbt.main as dbt -import os, shutil +import os +import shutil import yaml import random import time @@ -9,13 +10,17 @@ from functools import wraps from nose.plugins.attrib import attr +from mock import patch import dbt.flags as flags from dbt.adapters.factory import get_adapter, reset_adapters from dbt.clients.jinja import template_cache from dbt.config import RuntimeConfig -from dbt.compat import basestring, suppress_warnings +from dbt.compat import basestring +from dbt.context import common + +from contextlib import contextmanager from dbt.logger import GLOBAL_LOGGER as logger import logging @@ -372,7 +377,7 @@ def _get_schema_fqn(self, database, schema): def _create_schema_named(self, database, schema): if self.adapter_type == 'bigquery': - self.adapter.create_schema(database, schema, '__test') + self.adapter.create_schema(database, schema) else: schema_fqn = self._get_schema_fqn(database, schema) self.run_sql(self.CREATE_SCHEMA_STATEMENT.format(schema_fqn)) @@ -381,7 +386,7 @@ def _create_schema_named(self, database, schema): def _drop_schema_named(self, database, schema): if self.adapter_type == 'bigquery' or self.adapter_type == 'presto': self.adapter.drop_schema( - database, schema, '__test' + database, schema ) else: schema_fqn = self._get_schema_fqn(database, schema) @@ -389,9 +394,10 @@ def _drop_schema_named(self, database, schema): def _create_schemas(self): schema = self.unique_schema() - self._create_schema_named(self.default_database, schema) - if self.setup_alternate_db and self.adapter_type == 'snowflake': - self._create_schema_named(self.alternative_database, schema) + with self.adapter.connection_named('__test'): + self._create_schema_named(self.default_database, schema) + if self.setup_alternate_db and self.adapter_type == 'snowflake': + self._create_schema_named(self.alternative_database, schema) def _drop_schemas_adapter(self): schema = self.unique_schema() @@ -421,10 +427,11 @@ def _drop_schemas_sql(self): self._created_schemas.clear() def _drop_schemas(self): - if self.adapter_type == 'bigquery' or self.adapter_type == 'presto': - self._drop_schemas_adapter() - else: - self._drop_schemas_sql() + with self.adapter.connection_named('__test'): + if self.adapter_type == 'bigquery' or self.adapter_type == 'presto': + self._drop_schemas_adapter() + else: + self._drop_schemas_sql() @property def project_config(self): @@ -497,8 +504,7 @@ def run_sql_bigquery(self, sql, fetch): else: return list(res) - def run_sql_presto(self, sql, fetch, connection_name=None): - conn = self.adapter.acquire_connection(connection_name) + def run_sql_presto(self, sql, fetch, conn): cursor = conn.handle.cursor() try: cursor.execute(sql) @@ -519,6 +525,24 @@ def run_sql_presto(self, sql, fetch, connection_name=None): conn.handle.commit() conn.transaction_open = False + def run_sql_common(self, sql, fetch, conn): + with conn.handle.cursor() as cursor: + try: + cursor.execute(sql) + conn.handle.commit() + if fetch == 'one': + return cursor.fetchone() + elif fetch == 'all': + return cursor.fetchall() + else: + return + except BaseException as e: + conn.handle.rollback() + print(sql) + print(e) + raise e + finally: + conn.transaction_open = False def run_sql(self, query, fetch='None', kwargs=None, connection_name=None): if connection_name is None: @@ -528,30 +552,15 @@ def run_sql(self, query, fetch='None', kwargs=None, connection_name=None): return sql = self.transform_sql(query, kwargs=kwargs) - if self.adapter_type == 'bigquery': - return self.run_sql_bigquery(sql, fetch) - elif self.adapter_type == 'presto': - return self.run_sql_presto(sql, fetch, connection_name) - - conn = self.adapter.acquire_connection(connection_name) - with conn.handle.cursor() as cursor: - logger.debug('test connection "{}" executing: {}'.format(connection_name, sql)) - try: - cursor.execute(sql) - conn.handle.commit() - if fetch == 'one': - return cursor.fetchone() - elif fetch == 'all': - return cursor.fetchall() - else: - return - except BaseException as e: - conn.handle.rollback() - print(query) - print(e) - raise e - finally: - conn.transaction_open = False + + with self.test_connection(connection_name) as conn: + logger.debug('test connection "{}" executing: {}'.format(conn.name, sql)) + if self.adapter_type == 'bigquery': + return self.run_sql_bigquery(sql, fetch) + elif self.adapter_type == 'presto': + return self.run_sql_presto(sql, fetch, conn) + else: + return self.run_sql_common(sql, fetch, conn) def _ilike(self, target, value): # presto has this regex substitution monstrosity instead of 'ilike' @@ -612,11 +621,23 @@ def filter_many_columns(self, column): char_size = 16777216 return (table_name, column_name, data_type, char_size) + @contextmanager + def test_connection(self, name=None): + """Create a test connection context where all executed macros, etc will + get self.adapter as the adapter. + + This allows tests to run normal adapter macros as if reset_adapters() + were not called by handle_and_check (for asserts, etc) + """ + if name is None: + name = '__test' + with patch.object(common, 'get_adapter', return_value=self.adapter): + with self.adapter.connection_named(name) as conn: + yield conn + def get_relation_columns(self, relation): - columns = self.adapter.get_columns_in_relation( - relation, - model_name='__test' - ) + with self.test_connection(): + columns = self.adapter.get_columns_in_relation(relation) return sorted(((c.name, c.dtype, c.char_size) for c in columns), key=lambda x: x[0]) @@ -781,7 +802,8 @@ def assertManyRelationsEqual(self, relations, default_schema=None, default_datab specs.append(relation) - column_specs = self.get_many_relation_columns(specs) + with self.test_connection(): + column_specs = self.get_many_relation_columns(specs) # make sure everyone has equal column definitions first_columns = None diff --git a/test/unit/test_bigquery_adapter.py b/test/unit/test_bigquery_adapter.py index b667cf38a6f..f982d40833f 100644 --- a/test/unit/test_bigquery_adapter.py +++ b/test/unit/test_bigquery_adapter.py @@ -69,7 +69,7 @@ def get_adapter(self, target): profile=profile, ) adapter = BigQueryAdapter(config) - inject_adapter('bigquery', adapter) + inject_adapter(adapter) return adapter @@ -109,14 +109,14 @@ def test_cancel_open_connections_empty(self): def test_cancel_open_connections_master(self): adapter = self.get_adapter('oauth') - adapter.connections.in_use['master'] = object() + adapter.connections.thread_connections[0] = object() self.assertEqual(adapter.cancel_open_connections(), None) def test_cancel_open_connections_single(self): adapter = self.get_adapter('oauth') - adapter.connections.in_use.update({ - 'master': object(), - 'model': object(), + adapter.connections.thread_connections.update({ + 0: object(), + 1: object(), }) # actually does nothing self.assertEqual(adapter.cancel_open_connections(), None) diff --git a/test/unit/test_graph.py b/test/unit/test_graph.py index 88f3f15e694..37d2f1a7710 100644 --- a/test/unit/test_graph.py +++ b/test/unit/test_graph.py @@ -32,6 +32,7 @@ def tearDown(self): self.load_projects_patcher.stop() self.find_matching_patcher.stop() self.load_file_contents_patcher.stop() + self.get_adapter_patcher.stop() def setUp(self): dbt.flags.STRICT_MODE = True @@ -41,6 +42,8 @@ def setUp(self): self.load_projects_patcher = patch('dbt.loader._load_projects') self.find_matching_patcher = patch('dbt.clients.system.find_matching') self.load_file_contents_patcher = patch('dbt.clients.system.load_file_contents') + self.get_adapter_patcher = patch('dbt.context.parser.get_adapter') + self.factory = self.get_adapter_patcher.start() def mock_write_gpickle(graph, outfile): self.graph_result = graph @@ -52,7 +55,7 @@ def mock_write_gpickle(graph, outfile): 'test': { 'type': 'postgres', 'threads': 4, - 'host': 'database', + 'host': 'thishostshouldnotexist', 'port': 5432, 'user': 'root', 'pass': 'password', diff --git a/test/unit/test_parser.py b/test/unit/test_parser.py index a4e23b444c3..98598cebf9f 100644 --- a/test/unit/test_parser.py +++ b/test/unit/test_parser.py @@ -54,7 +54,6 @@ def setUp(self): 'project-root': os.path.abspath('.'), } - self.root_project_config = config_from_parts_or_dicts( project=root_project, profile=profile_data, @@ -76,7 +75,11 @@ def setUp(self): 'root': self.root_project_config, 'snowplow': self.snowplow_project_config } + self.patcher = mock.patch('dbt.context.parser.get_adapter') + self.factory = self.patcher.start() + def tearDown(self): + self.patcher.stop() class SourceConfigTest(BaseParserTest): diff --git a/test/unit/test_postgres_adapter.py b/test/unit/test_postgres_adapter.py index 7e76cafa0c1..745f101c46c 100644 --- a/test/unit/test_postgres_adapter.py +++ b/test/unit/test_postgres_adapter.py @@ -8,9 +8,10 @@ from dbt.exceptions import ValidationException from dbt.logger import GLOBAL_LOGGER as logger # noqa from psycopg2 import extensions as psycopg2_extensions +from psycopg2 import DatabaseError, Error import agate -from .utils import config_from_parts_or_dicts, inject_adapter +from .utils import config_from_parts_or_dicts, inject_adapter, mock_connection class TestPostgresAdapter(unittest.TestCase): @@ -29,7 +30,7 @@ def setUp(self): 'type': 'postgres', 'dbname': 'postgres', 'user': 'root', - 'host': 'database', + 'host': 'thishostshouldnotexist', 'pass': 'password', 'port': 5432, 'schema': 'public' @@ -45,7 +46,7 @@ def setUp(self): def adapter(self): if self._adapter is None: self._adapter = PostgresAdapter(self.config) - inject_adapter('postgres', self._adapter) + inject_adapter(self._adapter) return self._adapter @mock.patch('dbt.adapters.postgres.connections.psycopg2') @@ -72,17 +73,18 @@ def test_cancel_open_connections_empty(self): self.assertEqual(len(list(self.adapter.cancel_open_connections())), 0) def test_cancel_open_connections_master(self): - self.adapter.connections.in_use['master'] = mock.MagicMock() + key = self.adapter.connections.get_thread_identifier() + self.adapter.connections.thread_connections[key] = mock_connection('master') self.assertEqual(len(list(self.adapter.cancel_open_connections())), 0) def test_cancel_open_connections_single(self): - master = mock.MagicMock() - model = mock.MagicMock() + master = mock_connection('master') + model = mock_connection('model') + key = self.adapter.connections.get_thread_identifier() model.handle.get_backend_pid.return_value = 42 - - self.adapter.connections.in_use.update({ - 'master': master, - 'model': model, + self.adapter.connections.thread_connections.update({ + key: master, + 1: model, }) with mock.patch.object(self.adapter.connections, 'add_query') as add_query: query_result = mock.MagicMock() @@ -102,7 +104,7 @@ def test_default_keepalive(self, psycopg2): psycopg2.connect.assert_called_once_with( dbname='postgres', user='root', - host='database', + host='thishostshouldnotexist', password='password', port=5432, connect_timeout=10) @@ -117,7 +119,7 @@ def test_changed_keepalive(self, psycopg2): psycopg2.connect.assert_called_once_with( dbname='postgres', user='root', - host='database', + host='thishostshouldnotexist', password='password', port=5432, connect_timeout=10, @@ -133,7 +135,7 @@ def test_set_zero_keepalive(self, psycopg2): psycopg2.connect.assert_called_once_with( dbname='postgres', user='root', - host='database', + host='thishostshouldnotexist', password='password', port=5432, connect_timeout=10) @@ -172,7 +174,7 @@ def setUp(self): 'type': 'postgres', 'dbname': 'postgres', 'user': 'root', - 'host': 'database', + 'host': 'thishostshouldnotexist', 'pass': 'password', 'port': 5432, 'schema': 'public' @@ -198,10 +200,14 @@ def setUp(self): self.mock_execute = self.cursor.execute self.patcher = mock.patch('dbt.adapters.postgres.connections.psycopg2') self.psycopg2 = self.patcher.start() + # there must be a better way to do this... + self.psycopg2.DatabaseError = DatabaseError + self.psycopg2.Error = Error self.psycopg2.connect.return_value = self.handle self.adapter = PostgresAdapter(self.config) - inject_adapter('postgres', self.adapter) + self.adapter.acquire_connection() + inject_adapter(self.adapter) def tearDown(self): # we want a unique self.handle every time. diff --git a/test/unit/test_redshift_adapter.py b/test/unit/test_redshift_adapter.py index 5611a2a6efc..8d6a5184751 100644 --- a/test/unit/test_redshift_adapter.py +++ b/test/unit/test_redshift_adapter.py @@ -9,7 +9,7 @@ from dbt.exceptions import ValidationException, FailedToConnectException from dbt.logger import GLOBAL_LOGGER as logger # noqa -from .utils import config_from_parts_or_dicts +from .utils import config_from_parts_or_dicts, mock_connection @classmethod @@ -30,7 +30,7 @@ def setUp(self): 'type': 'redshift', 'dbname': 'redshift', 'user': 'root', - 'host': 'database', + 'host': 'thishostshouldnotexist', 'pass': 'password', 'port': 5439, 'schema': 'public' @@ -106,17 +106,19 @@ def test_cancel_open_connections_empty(self): self.assertEqual(len(list(self.adapter.cancel_open_connections())), 0) def test_cancel_open_connections_master(self): - self.adapter.connections.in_use['master'] = mock.MagicMock() + key = self.adapter.connections.get_thread_identifier() + self.adapter.connections.thread_connections[key] = mock_connection('master') self.assertEqual(len(list(self.adapter.cancel_open_connections())), 0) def test_cancel_open_connections_single(self): - master = mock.MagicMock() - model = mock.MagicMock() + master = mock_connection('master') + model = mock_connection('model') model.handle.get_backend_pid.return_value = 42 - self.adapter.connections.in_use.update({ - 'master': master, - 'model': model, + key = self.adapter.connections.get_thread_identifier() + self.adapter.connections.thread_connections.update({ + key: master, + 1: model, }) with mock.patch.object(self.adapter.connections, 'add_query') as add_query: query_result = mock.MagicMock() @@ -135,7 +137,7 @@ def test_default_keepalive(self, psycopg2): psycopg2.connect.assert_called_once_with( dbname='redshift', user='root', - host='database', + host='thishostshouldnotexist', password='password', port=5439, connect_timeout=10, @@ -152,7 +154,7 @@ def test_changed_keepalive(self, psycopg2): psycopg2.connect.assert_called_once_with( dbname='redshift', user='root', - host='database', + host='thishostshouldnotexist', password='password', port=5439, connect_timeout=10, @@ -168,7 +170,7 @@ def test_set_zero_keepalive(self, psycopg2): psycopg2.connect.assert_called_once_with( dbname='redshift', user='root', - host='database', + host='thishostshouldnotexist', password='password', port=5439, connect_timeout=10) diff --git a/test/unit/test_snowflake_adapter.py b/test/unit/test_snowflake_adapter.py index db5394c3fe3..1e55e09e7b0 100644 --- a/test/unit/test_snowflake_adapter.py +++ b/test/unit/test_snowflake_adapter.py @@ -11,7 +11,7 @@ from dbt.logger import GLOBAL_LOGGER as logger # noqa from snowflake import connector as snowflake_connector -from .utils import config_from_parts_or_dicts, inject_adapter +from .utils import config_from_parts_or_dicts, inject_adapter, mock_connection class TestSnowflakeAdapter(unittest.TestCase): @@ -54,8 +54,8 @@ def setUp(self): self.snowflake.return_value = self.handle self.adapter = SnowflakeAdapter(self.config) - # patch our new adapter into the factory so macros behave - inject_adapter('snowflake', self.adapter) + self.adapter.acquire_connection() + inject_adapter(self.adapter) def tearDown(self): # we want a unique self.handle every time. @@ -134,17 +134,19 @@ def test_cancel_open_connections_empty(self): self.assertEqual(len(list(self.adapter.cancel_open_connections())), 0) def test_cancel_open_connections_master(self): - self.adapter.connections.in_use['master'] = mock.MagicMock() + key = self.adapter.connections.get_thread_identifier() + self.adapter.connections.thread_connections[key] = mock_connection('master') self.assertEqual(len(list(self.adapter.cancel_open_connections())), 0) def test_cancel_open_connections_single(self): - master = mock.MagicMock() - model = mock.MagicMock() + master = mock_connection('master') + model = mock_connection('model') model.handle.session_id = 42 - self.adapter.connections.in_use.update({ - 'master': master, - 'model': model, + key = self.adapter.connections.get_thread_identifier() + self.adapter.connections.thread_connections.update({ + key: master, + 1: model, }) with mock.patch.object(self.adapter.connections, 'add_query') as add_query: query_result = mock.MagicMock() @@ -157,7 +159,7 @@ def test_cancel_open_connections_single(self): 'select system$abort_session(42)', 'master') def test_client_session_keep_alive_false_by_default(self): - self.adapter.connections.get(name='new_connection_with_new_config') + self.adapter.connections.set_connection_name(name='new_connection_with_new_config') self.snowflake.assert_has_calls([ mock.call( account='test_account', autocommit=False, @@ -170,7 +172,7 @@ def test_client_session_keep_alive_true(self): self.config.credentials = self.config.credentials.incorporate( client_session_keep_alive=True) self.adapter = SnowflakeAdapter(self.config) - self.adapter.connections.get(name='new_connection_with_new_config') + self.adapter.connections.set_connection_name(name='new_connection_with_new_config') self.snowflake.assert_has_calls([ mock.call( @@ -184,7 +186,7 @@ def test_user_pass_authentication(self): self.config.credentials = self.config.credentials.incorporate( password='test_password') self.adapter = SnowflakeAdapter(self.config) - self.adapter.connections.get(name='new_connection_with_new_config') + self.adapter.connections.set_connection_name(name='new_connection_with_new_config') self.snowflake.assert_has_calls([ mock.call( @@ -198,7 +200,7 @@ def test_authenticator_user_pass_authentication(self): self.config.credentials = self.config.credentials.incorporate( password='test_password', authenticator='test_sso_url') self.adapter = SnowflakeAdapter(self.config) - self.adapter.connections.get(name='new_connection_with_new_config') + self.adapter.connections.set_connection_name(name='new_connection_with_new_config') self.snowflake.assert_has_calls([ mock.call( @@ -213,7 +215,7 @@ def test_authenticator_externalbrowser_authentication(self): self.config.credentials = self.config.credentials.incorporate( authenticator='externalbrowser') self.adapter = SnowflakeAdapter(self.config) - self.adapter.connections.get(name='new_connection_with_new_config') + self.adapter.connections.set_connection_name(name='new_connection_with_new_config') self.snowflake.assert_has_calls([ mock.call( @@ -231,7 +233,7 @@ def test_authenticator_private_key_authentication(self, mock_get_private_key): private_key_passphrase='p@ssphr@se') self.adapter = SnowflakeAdapter(self.config) - self.adapter.connections.get(name='new_connection_with_new_config') + self.adapter.connections.set_connection_name(name='new_connection_with_new_config') self.snowflake.assert_has_calls([ mock.call( diff --git a/test/unit/utils.py b/test/unit/utils.py index fafb89484e7..48a753c2ee1 100644 --- a/test/unit/utils.py +++ b/test/unit/utils.py @@ -3,11 +3,19 @@ Note that all imports should be inside the functions to avoid import/mocking issues. """ +import mock + class Obj(object): which = 'blah' +def mock_connection(name): + conn = mock.MagicMock() + conn.name = name + return conn + + def config_from_parts_or_dicts(project, profile, packages=None, cli_vars='{}'): from dbt.config import Project, Profile, RuntimeConfig from dbt.utils import parse_cli_vars @@ -29,10 +37,12 @@ def config_from_parts_or_dicts(project, profile, packages=None, cli_vars='{}'): ) -def inject_adapter(key, value): +def inject_adapter(value): """Inject the given adapter into the adapter factory, so your hand-crafted artisanal adapter will be available from get_adapter() as if dbt loaded it. """ from dbt.adapters import factory + from dbt.adapters.base.connections import BaseConnectionManager + key = value.type() factory._ADAPTERS[key] = value factory.ADAPTER_TYPES[key] = type(value)