Skip to content

Commit

Permalink
mypy fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
benc-db committed Jan 29, 2024
1 parent a972568 commit 525b689
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 20 deletions.
10 changes: 4 additions & 6 deletions dbt/adapters/databricks/auth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Any, Dict, Optional
from databricks.sdk.oauth import ClientCredentials, Token, TokenSource
from databricks.sdk.oauth import ClientCredentials, Token
from databricks.sdk.core import CredentialsProvider, HeaderFactory, Config, credentials_provider


Expand All @@ -16,12 +16,12 @@ def as_dict(self) -> dict:
return {"token": self._token}

@staticmethod
def from_dict(raw: Optional[dict]) -> CredentialsProvider:
def from_dict(raw: Optional[dict]) -> Optional[CredentialsProvider]:
if not raw:
return None
return token_auth(raw["token"])

def __call__(self, *args: tuple, **kwargs: Dict[str, Any]) -> HeaderFactory:
def __call__(self, _: Optional[Config] = None) -> HeaderFactory:
static_credentials = {"Authorization": f"Bearer {self._token}"}

def inner() -> Dict[str, str]:
Expand All @@ -31,8 +31,6 @@ def inner() -> Dict[str, str]:


class m2m_auth(CredentialsProvider):
_token_source: TokenSource = None

def __init__(self, host: str, client_id: str, client_secret: str) -> None:
@credentials_provider("noop", [])
def noop_credentials(_: Any): # type: ignore
Expand Down Expand Up @@ -70,7 +68,7 @@ def from_dict(host: str, client_id: str, client_secret: str, raw: dict) -> Crede
c._token_source._token = Token.from_dict(raw["token"])
return c

def __call__(self, *args: tuple, **kwargs: Dict[str, Any]) -> HeaderFactory:
def __call__(self, _: Optional[Config] = None) -> HeaderFactory:
def inner() -> Dict[str, str]:
token = self._token_source.token()
return {"Authorization": f"{token.token_type} {token.access_token}"}
Expand Down
27 changes: 17 additions & 10 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
from dbt.events.types import ConnectionUsed, SQLQuery, SQLQueryStatus
from dbt.utils import DECIMALS, cast_to_str

from databricks import sql as dbsql
import databricks.sql as dbsql
from databricks.sql.client import (
Connection as DatabricksSQLConnection,
Cursor as DatabricksSQLCursor,
Expand All @@ -82,6 +82,9 @@
logger = AdapterLogger("Databricks")


TCredentialProvider = Union[CredentialsProvider, SessionCredentials]


class DbtCoreHandler(logging.Handler):
def __init__(self, level: Union[str, int], dbt_logger: AdapterLogger):
super().__init__(level=level)
Expand Down Expand Up @@ -344,15 +347,17 @@ def extract_cluster_id(cls, http_path: str) -> Optional[str]:
def cluster_id(self) -> Optional[str]:
return self.extract_cluster_id(self.http_path) # type: ignore[arg-type]

def authenticate(self, in_provider: CredentialsProvider) -> CredentialsProvider:
def authenticate(self, in_provider: Optional[TCredentialProvider]) -> TCredentialProvider:
self.validate_creds()
host: str = self.host or ""
if self._credentials_provider:
return self._provider_from_dict()
return self._provider_from_dict() # type: ignore
if in_provider:
self._credentials_provider = in_provider.as_dict()
if isinstance(in_provider, m2m_auth) or isinstance(in_provider, token_auth):
self._credentials_provider = in_provider.as_dict()
return in_provider

provider: TCredentialProvider
# dbt will spin up multiple threads. This has to be sync. So lock here
self._lock.acquire()
try:
Expand All @@ -373,7 +378,7 @@ def authenticate(self, in_provider: CredentialsProvider) -> CredentialsProvider:
oauth_client = OAuthClient(
host=host,
client_id=self.client_id if self.client_id else CLIENT_ID,
client_secret=None,
client_secret="",
redirect_url=REDIRECT_URL,
scopes=SCOPES,
)
Expand Down Expand Up @@ -416,7 +421,7 @@ def authenticate(self, in_provider: CredentialsProvider) -> CredentialsProvider:
finally:
self._lock.release()

def _provider_from_dict(self) -> CredentialsProvider:
def _provider_from_dict(self) -> Optional[TCredentialProvider]:
if self.token:
return token_auth.from_dict(self._credentials_provider)

Expand All @@ -429,14 +434,16 @@ def _provider_from_dict(self) -> CredentialsProvider:
)

oauth_client = OAuthClient(
host=self.host,
host=self.host or "",
client_id=CLIENT_ID,
client_secret=None,
client_secret="",
redirect_url=REDIRECT_URL,
scopes=SCOPES,
)

return SessionCredentials.from_dict(client=oauth_client, raw=self._credentials_provider)
return SessionCredentials.from_dict(
client=oauth_client, raw=self._credentials_provider or {"token": {}}
)


class DatabricksSQLConnectionWrapper:
Expand Down Expand Up @@ -844,7 +851,7 @@ def _reset_handle(self, open: Callable[[Connection], Connection]) -> None:

class DatabricksConnectionManager(SparkConnectionManager):
TYPE: str = "databricks"
credentials_provider: CredentialsProvider = None
credentials_provider: Optional[TCredentialProvider] = None

def __init__(self, profile: AdapterRequiredConfig) -> None:
super().__init__(profile)
Expand Down
8 changes: 4 additions & 4 deletions dbt/adapters/databricks/python_submissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from requests import Session

from dbt.adapters.databricks.__version__ import version
from dbt.adapters.databricks.connections import DatabricksCredentials
from dbt.adapters.databricks.connections import DatabricksCredentials, TCredentialProvider
from dbt.adapters.databricks import utils

import base64
Expand All @@ -16,7 +16,7 @@
import dbt.exceptions
from dbt.adapters.base import PythonJobHelper

from databricks.sdk.core import CredentialsProvider
from databricks.sdk.core import Config
from requests.adapters import HTTPAdapter
from dbt.adapters.databricks.connections import BearerAuth

Expand Down Expand Up @@ -442,7 +442,7 @@ def submit(self, compiled_code: str) -> None:

class DbtDatabricksBasePythonJobHelper(BaseDatabricksHelper):
credentials: DatabricksCredentials # type: ignore[assignment]
_credentials_provider: CredentialsProvider = None
_credentials_provider: Optional[TCredentialProvider] = None

def __init__(self, parsed_model: Dict, credentials: DatabricksCredentials) -> None:
super().__init__(
Expand All @@ -463,7 +463,7 @@ def __init__(self, parsed_model: Dict, credentials: DatabricksCredentials) -> No
connection_parameters.pop("http_headers", {})
)
self._credentials_provider = credentials.authenticate(self._credentials_provider)
header_factory = self._credentials_provider()
header_factory = self._credentials_provider(Config())
self.session.auth = BearerAuth(header_factory)

self.extra_headers.update({"User-Agent": user_agent, **http_headers})
Expand Down

0 comments on commit 525b689

Please sign in to comment.