Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow models to execute on different warehouses #488

Merged
merged 3 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 109 additions & 1 deletion dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,15 @@
Connection,
ConnectionState,
DEFAULT_QUERY_COMMENT,
Identifier,
LazyHandle,
)
from dbt.events.types import (
NewConnection,
ConnectionReused,
)
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.nodes import ResultNode
from dbt.events import AdapterLogger
from dbt.events.contextvars import get_node_info
from dbt.events.functions import fire_event
Expand Down Expand Up @@ -111,6 +118,10 @@ class DatabricksCredentials(Credentials):
connection_parameters: Optional[Dict[str, Any]] = None
auth_type: Optional[str] = None

# Named compute resources specified in the profile. Used for
# creating a connection when a model specifies a compute resource.
compute: Optional[Dict[str, Any]] = None

connect_retries: int = 1
connect_timeout: Optional[int] = None
retry_all: bool = False
Expand Down Expand Up @@ -741,6 +752,50 @@ def exception_handler(self, sql: str) -> Iterator[None]:
else:
raise dbt.exceptions.DbtRuntimeError(str(exc)) from exc

# override/overload
def set_connection_name(
self, name: Optional[str] = None, node: Optional[ResultNode] = None
) -> Connection:
"""Called by 'acquire_connection' in DatabricksAdapter, which is called by
'connection_named', called by 'connection_for(node)'.
Creates a connection for this thread if one doesn't already
exist, and will rename an existing connection."""

conn_name: str = "master" if name is None else name

# Get a connection for this thread
conn = self.get_if_exists()

if conn and conn.name == conn_name and conn.state == "open":
# Found a connection and nothing to do, so just return it
return conn

if conn is None:
# Create a new connection
conn = Connection(
type=Identifier(self.TYPE),
name=conn_name,
state=ConnectionState.INIT,
transaction_open=False,
handle=None,
credentials=self.profile.credentials,
)
conn.handle = LazyHandle(self.get_open_for_model(node))
# Add the connection to thread_connections for this thread
self.set_thread_connection(conn)
fire_event(
NewConnection(conn_name=conn_name, conn_type=self.TYPE, node_info=get_node_info())
)
else: # existing connection either wasn't open or didn't have the right name
if conn.state != "open":
conn.handle = LazyHandle(self.get_open_for_model(node))
if conn.name != conn_name:
orig_conn_name: str = conn.name or ""
conn.name = conn_name
fire_event(ConnectionReused(orig_conn_name=orig_conn_name, conn_name=conn_name))

return conn

def add_query(
self,
sql: str,
Expand Down Expand Up @@ -861,8 +916,29 @@ def list_tables(self, database: str, schema: str, identifier: Optional[str] = No
),
)

@classmethod
def get_open_for_model(
cls, node: Optional[ResultNode] = None
) -> Callable[[Connection], Connection]:
# If there is no node we can simply return the exsting class method open.
# If there is a node create a closure that will call cls._open with the node.
if not node:
return cls.open

def open_for_model(connection: Connection) -> Connection:
return cls._open(connection, node)

return open_for_model

@classmethod
def open(cls, connection: Connection) -> Connection:
# Simply call _open with no ResultNode argument.
# Because this is an overridden method we can't just add
# a ResultNode parameter to open.
return cls._open(connection)

@classmethod
def _open(cls, connection: Connection, node: Optional[ResultNode] = None) -> Connection:
if connection.state == ConnectionState.OPEN:
logger.debug("Connection is already open, skipping open.")
return connection
Expand All @@ -885,12 +961,16 @@ def open(cls, connection: Connection) -> Connection:
creds.get_all_http_headers(connection_parameters.pop("http_headers", {})).items()
)

# If a model specifies a compute resource the http path
# may be different than the http_path property of creds.
http_path = _get_http_path(node, creds)

def connect() -> DatabricksSQLConnectionWrapper:
try:
# TODO: what is the error when a user specifies a catalog they don't have access to
conn: DatabricksSQLConnection = dbsql.connect(
server_hostname=creds.host,
http_path=creds.http_path,
http_path=http_path,
credentials_provider=cls.credentials_provider,
http_headers=http_headers if http_headers else None,
session_configuration=creds.session_properties,
Expand Down Expand Up @@ -1028,3 +1108,31 @@ def _get_update_error_msg(host: str, headers: dict, pipeline_id: str, update_id:
msg = error_events[0].get("message", "")

return msg


def _get_compute_name(node: Optional[ResultNode]) -> Optional[str]:
# Get the name of the specified compute resource from the node's
# config.
compute_name = None
if node and node.config:
compute_name = node.config.get("databricks_compute", None)
return compute_name


def _get_http_path(node: Optional[ResultNode], creds: DatabricksCredentials) -> Optional[str]:
# Get the http path of the compute resource specified in the node's config.
# If none is specified return the default path from creds.
compute_name = _get_compute_name(node)
if not node or not compute_name:
return creds.http_path

http_path = None
if creds.compute:
http_path = creds.compute.get(compute_name, {}).get("http_path", None)

if not http_path:
raise dbt.exceptions.DbtRuntimeError(
f"Compute resource {compute_name} does not exist, relation: {node.relation_name}"
)

return http_path
20 changes: 20 additions & 0 deletions dbt/adapters/databricks/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from dbt.clients.agate_helper import DEFAULT_TYPE_TESTER, empty_table
from dbt.contracts.connection import AdapterResponse, Connection
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.nodes import ResultNode
from dbt.contracts.relation import RelationType
import dbt.exceptions
from dbt.events import AdapterLogger
Expand Down Expand Up @@ -118,6 +119,25 @@ class DatabricksAdapter(SparkAdapter):
}
)

# override/overload
def acquire_connection(
self, name: Optional[str] = None, node: Optional[ResultNode] = None
) -> Connection:
return self.connections.set_connection_name(name, node)

# override
@contextmanager
def connection_named(self, name: str, node: Optional[ResultNode] = None) -> Iterator[None]:
try:
if self.connections.query_header is not None:
self.connections.query_header.set(name, node)
self.acquire_connection(name, node)
yield
finally:
self.release_connection()
if self.connections.query_header is not None:
self.connections.query_header.reset()

@available.parse(lambda *a, **k: 0)
def compare_dbr_version(self, major: int, minor: int) -> int:
"""
Expand Down
50 changes: 50 additions & 0 deletions tests/functional/adapter/warehouse_per_model/fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
source = """id,name,date
1,Alice,2022-01-01
2,Bob,2022-01-02
"""

target = """
{{config(materialized='table', databricks_compute='alternate_warehouse')}}

select * from {{ ref('source') }}
"""

target2 = """
{{config(materialized='table')}}

select * from {{ ref('source') }}
"""

target3 = """
{{config(materialized='table')}}

select * from {{ ref('source') }}
"""

model_schema = """
version: 2

models:
- name: target
columns:
- name: id
- name: name
- name: date
- name: target2
config:
databricks_compute: alternate_warehouse
columns:
- name: id
- name: name
- name: date
- name: target3
columns:
- name: id
- name: name
- name: date
"""

expected_target = """id,name,date
1,Alice,2022-01-01
2,Bob,2022-01-02
"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import pytest
from dbt.tests import util
from tests.functional.adapter.warehouse_per_model import fixtures


class BaseWarehousePerModel:
args_formatter = ""

@pytest.fixture(scope="class")
def seeds(self):
return {
"source.csv": fixtures.source,
}

@pytest.fixture(scope="class")
def models(self):
d = dict()
d["target4.sql"] = fixtures.target3
return {
"target.sql": fixtures.target,
"target2.sql": fixtures.target2,
"target3.sql": fixtures.target3,
"schema.yml": fixtures.model_schema,
"special": d,
}


class BaseSpecifyingCompute(BaseWarehousePerModel):
"""Base class for testing various ways to specify a warehouse."""

def test_wpm(self, project):
util.run_dbt(["seed"])
models = project.test_config.get("model_names")
for model_name in models:
# Since the profile doesn't define a compute resource named 'alternate_warehouse'
# we should fail with an error if the warehouse specified for the model is
# correctly handled.
res = util.run_dbt(["run", "--select", model_name], expect_pass=False)
msg = res.results[0].message
assert "Compute resource alternate_warehouse does not exist" in msg
assert model_name in msg


class TestSpecifyingInConfigBlock(BaseSpecifyingCompute):
@pytest.fixture(scope="class")
def test_config(self):
return {"model_names": ["target"]}


class TestSpecifyingInSchemaYml(BaseSpecifyingCompute):
@pytest.fixture(scope="class")
def test_config(self):
return {"model_names": ["target2"]}


class TestSpecifyingForProjectModels(BaseSpecifyingCompute):
@pytest.fixture(scope="class")
def project_config_update(self):
return {
"models": {
"+databricks_compute": "alternate_warehouse",
}
}

@pytest.fixture(scope="class")
def test_config(self):
return {"model_names": ["target3"]}


class TestSpecifyingForProjectModelsInFolder(BaseSpecifyingCompute):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know if we can specify a compute to use with models of a particular tag? This came up in a customer call where they would want to tag certain models as heavy_compute for example.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After reading the dbt docs on tags, I don't think that would work, which is probably fine. I think having the named compute approach gets us 95% of the way to what it would be if they could target compute to tags.

@pytest.fixture(scope="class")
def project_config_update(self):
return {
"models": {
"test": {
"special": {
"+databricks_compute": "alternate_warehouse",
},
},
}
}

@pytest.fixture(scope="class")
def test_config(self):
return {"model_names": ["target4"]}


class TestWarehousePerModel(BaseWarehousePerModel):
@pytest.fixture(scope="class")
def profiles_config_update(self, dbt_profile_target):
outputs = {"default": dbt_profile_target}
outputs["default"]["compute"] = {
"alternate_warehouse": {"http_path": dbt_profile_target["http_path"]}
}
return {"test": {"outputs": outputs, "target": "default"}}

def test_wpm(self, project):
util.run_dbt(["seed"])
util.run_dbt(["run", "--select", "target"])
util.check_relations_equal(project.adapter, ["target", "source"])
Loading
Loading