Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add DatabricksOauthProfileMapping profile #1091

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cosmos/profiles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .bigquery.service_account_file import GoogleCloudServiceAccountFileProfileMapping
from .bigquery.service_account_keyfile_dict import GoogleCloudServiceAccountDictProfileMapping
from .clickhouse.user_pass import ClickhouseUserPasswordProfileMapping
from .databricks.oauth import DatabricksOauthProfileMapping
from .databricks.token import DatabricksTokenProfileMapping
from .exasol.user_pass import ExasolUserPasswordProfileMapping
from .postgres.user_pass import PostgresUserPasswordProfileMapping
Expand All @@ -32,6 +33,7 @@
GoogleCloudServiceAccountDictProfileMapping,
GoogleCloudOauthProfileMapping,
DatabricksTokenProfileMapping,
DatabricksOauthProfileMapping,
PostgresUserPasswordProfileMapping,
RedshiftUserPasswordProfileMapping,
SnowflakeUserPasswordProfileMapping,
Expand Down Expand Up @@ -73,6 +75,7 @@ def get_automatic_profile_mapping(
"GoogleCloudServiceAccountDictProfileMapping",
"GoogleCloudOauthProfileMapping",
"DatabricksTokenProfileMapping",
"DatabricksOauthProfileMapping",
"DbtProfileConfigVars",
"PostgresUserPasswordProfileMapping",
"RedshiftUserPasswordProfileMapping",
Expand Down
3 changes: 2 additions & 1 deletion cosmos/profiles/databricks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Databricks Airflow connection -> dbt profile mappings"""

from .oauth import DatabricksOauthProfileMapping
from .token import DatabricksTokenProfileMapping

__all__ = ["DatabricksTokenProfileMapping"]
__all__ = ["DatabricksTokenProfileMapping", "DatabricksOauthProfileMapping"]
48 changes: 48 additions & 0 deletions cosmos/profiles/databricks/oauth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Maps Airflow Databricks connections with the client auth to dbt profiles."""

from __future__ import annotations

from typing import Any

from ..base import BaseProfileMapping


class DatabricksOauthProfileMapping(BaseProfileMapping):
"""
Maps Airflow Databricks connections with the client auth to dbt profiles.

https://docs.getdbt.com/reference/warehouse-setups/databricks-setup
https://airflow.apache.org/docs/apache-airflow-providers-databricks/stable/connections/databricks.html
"""

airflow_connection_type: str = "databricks"
dbt_profile_type: str = "databricks"

required_fields = [
"host",
"schema",
"client_secret",
"client_id",
"http_path",
]

secret_fields = ["client_secret", "client_id"]

airflow_param_mapping = {
"host": "host",
"schema": "schema",
"client_id": ["login", "extra.client_id"],
"client_secret": ["password", "extra.client_secret"],
"http_path": "extra.http_path",
}

@property
def profile(self) -> dict[str, Any | None]:
"""Generates profile. The client-id and client-secret is stored in an environment variable."""
return {
**self.mapped_params,
**self.profile_args,
"auth_type": "oauth",
"client_secret": self.get_env_var_format("client_secret"),
"client_id": self.get_env_var_format("client_id"),
}
71 changes: 71 additions & 0 deletions tests/profiles/databricks/test_dbr_oauth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""Tests for the databricks OAuth profile."""

from unittest.mock import patch

import pytest
from airflow.models.connection import Connection

from cosmos.profiles.databricks import DatabricksOauthProfileMapping


@pytest.fixture()
def mock_databricks_conn(): # type: ignore
"""
Mocks and returns an Airflow Databricks connection.
"""
conn = Connection(
conn_id="my_databricks_connection",
conn_type="databricks",
host="https://my_host",
login="my_client_id",
password="my_client_secret",
extra='{"http_path": "my_http_path"}',
)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
yield conn


def test_connection_claiming() -> None:
"""
Tests that the Databricks profile mapping claims the correct connection type.
"""
# should only claim when:
# - conn_type == databricks
# and the following exist:
# - schema
# - host
# - http_path
# - client_id
# - client_secret
potential_values = {
"conn_type": "databricks",
"host": "my_host",
"login": "my_client_id",
"password": "my_client_secret",
"extra": '{"http_path": "my_http_path"}',
}

# if we're missing any of the values, it shouldn't claim
for key in potential_values:
values = potential_values.copy()
del values[key]
conn = Connection(**values) # type: ignore

print("testing with", values)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = DatabricksOauthProfileMapping(conn, {"schema": "my_schema"})
assert not profile_mapping.can_claim_connection()

# also test when there's no schema
conn = Connection(**potential_values) # type: ignore
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = DatabricksOauthProfileMapping(conn, {})
assert not profile_mapping.can_claim_connection()

# if we have them all, it should claim
conn = Connection(**potential_values) # type: ignore
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = DatabricksOauthProfileMapping(conn, {"schema": "my_schema"})
assert profile_mapping.can_claim_connection()
Loading