diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 5ce42443..af4c3394 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -24,7 +24,9 @@ Tuple, cast, Union, + Hashable, ) +from numbers import Number from agate import Table @@ -35,6 +37,7 @@ from dbt.clients import agate_helper from dbt.contracts.connection import ( AdapterResponse, + AdapterRequiredConfig, Connection, ConnectionState, DEFAULT_QUERY_COMMENT, @@ -44,7 +47,10 @@ from dbt.events.types import ( NewConnection, ConnectionReused, + ConnectionLeftOpenInCleanup, + ConnectionClosedInCleanup, ) + from dbt.contracts.graph.manifest import Manifest from dbt.contracts.graph.nodes import ResultNode from dbt.events import AdapterLogger @@ -105,6 +111,13 @@ def emit(self, record: logging.LogRecord) -> None: CLIENT_ID = "dbt-databricks" SCOPES = ["all-apis", "offline_access"] +# toggle for session managements that minimizes the number of sessions opened/closed +USE_LONG_SESSIONS = os.getenv("DBT_DATABRICKS_LONG_SESSIONS", "True").upper() == "TRUE" + +# Number of idle seconds before a connection is automatically closed. Only applicable if +# USE_LONG_SESSIONS is true. +DEFAULT_MAX_IDLE_TIME = 600 + @dataclass class DatabricksCredentials(Credentials): @@ -126,6 +139,7 @@ class DatabricksCredentials(Credentials): connect_retries: int = 1 connect_timeout: Optional[int] = None retry_all: bool = False + connect_max_idle: Optional[int] = None _credentials_provider: Optional[Dict[str, Any]] = None _lock = threading.Lock() # to avoid concurrent auth @@ -714,10 +728,73 @@ class DatabricksAdapterResponse(AdapterResponse): query_id: str = "" +@dataclass(init=False) +class DatabricksDBTConnection(Connection): + last_used_time: Optional[float] = None + acquire_release_count: int = 0 + compute_name: str = "" + http_path: str = "" + thread_identifier: Tuple[int, int] = (0, 0) + max_idle_time: int = DEFAULT_MAX_IDLE_TIME + + def _acquire(self, node: Optional[ResultNode]) -> None: + """Indicate that this connection is in use.""" + logger.debug(f"DatabricksDBTConnection._acquire: {self._get_conn_info_str()}") + self._log_usage(node) + self.acquire_release_count += 1 + + def _release(self) -> None: + """Indicate that this connection is not in use.""" + logger.debug(f"DatabricksDBTConnection._release: {self._get_conn_info_str()}") + # Need to check for > 0 because in some situations the dbt code will make an extra + # release call on a connection. + if self.acquire_release_count > 0: + self.acquire_release_count -= 1 + + if self.acquire_release_count == 0: + self.last_used_time = time.time() + + def _get_idle_time(self) -> float: + return 0 if self.last_used_time is None else time.time() - self.last_used_time + + def _idle_too_long(self) -> bool: + return self.max_idle_time > 0 and self._get_idle_time() > self.max_idle_time + + def _get_conn_info_str(self) -> str: + """Generate a string describing this connection.""" + return ( + f"name: {self.name}, thread: {self.thread_identifier}, " + f"compute: `{self.compute_name}`, acquire_release_count: {self.acquire_release_count}," + f" idle time: {self._get_idle_time()}s" + ) + + def _log_usage(self, node: Optional[ResultNode]) -> None: + if node: + if not self.compute_name: + logger.debug( + f"On thread {self.thread_identifier}: {node.relation_name} " + "using default compute resource." + ) + else: + logger.debug( + f"On thread {self.thread_identifier}: {node.relation_name} " + "using compute resource '{self.compute_name}'." + ) + else: + logger.debug(f"Thread {self.thread_identifier} using default compute resource.") + + class DatabricksConnectionManager(SparkConnectionManager): TYPE: str = "databricks" credentials_provider: CredentialsProvider = None + def __init__(self, profile: AdapterRequiredConfig) -> None: + super().__init__(profile) + if USE_LONG_SESSIONS: + self.threads_compute_connections: Dict[ + Hashable, Dict[Hashable, DatabricksDBTConnection] + ] = {} + def compare_dbr_version(self, major: int, minor: int) -> int: version = (major, minor) @@ -762,12 +839,15 @@ def set_connection_name( Creates a connection for this thread if one doesn't already exist, and will rename an existing connection.""" + if USE_LONG_SESSIONS: + return self._get_compute_connection(name, node) + conn_name: str = "master" if name is None else name # Get a connection for this thread conn = self.get_if_exists() - if conn and conn.name == conn_name and conn.state == "open": + if conn and conn.name == conn_name and conn.state == ConnectionState.OPEN: # Found a connection and nothing to do, so just return it return conn @@ -788,7 +868,7 @@ def set_connection_name( NewConnection(conn_name=conn_name, conn_type=self.TYPE, node_info=get_node_info()) ) else: # existing connection either wasn't open or didn't have the right name - if conn.state != "open": + if conn.state != ConnectionState.OPEN: conn.handle = LazyHandle(self.get_open_for_model(node)) if conn.name != conn_name: orig_conn_name: str = conn.name or "" @@ -797,6 +877,199 @@ def set_connection_name( return conn + # override + def release(self) -> None: + if not USE_LONG_SESSIONS: + return super().release() + + with self.lock: + conn = cast(Optional[DatabricksDBTConnection], self.get_if_exists()) + if conn is None: + return + + conn._release() + + # override + def cleanup_all(self) -> None: + if not USE_LONG_SESSIONS: + return super().cleanup_all() + + with self.lock: + for thread_connections in self.threads_compute_connections.values(): + for connection in thread_connections.values(): + if connection.acquire_release_count > 0: + fire_event( + ConnectionLeftOpenInCleanup(conn_name=cast_to_str(connection.name)) + ) + else: + fire_event( + ConnectionClosedInCleanup(conn_name=cast_to_str(connection.name)) + ) + self.close(connection) + + # garbage collect these connections + self.thread_connections.clear() + self.threads_compute_connections.clear() + + def _get_compute_connection( + self, name: Optional[str] = None, node: Optional[ResultNode] = None + ) -> Connection: + """Called by 'set_connection_name' in DatabricksConnectionManager. + Creates a connection for this thread/node if one doesn't already + exist, and will rename an existing connection.""" + + assert ( + USE_LONG_SESSIONS + ), "This path, '_get_compute_connection', should only be reachable with USE_LONG_SESSIONS" + + self._cleanup_idle_connections() + + conn_name: str = "master" if name is None else name + + # Get a connection for this thread + conn = self._get_if_exists_compute_connection(_get_compute_name(node) or "") + + if conn is None: + conn = self._create_compute_connection(conn_name, node) + else: # existing connection either wasn't open or didn't have the right name + conn = self._update_compute_connection(conn, conn_name, node) + + conn._acquire(node) + + return conn + + def _update_compute_connection( + self, + conn: DatabricksDBTConnection, + new_name: str, + node: Optional[ResultNode] = None, + ) -> DatabricksDBTConnection: + """Update a connection that is being re-used with a new name, handle, etc.""" + assert USE_LONG_SESSIONS, ( + "This path, '_update_compute_connection', should only be " + "reachable with USE_LONG_SESSIONS" + ) + + if conn.name == new_name and conn.state == ConnectionState.OPEN: + # Found a connection and nothing to do, so just return it + return conn + + if conn.state != ConnectionState.OPEN: + conn.handle = LazyHandle(self._open2) + if conn.name != new_name: + orig_conn_name: str = conn.name or "" + conn.name = new_name + fire_event(ConnectionReused(orig_conn_name=orig_conn_name, conn_name=new_name)) + + current_thread_conn = cast(Optional[DatabricksDBTConnection], self.get_if_exists()) + if current_thread_conn and current_thread_conn.compute_name != conn.compute_name: + self.clear_thread_connection() + self.set_thread_connection(conn) + + logger.debug(f"Reusing DatabricksDBTConnection. {conn._get_conn_info_str()}") + + return conn + + def _create_compute_connection( + self, conn_name: str, node: Optional[ResultNode] = None + ) -> DatabricksDBTConnection: + """Create anew connection for the combination of current thread and compute associated + with the given node.""" + assert USE_LONG_SESSIONS, ( + "This path, '_create_compute_connection', should only be reachable " + "with USE_LONG_SESSIONS" + ) + + # Create a new connection + compute_name = _get_compute_name(node=node) or "" + logger.debug( + f"Creating DatabricksDBTConnection. name: {conn_name}, " + f"thread: {self.get_thread_identifier()}, compute: `{compute_name}`" + ) + conn = DatabricksDBTConnection( + type=Identifier(self.TYPE), + name=conn_name, + state=ConnectionState.INIT, + transaction_open=False, + handle=None, + credentials=self.profile.credentials, + ) + conn.compute_name = compute_name + creds = cast(DatabricksCredentials, self.profile.credentials) + conn.http_path = _get_http_path(node=node, creds=creds) or "" + conn.thread_identifier = cast(Tuple[int, int], self.get_thread_identifier()) + conn.max_idle_time = _get_max_idle_time(node=node, creds=creds) + + conn.handle = LazyHandle(self._open2) + # Add this connection to the thread/compute connection pool. + self._add_compute_connection(conn) + # Remove the connection currently in use by this thread from the thread connection pool. + self.clear_thread_connection() + # Add the connection to thread connection pool. + self.set_thread_connection(conn) + + fire_event( + NewConnection(conn_name=conn_name, conn_type=self.TYPE, node_info=get_node_info()) + ) + + return conn + + def _add_compute_connection(self, conn: DatabricksDBTConnection) -> None: + """Add a new connection to the map of connection per thread per compute.""" + assert ( + USE_LONG_SESSIONS + ), "This path, '_add_compute_connection', should only be reachable with USE_LONG_SESSIONS" + + with self.lock: + thread_map = self._get_compute_connections() + if conn.compute_name in thread_map: + raise dbt.exceptions.DbtInternalError( + f"In set_thread_compute_connection, connection exists for `{conn.compute_name}`" + ) + thread_map[conn.compute_name] = conn + + def _get_compute_connections( + self, + ) -> Dict[Hashable, DatabricksDBTConnection]: + """Retrieve a map of compute name to connection for the current thread.""" + assert ( + USE_LONG_SESSIONS + ), "This path, '_get_compute_connections', should only be reachable with USE_LONG_SESSIONS" + + thread_id = self.get_thread_identifier() + with self.lock: + thread_map = self.threads_compute_connections.get(thread_id) + if not thread_map: + thread_map = {} + self.threads_compute_connections[thread_id] = thread_map + return thread_map + + def _get_if_exists_compute_connection( + self, compute_name: str + ) -> Optional[DatabricksDBTConnection]: + """Get the connection for the current thread and named compute, if it exists.""" + assert USE_LONG_SESSIONS, ( + "This path, '_get_if_exists_compute_connection', should only be reachable " + "with USE_LONG_SESSIONS" + ) + + with self.lock: + threads_map = self._get_compute_connections() + return threads_map.get(compute_name) + + def _cleanup_idle_connections(self) -> None: + assert ( + USE_LONG_SESSIONS + ), "This path, '_cleanup_idle_connections', should only be reachable with USE_LONG_SESSIONS" + + with self.lock: + for thread_conns in self.threads_compute_connections.values(): + for conn in thread_conns.values(): + if conn.acquire_release_count == 0 and conn._idle_too_long(): + logger.debug(f"closing idle connection: {conn._get_conn_info_str()}") + self.close(conn) + conn.handle = LazyHandle(self._open2) + def add_query( self, sql: str, @@ -825,6 +1098,7 @@ def add_query( node_info=get_node_info(), ) ) + pre = time.time() cursor = cast(DatabricksSQLConnectionWrapper, connection.handle).cursor() @@ -884,6 +1158,7 @@ def _execute_cursor( node_info=get_node_info(), ) ) + pre = time.time() handle: DatabricksSQLConnectionWrapper = connection.handle @@ -1007,6 +1282,81 @@ def exponential_backoff(attempt: int) -> int: retry_timeout=(timeout if timeout is not None else exponential_backoff), ) + @classmethod + def _open2(cls, connection: Connection) -> Connection: + # Once long session management is no longer under the USE_LONG_SESSIONS toggle + # this should be renamed and replace the _open class method. + assert ( + USE_LONG_SESSIONS + ), "This path, '_open2', should only be reachable with USE_LONG_SESSIONS" + + if connection.state == ConnectionState.OPEN: + logger.debug("Connection is already open, skipping open.") + return connection + + creds: DatabricksCredentials = connection.credentials + timeout = creds.connect_timeout + + # gotta keep this so we don't prompt users many times + cls.credentials_provider = creds.authenticate(cls.credentials_provider) + + user_agent_entry = f"dbt-databricks/{__version__}" + + invocation_env = creds.get_invocation_env() + if invocation_env: + user_agent_entry = f"{user_agent_entry}; {invocation_env}" + + connection_parameters = creds.connection_parameters.copy() # type: ignore[union-attr] + + http_headers: List[Tuple[str, str]] = list( + creds.get_all_http_headers(connection_parameters.pop("http_headers", {})).items() + ) + + # If a model specifies a compute resource the http path + # may be different than the http_path property of creds. + http_path = cast(DatabricksDBTConnection, connection).http_path + + def connect() -> DatabricksSQLConnectionWrapper: + try: + # TODO: what is the error when a user specifies a catalog they don't have access to + conn: DatabricksSQLConnection = dbsql.connect( + server_hostname=creds.host, + http_path=http_path, + credentials_provider=cls.credentials_provider, + http_headers=http_headers if http_headers else None, + session_configuration=creds.session_properties, + catalog=creds.database, + # schema=creds.schema, # TODO: Explicitly set once DBR 7.3LTS is EOL. + _user_agent_entry=user_agent_entry, + **connection_parameters, + ) + return DatabricksSQLConnectionWrapper( + conn, + is_cluster=creds.cluster_id is not None, + creds=creds, + user_agent=user_agent_entry, + ) + except Error as exc: + _log_dbsql_errors(exc) + raise + + def exponential_backoff(attempt: int) -> int: + return attempt * attempt + + retryable_exceptions = [] + # this option is for backwards compatibility + if creds.retry_all: + retryable_exceptions = [Error] + + return cls.retry_connection( + connection, + connect=connect, + logger=logger, + retryable_exceptions=retryable_exceptions, + retry_limit=creds.connect_retries, + retry_timeout=(timeout if timeout is not None else exponential_backoff), + ) + @classmethod def get_response(cls, cursor: DatabricksSQLCursorWrapper) -> DatabricksAdapterResponse: _query_id = getattr(cursor, "hex_query_id", None) @@ -1121,18 +1471,25 @@ def _get_compute_name(node: Optional[ResultNode]) -> Optional[str]: def _get_http_path(node: Optional[ResultNode], creds: DatabricksCredentials) -> Optional[str]: + """Get the http_path for the compute specified for the node. + If none is specified default will be used.""" + thread_id = (os.getpid(), get_ident()) # If there is no node we return the http_path for the default compute. if not node: - logger.debug(f"Thread {thread_id}: using default compute resource.") + if not USE_LONG_SESSIONS: + logger.debug(f"Thread {thread_id}: using default compute resource.") return creds.http_path # Get the name of the compute resource specified in the node's config. # If none is specified return the http_path for the default compute. compute_name = _get_compute_name(node) if not compute_name: - logger.debug(f"On thread {thread_id}: {node.relation_name} using default compute resource.") + if not USE_LONG_SESSIONS: + logger.debug( + f"On thread {thread_id}: {node.relation_name} using default compute resource." + ) return creds.http_path # Get the http_path for the named compute. @@ -1147,8 +1504,36 @@ def _get_http_path(node: Optional[ResultNode], creds: DatabricksCredentials) -> f"does not specify http_path, relation: {node.relation_name}" ) - logger.debug( - f"On thread {thread_id}: {node.relation_name} using compute resource '{compute_name}'." - ) + if not USE_LONG_SESSIONS: + logger.debug( + f"On thread {thread_id}: {node.relation_name} using compute resource '{compute_name}'." + ) return http_path + + +def _get_max_idle_time(node: Optional[ResultNode], creds: DatabricksCredentials) -> int: + """Get the http_path for the compute specified for the node. + If none is specified default will be used.""" + + max_idle_time = ( + DEFAULT_MAX_IDLE_TIME if creds.connect_max_idle is None else creds.connect_max_idle + ) + + if node: + compute_name = _get_compute_name(node) + if compute_name and creds.compute: + max_idle_time = creds.compute.get(compute_name, {}).get( + "connect_max_idle", max_idle_time + ) + + if not isinstance(max_idle_time, Number): + if isinstance(max_idle_time, str) and max_idle_time.strip().isnumeric(): + return int(max_idle_time.strip()) + else: + raise dbt.exceptions.DbtRuntimeError( + f"{max_idle_time} is not a valid value for connect_max_idle. " + "Must be a number of seconds." + ) + + return max_idle_time diff --git a/tests/functional/adapter/long_sessions/fixtures.py b/tests/functional/adapter/long_sessions/fixtures.py new file mode 100644 index 00000000..f1332e87 --- /dev/null +++ b/tests/functional/adapter/long_sessions/fixtures.py @@ -0,0 +1,46 @@ +source = """id,name,date +1,Alice,2022-01-01 +2,Bob,2022-01-02 +""" + +target = """ +{{config(materialized='table')}} + +select * from {{ ref('source') }} +""" + +target2 = """ +{{config(materialized='table', databricks_compute='alternate_warehouse')}} + +select * from {{ ref('source') }} +""" + +targetseq1 = """ +{{config(materialized='table', databricks_compute='alternate_warehouse')}} + +select * from {{ ref('source') }} +""" + +targetseq2 = """ +{{config(materialized='table')}} + +select * from {{ ref('targetseq1') }} +""" + +targetseq3 = """ +{{config(materialized='table')}} + +select * from {{ ref('targetseq2') }} +""" + +targetseq4 = """ +{{config(materialized='table')}} + +select * from {{ ref('targetseq3') }} +""" + +targetseq5 = """ +{{config(materialized='table', databricks_compute='alternate_warehouse')}} + +select * from {{ ref('targetseq4') }} +""" diff --git a/tests/functional/adapter/long_sessions/test_long_sessions.py b/tests/functional/adapter/long_sessions/test_long_sessions.py new file mode 100644 index 00000000..d32b5397 --- /dev/null +++ b/tests/functional/adapter/long_sessions/test_long_sessions.py @@ -0,0 +1,116 @@ +import pytest +import os +from unittest import mock +from dbt.tests import util +from tests.functional.adapter.long_sessions import fixtures + +with mock.patch.dict(os.environ, {"DBT_DATABRICKS_LONG_SESSIONS": "true"}): + import dbt.adapters.databricks.connections # noqa + + +class TestLongSessionsBase: + args_formatter = "" + + @pytest.fixture(scope="class") + def seeds(self): + return { + "source.csv": fixtures.source, + } + + @pytest.fixture(scope="class") + def models(self): + m = {} + for i in range(5): + m[f"target{i}.sql"] = fixtures.target + + return m + + def test_long_sessions(self, project): + _, log = util.run_dbt_and_capture(["--debug", "seed"]) + open_count = log.count("Sending request: OpenSession") / 2 + assert open_count == 2 + + _, log = util.run_dbt_and_capture(["--debug", "run"]) + open_count = log.count("Sending request: OpenSession") / 2 + assert open_count == 2 + + +class TestLongSessionsMultipleThreads(TestLongSessionsBase): + def test_long_sessions(self, project): + util.run_dbt_and_capture(["seed"]) + + for n_threads in [1, 2, 3]: + _, log = util.run_dbt_and_capture(["--debug", "run", "--threads", f"{n_threads}"]) + open_count = log.count("Sending request: OpenSession") / 2 + assert open_count == (n_threads + 1) + + +class TestLongSessionsMultipleCompute: + args_formatter = "" + + @pytest.fixture(scope="class") + def seeds(self): + return { + "source.csv": fixtures.source, + } + + @pytest.fixture(scope="class") + def models(self): + m = {} + for i in range(2): + m[f"target{i}.sql"] = fixtures.target + + m["target_alt.sql"] = fixtures.target2 + + return m + + @pytest.fixture(scope="class") + def profiles_config_update(self, dbt_profile_target): + outputs = {"default": dbt_profile_target} + outputs["default"]["compute"] = { + "alternate_warehouse": {"http_path": dbt_profile_target["http_path"]}, + } + + return {"test": {"outputs": outputs, "target": "default"}} + + def test_long_sessions(self, project): + util.run_dbt_and_capture(["--debug", "seed"]) + + _, log = util.run_dbt_and_capture(["--debug", "run"]) + open_count = log.count("Sending request: OpenSession") / 2 + assert open_count == 3 + + +class TestLongSessionsIdleCleanup(TestLongSessionsMultipleCompute): + args_formatter = "" + + @pytest.fixture(scope="class") + def models(self): + m = { + "targetseq1.sql": fixtures.targetseq1, + "targetseq2.sql": fixtures.targetseq2, + "targetseq3.sql": fixtures.targetseq3, + "targetseq4.sql": fixtures.targetseq4, + "targetseq5.sql": fixtures.targetseq5, + } + return m + + @pytest.fixture(scope="class") + def profiles_config_update(self, dbt_profile_target): + outputs = {"default": dbt_profile_target} + outputs["default"]["connect_max_idle"] = 1 + outputs["default"]["compute"] = { + "alternate_warehouse": { + "http_path": dbt_profile_target["http_path"], + "connect_max_idle": 1, + }, + } + + return {"test": {"outputs": outputs, "target": "default"}} + + def test_long_sessions(self, project): + util.run_dbt(["--debug", "seed"]) + + _, log = util.run_dbt_and_capture(["--debug", "run"]) + idle_count = log.count("closing idle connection") / 2 + assert idle_count > 0 diff --git a/tests/unit/test_idle_config.py b/tests/unit/test_idle_config.py new file mode 100644 index 00000000..62a4da40 --- /dev/null +++ b/tests/unit/test_idle_config.py @@ -0,0 +1,238 @@ +import unittest +import dbt.exceptions +from dbt.contracts.graph import nodes, model_config +from dbt.adapters.databricks import connections + + +class TestDatabricksConnectionMaxIdleTime(unittest.TestCase): + """Test the various cases for determining a specified warehouse.""" + + errMsg = ( + "Compute resource foo does not exist or does not specify http_path, " "relation: a_relation" + ) + + def test_get_max_idle_default(self): + creds = connections.DatabricksCredentials() + + # No node and nothing specified in creds + time = connections._get_max_idle_time(None, creds) + self.assertEqual(connections.DEFAULT_MAX_IDLE_TIME, time) + + node = nodes.ModelNode( + relation_name="a_relation", + database="database", + schema="schema", + name="node_name", + resource_type="model", + package_name="package", + path="path", + original_file_path="orig_path", + unique_id="uniqueID", + fqn=[], + alias="alias", + checksum=None, + ) + + # node has no configuration so should get back default + time = connections._get_max_idle_time(node, creds) + self.assertEqual(connections.DEFAULT_MAX_IDLE_TIME, time) + + # empty configuration should return default + node.config = model_config.ModelConfig() + time = connections._get_max_idle_time(node, creds) + self.assertEqual(connections.DEFAULT_MAX_IDLE_TIME, time) + + # node with no extras in configuration should return default + node.config._extra = {} + time = connections._get_max_idle_time(node, creds) + self.assertEqual(connections.DEFAULT_MAX_IDLE_TIME, time) + + # node that specifies a compute with no corresponding definition should return default + node.config._extra["databricks_compute"] = "foo" + time = connections._get_max_idle_time(node, creds) + self.assertEqual(connections.DEFAULT_MAX_IDLE_TIME, time) + + creds.compute = {} + time = connections._get_max_idle_time(node, creds) + self.assertEqual(connections.DEFAULT_MAX_IDLE_TIME, time) + + # if alternate compute doesn't specify a max time should return default + creds.compute = {"foo": {}} + time = connections._get_max_idle_time(node, creds) + self.assertEqual(connections.DEFAULT_MAX_IDLE_TIME, time) + # with self.assertRaisesRegex( + # dbt.exceptions.DbtRuntimeError, + # self.errMsg, + # ): + # connections._get_http_path(node, creds) + + # creds.compute = {"foo": {"http_path": "alternate_path"}} + # path = connections._get_http_path(node, creds) + # self.assertEqual("alternate_path", path) + + def test_get_max_idle_creds(self): + creds_idle_time = 77 + creds = connections.DatabricksCredentials(connect_max_idle=creds_idle_time) + + # No node so value should come from creds + time = connections._get_max_idle_time(None, creds) + self.assertEqual(creds_idle_time, time) + + node = nodes.ModelNode( + relation_name="a_relation", + database="database", + schema="schema", + name="node_name", + resource_type="model", + package_name="package", + path="path", + original_file_path="orig_path", + unique_id="uniqueID", + fqn=[], + alias="alias", + checksum=None, + ) + + # node has no configuration so should get value from creds + time = connections._get_max_idle_time(node, creds) + self.assertEqual(creds_idle_time, time) + + # empty configuration should get value from creds + node.config = model_config.ModelConfig() + time = connections._get_max_idle_time(node, creds) + self.assertEqual(creds_idle_time, time) + + # node with no extras in configuration should get value from creds + node.config._extra = {} + time = connections._get_max_idle_time(node, creds) + self.assertEqual(creds_idle_time, time) + + # node that specifies a compute with no corresponding definition should get value from creds + node.config._extra["databricks_compute"] = "foo" + time = connections._get_max_idle_time(node, creds) + self.assertEqual(creds_idle_time, time) + + creds.compute = {} + time = connections._get_max_idle_time(node, creds) + self.assertEqual(creds_idle_time, time) + + # if alternate compute doesn't specify a max time should get value from creds + creds.compute = {"foo": {}} + time = connections._get_max_idle_time(node, creds) + self.assertEqual(creds_idle_time, time) + + def test_get_max_idle_compute(self): + creds_idle_time = 88 + compute_idle_time = 77 + creds = connections.DatabricksCredentials(connect_max_idle=creds_idle_time) + creds.compute = {"foo": {"connect_max_idle": compute_idle_time}} + + node = nodes.SnapshotNode( + config=None, + relation_name="a_relation", + database="database", + schema="schema", + name="node_name", + resource_type="model", + package_name="package", + path="path", + original_file_path="orig_path", + unique_id="uniqueID", + fqn=[], + alias="alias", + checksum=None, + ) + + node.config = model_config.SnapshotConfig() + node.config._extra = {"databricks_compute": "foo"} + + time = connections._get_max_idle_time(node, creds) + self.assertEqual(compute_idle_time, time) + + def test_get_max_idle_invalid(self): + creds_idle_time = "foo" + compute_idle_time = "bar" + creds = connections.DatabricksCredentials(connect_max_idle=creds_idle_time) + creds.compute = {"alternate_compute": {"connect_max_idle": compute_idle_time}} + + node = nodes.SnapshotNode( + config=None, + relation_name="a_relation", + database="database", + schema="schema", + name="node_name", + resource_type="model", + package_name="package", + path="path", + original_file_path="orig_path", + unique_id="uniqueID", + fqn=[], + alias="alias", + checksum=None, + ) + + node.config = model_config.SnapshotConfig() + + with self.assertRaisesRegex( + dbt.exceptions.DbtRuntimeError, + f"{creds_idle_time} is not a valid value for connect_max_idle. " + "Must be a number of seconds.", + ): + connections._get_max_idle_time(node, creds) + + node.config._extra["databricks_compute"] = "alternate_compute" + with self.assertRaisesRegex( + dbt.exceptions.DbtRuntimeError, + f"{compute_idle_time} is not a valid value for connect_max_idle. " + "Must be a number of seconds.", + ): + connections._get_max_idle_time(node, creds) + + creds.compute["alternate_compute"]["connect_max_idle"] = "1.2.3" + with self.assertRaisesRegex( + dbt.exceptions.DbtRuntimeError, + "1.2.3 is not a valid value for connect_max_idle. " "Must be a number of seconds.", + ): + connections._get_max_idle_time(node, creds) + + creds.compute["alternate_compute"]["connect_max_idle"] = "1,002.3" + with self.assertRaisesRegex( + dbt.exceptions.DbtRuntimeError, + "1,002.3 is not a valid value for connect_max_idle. " "Must be a number of seconds.", + ): + connections._get_max_idle_time(node, creds) + + def test_get_max_idle_simple_string_conversion(self): + creds_idle_time = "12" + compute_idle_time = "34" + creds = connections.DatabricksCredentials(connect_max_idle=creds_idle_time) + creds.compute = {"alternate_compute": {"connect_max_idle": compute_idle_time}} + + node = nodes.SnapshotNode( + config=None, + relation_name="a_relation", + database="database", + schema="schema", + name="node_name", + resource_type="model", + package_name="package", + path="path", + original_file_path="orig_path", + unique_id="uniqueID", + fqn=[], + alias="alias", + checksum=None, + ) + + node.config = model_config.SnapshotConfig() + + time = connections._get_max_idle_time(node, creds) + self.assertEqual(float(creds_idle_time), time) + + node.config._extra["databricks_compute"] = "alternate_compute" + time = connections._get_max_idle_time(node, creds) + self.assertEqual(float(compute_idle_time), time) + + creds.compute["alternate_compute"]["connect_max_idle"] = " 56 " + time = connections._get_max_idle_time(node, creds) + self.assertEqual(56, time)