From 183897c135a4bde9f46f33fc268e5452655baf62 Mon Sep 17 00:00:00 2001 From: Raymond Cypher Date: Thu, 26 Oct 2023 16:45:34 -0600 Subject: [PATCH 1/3] Allow models to execute on different warehouses Signed-off-by: Raymond Cypher --- dbt/adapters/databricks/connections.py | 102 ++++++++++++++++++++++++- dbt/adapters/databricks/impl.py | 20 +++++ 2 files changed, 121 insertions(+), 1 deletion(-) diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 0e39af14..12fe3ea8 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -37,8 +37,15 @@ Connection, ConnectionState, DEFAULT_QUERY_COMMENT, + Identifier, + LazyHandle, +) +from dbt.events.types import ( + NewConnection, + ConnectionReused, ) from dbt.contracts.graph.manifest import Manifest +from dbt.contracts.graph.nodes import ResultNode from dbt.events import AdapterLogger from dbt.events.contextvars import get_node_info from dbt.events.functions import fire_event @@ -111,6 +118,10 @@ class DatabricksCredentials(Credentials): connection_parameters: Optional[Dict[str, Any]] = None auth_type: Optional[str] = None + # Named compute resources specified in the profile. Used for + # creating a connection when a model specifies a compute resource. + compute: Optional[Dict[str, Any]] = None + connect_retries: int = 1 connect_timeout: Optional[int] = None retry_all: bool = False @@ -741,6 +752,50 @@ def exception_handler(self, sql: str) -> Iterator[None]: else: raise dbt.exceptions.DbtRuntimeError(str(exc)) from exc + # override/overload + def set_connection_name( + self, name: Optional[str] = None, node: Optional[ResultNode] = None + ) -> Connection: + """Called by 'acquire_connection' in DatabricksAdapter, which is called by + 'connection_named', called by 'connection_for(node)'. + Creates a connection for this thread if one doesn't already + exist, and will rename an existing connection.""" + + conn_name: str = "master" if name is None else name + + # Get a connection for this thread + conn = self.get_if_exists() + + if conn and conn.name == conn_name and conn.state == "open": + # Found a connection and nothing to do, so just return it + return conn + + if conn is None: + # Create a new connection + conn = Connection( + type=Identifier(self.TYPE), + name=conn_name, + state=ConnectionState.INIT, + transaction_open=False, + handle=None, + credentials=self.profile.credentials, + ) + conn.handle = LazyHandle(self.get_open_for_model(node)) + # Add the connection to thread_connections for this thread + self.set_thread_connection(conn) + fire_event( + NewConnection(conn_name=conn_name, conn_type=self.TYPE, node_info=get_node_info()) + ) + else: # existing connection either wasn't open or didn't have the right name + if conn.state != "open": + conn.handle = LazyHandle(self.get_open_for_model(node)) + if conn.name != conn_name: + orig_conn_name: str = conn.name or "" + conn.name = conn_name + fire_event(ConnectionReused(orig_conn_name=orig_conn_name, conn_name=conn_name)) + + return conn + def add_query( self, sql: str, @@ -861,8 +916,29 @@ def list_tables(self, database: str, schema: str, identifier: Optional[str] = No ), ) + @classmethod + def get_open_for_model( + cls, node: Optional[ResultNode] = None + ) -> Callable[[Connection], Connection]: + # If there is no node we can simply return the exsting class method open. + # If there is a node create a closure that will call cls._open with the node. + if not node: + return cls.open + + def _open(connection: Connection) -> Connection: + return cls._open(connection, node) + + return _open + @classmethod def open(cls, connection: Connection) -> Connection: + # Simply call _open with no ResultNode argument. + # Because this is an overridden method we can't just add + # a ResultNode parameter to open. + return cls._open(connection) + + @classmethod + def _open(cls, connection: Connection, node: Optional[ResultNode] = None) -> Connection: if connection.state == ConnectionState.OPEN: logger.debug("Connection is already open, skipping open.") return connection @@ -885,12 +961,16 @@ def open(cls, connection: Connection) -> Connection: creds.get_all_http_headers(connection_parameters.pop("http_headers", {})).items() ) + # If a model specifies a compute resource to use the http path + # may be different than the http_path property of creds. + http_path = get_http_path(node, creds) + def connect() -> DatabricksSQLConnectionWrapper: try: # TODO: what is the error when a user specifies a catalog they don't have access to conn: DatabricksSQLConnection = dbsql.connect( server_hostname=creds.host, - http_path=creds.http_path, + http_path=http_path, credentials_provider=cls.credentials_provider, http_headers=http_headers if http_headers else None, session_configuration=creds.session_properties, @@ -1028,3 +1108,23 @@ def _get_update_error_msg(host: str, headers: dict, pipeline_id: str, update_id: msg = error_events[0].get("message", "") return msg + + +def get_compute_name(node: Optional[ResultNode]) -> Optional[str]: + # Get the name of the specified compute resource from the node's + # config. + compute_name = None + if node and node.config and node.config.extra: + compute_name = node.config.extra.get("databricks_compute", None) + return compute_name + + +def get_http_path(node: Optional[ResultNode], creds: DatabricksCredentials) -> Optional[str]: + # Get the http path of the compute resource specified in the node's config. + # If none is specified return the default path from creds. + compute_name = get_compute_name(node) + http_path = creds.http_path + if compute_name and creds.compute: + http_path = creds.compute.get(compute_name, {}).get("http_path", creds.http_path) + + return http_path diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index b41c8395..5f4730db 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -36,6 +36,7 @@ from dbt.clients.agate_helper import DEFAULT_TYPE_TESTER, empty_table from dbt.contracts.connection import AdapterResponse, Connection from dbt.contracts.graph.manifest import Manifest +from dbt.contracts.graph.nodes import ResultNode from dbt.contracts.relation import RelationType import dbt.exceptions from dbt.events import AdapterLogger @@ -118,6 +119,25 @@ class DatabricksAdapter(SparkAdapter): } ) + # override/overload + def acquire_connection( + self, name: Optional[str] = None, node: Optional[ResultNode] = None + ) -> Connection: + return self.connections.set_connection_name(name, node) + + # override + @contextmanager + def connection_named(self, name: str, node: Optional[ResultNode] = None) -> Iterator[None]: + try: + if self.connections.query_header is not None: + self.connections.query_header.set(name, node) + self.acquire_connection(name, node) + yield + finally: + self.release_connection() + if self.connections.query_header is not None: + self.connections.query_header.reset() + @available.parse(lambda *a, **k: 0) def compare_dbr_version(self, major: int, minor: int) -> int: """ From 81fa7ba9a482a1fe594010be9ea4a7a8dab7d763 Mon Sep 17 00:00:00 2001 From: Raymond Cypher Date: Mon, 30 Oct 2023 14:20:55 -0600 Subject: [PATCH 2/3] Raise exception on missing compute resource Raise an exception if a model specifies a compute resource that is not defined in the profile. Signed-off-by: Raymond Cypher --- dbt/adapters/databricks/connections.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 12fe3ea8..3985d543 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -1123,8 +1123,16 @@ def get_http_path(node: Optional[ResultNode], creds: DatabricksCredentials) -> O # Get the http path of the compute resource specified in the node's config. # If none is specified return the default path from creds. compute_name = get_compute_name(node) - http_path = creds.http_path - if compute_name and creds.compute: - http_path = creds.compute.get(compute_name, {}).get("http_path", creds.http_path) + if not node or not compute_name: + return creds.http_path + + http_path = None + if creds.compute: + http_path = creds.compute.get(compute_name, {}).get("http_path", None) + + if not http_path: + raise dbt.exceptions.DbtRuntimeError( + f"Compute resource {compute_name} does not exist, relation: {node.relation_name}" + ) return http_path From e1f89c01a7b5edec48f275d599c8306a87911013 Mon Sep 17 00:00:00 2001 From: Raymond Cypher Date: Wed, 8 Nov 2023 16:47:14 -0700 Subject: [PATCH 3/3] Tests for warehouse-per-model Signed-off-by: Raymond Cypher --- dbt/adapters/databricks/connections.py | 18 +- .../adapter/warehouse_per_model/fixtures.py | 50 +++++ .../test_warehouse_per_model.py | 100 ++++++++++ tests/unit/test_compute_config.py | 178 ++++++++++++++++++ 4 files changed, 337 insertions(+), 9 deletions(-) create mode 100644 tests/functional/adapter/warehouse_per_model/fixtures.py create mode 100644 tests/functional/adapter/warehouse_per_model/test_warehouse_per_model.py create mode 100644 tests/unit/test_compute_config.py diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 3985d543..46e2dfbd 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -925,10 +925,10 @@ def get_open_for_model( if not node: return cls.open - def _open(connection: Connection) -> Connection: + def open_for_model(connection: Connection) -> Connection: return cls._open(connection, node) - return _open + return open_for_model @classmethod def open(cls, connection: Connection) -> Connection: @@ -961,9 +961,9 @@ def _open(cls, connection: Connection, node: Optional[ResultNode] = None) -> Con creds.get_all_http_headers(connection_parameters.pop("http_headers", {})).items() ) - # If a model specifies a compute resource to use the http path + # If a model specifies a compute resource the http path # may be different than the http_path property of creds. - http_path = get_http_path(node, creds) + http_path = _get_http_path(node, creds) def connect() -> DatabricksSQLConnectionWrapper: try: @@ -1110,19 +1110,19 @@ def _get_update_error_msg(host: str, headers: dict, pipeline_id: str, update_id: return msg -def get_compute_name(node: Optional[ResultNode]) -> Optional[str]: +def _get_compute_name(node: Optional[ResultNode]) -> Optional[str]: # Get the name of the specified compute resource from the node's # config. compute_name = None - if node and node.config and node.config.extra: - compute_name = node.config.extra.get("databricks_compute", None) + if node and node.config: + compute_name = node.config.get("databricks_compute", None) return compute_name -def get_http_path(node: Optional[ResultNode], creds: DatabricksCredentials) -> Optional[str]: +def _get_http_path(node: Optional[ResultNode], creds: DatabricksCredentials) -> Optional[str]: # Get the http path of the compute resource specified in the node's config. # If none is specified return the default path from creds. - compute_name = get_compute_name(node) + compute_name = _get_compute_name(node) if not node or not compute_name: return creds.http_path diff --git a/tests/functional/adapter/warehouse_per_model/fixtures.py b/tests/functional/adapter/warehouse_per_model/fixtures.py new file mode 100644 index 00000000..7a5f9fa1 --- /dev/null +++ b/tests/functional/adapter/warehouse_per_model/fixtures.py @@ -0,0 +1,50 @@ +source = """id,name,date +1,Alice,2022-01-01 +2,Bob,2022-01-02 +""" + +target = """ +{{config(materialized='table', databricks_compute='alternate_warehouse')}} + +select * from {{ ref('source') }} +""" + +target2 = """ +{{config(materialized='table')}} + +select * from {{ ref('source') }} +""" + +target3 = """ +{{config(materialized='table')}} + +select * from {{ ref('source') }} +""" + +model_schema = """ +version: 2 + +models: + - name: target + columns: + - name: id + - name: name + - name: date + - name: target2 + config: + databricks_compute: alternate_warehouse + columns: + - name: id + - name: name + - name: date + - name: target3 + columns: + - name: id + - name: name + - name: date +""" + +expected_target = """id,name,date +1,Alice,2022-01-01 +2,Bob,2022-01-02 +""" diff --git a/tests/functional/adapter/warehouse_per_model/test_warehouse_per_model.py b/tests/functional/adapter/warehouse_per_model/test_warehouse_per_model.py new file mode 100644 index 00000000..608d58b8 --- /dev/null +++ b/tests/functional/adapter/warehouse_per_model/test_warehouse_per_model.py @@ -0,0 +1,100 @@ +import pytest +from dbt.tests import util +from tests.functional.adapter.warehouse_per_model import fixtures + + +class BaseWarehousePerModel: + args_formatter = "" + + @pytest.fixture(scope="class") + def seeds(self): + return { + "source.csv": fixtures.source, + } + + @pytest.fixture(scope="class") + def models(self): + d = dict() + d["target4.sql"] = fixtures.target3 + return { + "target.sql": fixtures.target, + "target2.sql": fixtures.target2, + "target3.sql": fixtures.target3, + "schema.yml": fixtures.model_schema, + "special": d, + } + + +class BaseSpecifyingCompute(BaseWarehousePerModel): + """Base class for testing various ways to specify a warehouse.""" + + def test_wpm(self, project): + util.run_dbt(["seed"]) + models = project.test_config.get("model_names") + for model_name in models: + # Since the profile doesn't define a compute resource named 'alternate_warehouse' + # we should fail with an error if the warehouse specified for the model is + # correctly handled. + res = util.run_dbt(["run", "--select", model_name], expect_pass=False) + msg = res.results[0].message + assert "Compute resource alternate_warehouse does not exist" in msg + assert model_name in msg + + +class TestSpecifyingInConfigBlock(BaseSpecifyingCompute): + @pytest.fixture(scope="class") + def test_config(self): + return {"model_names": ["target"]} + + +class TestSpecifyingInSchemaYml(BaseSpecifyingCompute): + @pytest.fixture(scope="class") + def test_config(self): + return {"model_names": ["target2"]} + + +class TestSpecifyingForProjectModels(BaseSpecifyingCompute): + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+databricks_compute": "alternate_warehouse", + } + } + + @pytest.fixture(scope="class") + def test_config(self): + return {"model_names": ["target3"]} + + +class TestSpecifyingForProjectModelsInFolder(BaseSpecifyingCompute): + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "test": { + "special": { + "+databricks_compute": "alternate_warehouse", + }, + }, + } + } + + @pytest.fixture(scope="class") + def test_config(self): + return {"model_names": ["target4"]} + + +class TestWarehousePerModel(BaseWarehousePerModel): + @pytest.fixture(scope="class") + def profiles_config_update(self, dbt_profile_target): + outputs = {"default": dbt_profile_target} + outputs["default"]["compute"] = { + "alternate_warehouse": {"http_path": dbt_profile_target["http_path"]} + } + return {"test": {"outputs": outputs, "target": "default"}} + + def test_wpm(self, project): + util.run_dbt(["seed"]) + util.run_dbt(["run", "--select", "target"]) + util.check_relations_equal(project.adapter, ["target", "source"]) diff --git a/tests/unit/test_compute_config.py b/tests/unit/test_compute_config.py new file mode 100644 index 00000000..6af95b75 --- /dev/null +++ b/tests/unit/test_compute_config.py @@ -0,0 +1,178 @@ +import unittest +import dbt.exceptions +from dbt.contracts.graph import nodes, model_config +from dbt.adapters.databricks import connections + + +class TestDatabricksConnectionHTTPPath(unittest.TestCase): + """Test the various cases for determining a specified warehouse.""" + + def test_get_http_path_model(self): + default_path = "my_http_path" + creds = connections.DatabricksCredentials(http_path=default_path) + + path = connections._get_http_path(None, creds) + self.assertEqual(default_path, path) + + node = nodes.ModelNode( + relation_name="a_relation", + database="database", + schema="schema", + name="node_name", + resource_type="model", + package_name="package", + path="path", + original_file_path="orig_path", + unique_id="uniqueID", + fqn=[], + alias="alias", + checksum=None, + ) + path = connections._get_http_path(node, creds) + self.assertEqual(default_path, path) + + node.config = model_config.ModelConfig() + path = connections._get_http_path(node, creds) + self.assertEqual(default_path, path) + + node.config._extra = {} + path = connections._get_http_path(node, creds) + self.assertEqual(default_path, path) + + node.config._extra["databricks_compute"] = "foo" + with self.assertRaisesRegex( + dbt.exceptions.DbtRuntimeError, + "Compute resource foo does not exist, relation: a_relation", + ): + connections._get_http_path(node, creds) + + creds.compute = {} + with self.assertRaisesRegex( + dbt.exceptions.DbtRuntimeError, + "Compute resource foo does not exist, relation: a_relation", + ): + connections._get_http_path(node, creds) + + creds.compute = {"foo": {}} + with self.assertRaisesRegex( + dbt.exceptions.DbtRuntimeError, + "Compute resource foo does not exist, relation: a_relation", + ): + connections._get_http_path(node, creds) + + creds.compute = {"foo": {"http_path": "alternate_path"}} + path = connections._get_http_path(node, creds) + self.assertEqual("alternate_path", path) + + def test_get_http_path_seed(self): + default_path = "my_http_path" + creds = connections.DatabricksCredentials(http_path=default_path) + + path = connections._get_http_path(None, creds) + self.assertEqual(default_path, path) + + node = nodes.SeedNode( + relation_name="a_relation", + database="database", + schema="schema", + name="node_name", + resource_type="model", + package_name="package", + path="path", + original_file_path="orig_path", + unique_id="uniqueID", + fqn=[], + alias="alias", + checksum=None, + ) + path = connections._get_http_path(node, creds) + self.assertEqual(default_path, path) + + node.config = model_config.SeedConfig() + path = connections._get_http_path(node, creds) + self.assertEqual(default_path, path) + + node.config._extra = {} + path = connections._get_http_path(node, creds) + self.assertEqual(default_path, path) + + node.config._extra["databricks_compute"] = "foo" + with self.assertRaisesRegex( + dbt.exceptions.DbtRuntimeError, + "Compute resource foo does not exist, relation: a_relation", + ): + connections._get_http_path(node, creds) + + creds.compute = {} + with self.assertRaisesRegex( + dbt.exceptions.DbtRuntimeError, + "Compute resource foo does not exist, relation: a_relation", + ): + connections._get_http_path(node, creds) + + creds.compute = {"foo": {}} + with self.assertRaisesRegex( + dbt.exceptions.DbtRuntimeError, + "Compute resource foo does not exist, relation: a_relation", + ): + connections._get_http_path(node, creds) + + creds.compute = {"foo": {"http_path": "alternate_path"}} + path = connections._get_http_path(node, creds) + self.assertEqual("alternate_path", path) + + def test_get_http_path_snapshot(self): + default_path = "my_http_path" + creds = connections.DatabricksCredentials(http_path=default_path) + + path = connections._get_http_path(None, creds) + self.assertEqual(default_path, path) + + node = nodes.SnapshotNode( + config=None, + relation_name="a_relation", + database="database", + schema="schema", + name="node_name", + resource_type="model", + package_name="package", + path="path", + original_file_path="orig_path", + unique_id="uniqueID", + fqn=[], + alias="alias", + checksum=None, + ) + + node.config = model_config.SnapshotConfig() + path = connections._get_http_path(node, creds) + self.assertEqual(default_path, path) + + node.config._extra = {} + path = connections._get_http_path(node, creds) + self.assertEqual(default_path, path) + + node.config._extra["databricks_compute"] = "foo" + with self.assertRaisesRegex( + dbt.exceptions.DbtRuntimeError, + "Compute resource foo does not exist, relation: a_relation", + ): + connections._get_http_path(node, creds) + + creds.compute = {} + with self.assertRaisesRegex( + dbt.exceptions.DbtRuntimeError, + "Compute resource foo does not exist, relation: a_relation", + ): + connections._get_http_path(node, creds) + + creds.compute = {"foo": {}} + with self.assertRaisesRegex( + dbt.exceptions.DbtRuntimeError, + "Compute resource foo does not exist, relation: a_relation", + ): + connections._get_http_path(node, creds) + + creds.compute = {"foo": {"http_path": "alternate_path"}} + path = connections._get_http_path(node, creds) + self.assertEqual("alternate_path", path)