From 525b6899c8eea5bf3da207e268c8c92b9359d8d9 Mon Sep 17 00:00:00 2001 From: Ben Cassell Date: Mon, 29 Jan 2024 15:26:28 -0800 Subject: [PATCH] mypy fixes --- dbt/adapters/databricks/auth.py | 10 +++---- dbt/adapters/databricks/connections.py | 27 ++++++++++++------- dbt/adapters/databricks/python_submissions.py | 8 +++--- 3 files changed, 25 insertions(+), 20 deletions(-) diff --git a/dbt/adapters/databricks/auth.py b/dbt/adapters/databricks/auth.py index cb8f6dcd..c1975ceb 100644 --- a/dbt/adapters/databricks/auth.py +++ b/dbt/adapters/databricks/auth.py @@ -1,5 +1,5 @@ from typing import Any, Dict, Optional -from databricks.sdk.oauth import ClientCredentials, Token, TokenSource +from databricks.sdk.oauth import ClientCredentials, Token from databricks.sdk.core import CredentialsProvider, HeaderFactory, Config, credentials_provider @@ -16,12 +16,12 @@ def as_dict(self) -> dict: return {"token": self._token} @staticmethod - def from_dict(raw: Optional[dict]) -> CredentialsProvider: + def from_dict(raw: Optional[dict]) -> Optional[CredentialsProvider]: if not raw: return None return token_auth(raw["token"]) - def __call__(self, *args: tuple, **kwargs: Dict[str, Any]) -> HeaderFactory: + def __call__(self, _: Optional[Config] = None) -> HeaderFactory: static_credentials = {"Authorization": f"Bearer {self._token}"} def inner() -> Dict[str, str]: @@ -31,8 +31,6 @@ def inner() -> Dict[str, str]: class m2m_auth(CredentialsProvider): - _token_source: TokenSource = None - def __init__(self, host: str, client_id: str, client_secret: str) -> None: @credentials_provider("noop", []) def noop_credentials(_: Any): # type: ignore @@ -70,7 +68,7 @@ def from_dict(host: str, client_id: str, client_secret: str, raw: dict) -> Crede c._token_source._token = Token.from_dict(raw["token"]) return c - def __call__(self, *args: tuple, **kwargs: Dict[str, Any]) -> HeaderFactory: + def __call__(self, _: Optional[Config] = None) -> HeaderFactory: def inner() -> Dict[str, str]: token = self._token_source.token() return {"Authorization": f"{token.token_type} {token.access_token}"} diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 94a59b7c..4f1a961a 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -58,7 +58,7 @@ from dbt.events.types import ConnectionUsed, SQLQuery, SQLQueryStatus from dbt.utils import DECIMALS, cast_to_str -from databricks import sql as dbsql +import databricks.sql as dbsql from databricks.sql.client import ( Connection as DatabricksSQLConnection, Cursor as DatabricksSQLCursor, @@ -82,6 +82,9 @@ logger = AdapterLogger("Databricks") +TCredentialProvider = Union[CredentialsProvider, SessionCredentials] + + class DbtCoreHandler(logging.Handler): def __init__(self, level: Union[str, int], dbt_logger: AdapterLogger): super().__init__(level=level) @@ -344,15 +347,17 @@ def extract_cluster_id(cls, http_path: str) -> Optional[str]: def cluster_id(self) -> Optional[str]: return self.extract_cluster_id(self.http_path) # type: ignore[arg-type] - def authenticate(self, in_provider: CredentialsProvider) -> CredentialsProvider: + def authenticate(self, in_provider: Optional[TCredentialProvider]) -> TCredentialProvider: self.validate_creds() host: str = self.host or "" if self._credentials_provider: - return self._provider_from_dict() + return self._provider_from_dict() # type: ignore if in_provider: - self._credentials_provider = in_provider.as_dict() + if isinstance(in_provider, m2m_auth) or isinstance(in_provider, token_auth): + self._credentials_provider = in_provider.as_dict() return in_provider + provider: TCredentialProvider # dbt will spin up multiple threads. This has to be sync. So lock here self._lock.acquire() try: @@ -373,7 +378,7 @@ def authenticate(self, in_provider: CredentialsProvider) -> CredentialsProvider: oauth_client = OAuthClient( host=host, client_id=self.client_id if self.client_id else CLIENT_ID, - client_secret=None, + client_secret="", redirect_url=REDIRECT_URL, scopes=SCOPES, ) @@ -416,7 +421,7 @@ def authenticate(self, in_provider: CredentialsProvider) -> CredentialsProvider: finally: self._lock.release() - def _provider_from_dict(self) -> CredentialsProvider: + def _provider_from_dict(self) -> Optional[TCredentialProvider]: if self.token: return token_auth.from_dict(self._credentials_provider) @@ -429,14 +434,16 @@ def _provider_from_dict(self) -> CredentialsProvider: ) oauth_client = OAuthClient( - host=self.host, + host=self.host or "", client_id=CLIENT_ID, - client_secret=None, + client_secret="", redirect_url=REDIRECT_URL, scopes=SCOPES, ) - return SessionCredentials.from_dict(client=oauth_client, raw=self._credentials_provider) + return SessionCredentials.from_dict( + client=oauth_client, raw=self._credentials_provider or {"token": {}} + ) class DatabricksSQLConnectionWrapper: @@ -844,7 +851,7 @@ def _reset_handle(self, open: Callable[[Connection], Connection]) -> None: class DatabricksConnectionManager(SparkConnectionManager): TYPE: str = "databricks" - credentials_provider: CredentialsProvider = None + credentials_provider: Optional[TCredentialProvider] = None def __init__(self, profile: AdapterRequiredConfig) -> None: super().__init__(profile) diff --git a/dbt/adapters/databricks/python_submissions.py b/dbt/adapters/databricks/python_submissions.py index c553b182..3281542b 100644 --- a/dbt/adapters/databricks/python_submissions.py +++ b/dbt/adapters/databricks/python_submissions.py @@ -3,7 +3,7 @@ from requests import Session from dbt.adapters.databricks.__version__ import version -from dbt.adapters.databricks.connections import DatabricksCredentials +from dbt.adapters.databricks.connections import DatabricksCredentials, TCredentialProvider from dbt.adapters.databricks import utils import base64 @@ -16,7 +16,7 @@ import dbt.exceptions from dbt.adapters.base import PythonJobHelper -from databricks.sdk.core import CredentialsProvider +from databricks.sdk.core import Config from requests.adapters import HTTPAdapter from dbt.adapters.databricks.connections import BearerAuth @@ -442,7 +442,7 @@ def submit(self, compiled_code: str) -> None: class DbtDatabricksBasePythonJobHelper(BaseDatabricksHelper): credentials: DatabricksCredentials # type: ignore[assignment] - _credentials_provider: CredentialsProvider = None + _credentials_provider: Optional[TCredentialProvider] = None def __init__(self, parsed_model: Dict, credentials: DatabricksCredentials) -> None: super().__init__( @@ -463,7 +463,7 @@ def __init__(self, parsed_model: Dict, credentials: DatabricksCredentials) -> No connection_parameters.pop("http_headers", {}) ) self._credentials_provider = credentials.authenticate(self._credentials_provider) - header_factory = self._credentials_provider() + header_factory = self._credentials_provider(Config()) self.session.auth = BearerAuth(header_factory) self.extra_headers.update({"User-Agent": user_agent, **http_headers})