Skip to content

Commit

Permalink
Apply connection retry refactor, add defaults with exponential backoff (
Browse files Browse the repository at this point in the history
#137)

### Description

Applies connection retry refactor, add defaults with exponential backoff as per an item from #127.

- Refactors retry logic to use `retry_connection` from core.
- Adds consistent defaults with other adpaters:
  - if `connect_retries` and `connect_timeout` are not specified: 1 retry after 1s
  - if `connect_retries` is specified but `connect_timeout` is not, it will use exponential backoff
  • Loading branch information
ueshin authored Dec 21, 2022
1 parent c0cedc9 commit 23725ae
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 58 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
### Features
- Support Python 3.11 ([#233](https://github.com/databricks/dbt-databricks/pull/233))
- Support `incremental_predicates` ([#161](https://github.com/databricks/dbt-databricks/pull/161))
- Apply connection retry refactor, add defaults with exponential backoff ([#137](https://github.com/databricks/dbt-databricks/pull/137))

## dbt-databricks 1.3.3 (Release TBD)

Expand Down
94 changes: 36 additions & 58 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import dbt.exceptions
from dbt.adapters.base import Credentials
from dbt.adapters.base.query_headers import MacroQueryStringSetter
from dbt.adapters.databricks.__version__ import version as __version__
from dbt.adapters.spark.connections import SparkConnectionManager
from dbt.clients import agate_helper
from dbt.contracts.connection import (
AdapterResponse,
Expand All @@ -39,14 +39,14 @@
from dbt.events.types import ConnectionUsed, SQLQuery, SQLQueryStatus
from dbt.utils import DECIMALS, cast_to_str

from dbt.adapters.spark.connections import SparkConnectionManager, _is_retryable_error

from databricks import sql as dbsql
from databricks.sql.client import (
Connection as DatabricksSQLConnection,
Cursor as DatabricksSQLCursor,
)
from databricks.sql.exc import Error as DBSQLError
from databricks.sql.exc import Error

from dbt.adapters.databricks.__version__ import version as __version__
from dbt.adapters.databricks.utils import redact_credentials

logger = AdapterLogger("Databricks")
Expand All @@ -68,8 +68,8 @@ class DatabricksCredentials(Credentials):
session_properties: Optional[Dict[str, Any]] = None
connection_parameters: Optional[Dict[str, Any]] = None

connect_retries: int = 0
connect_timeout: int = 10
connect_retries: int = 1
connect_timeout: Optional[int] = None
retry_all: bool = False

_ALIASES = {
Expand Down Expand Up @@ -255,14 +255,14 @@ def cancel(self) -> None:
for cursor in cursors:
try:
cursor.cancel()
except DBSQLError as exc:
except Error as exc:
logger.debug("Exception while cancelling query: {}".format(exc))
_log_dbsql_errors(exc)

def close(self) -> None:
try:
self._conn.close()
except DBSQLError as exc:
except Error as exc:
logger.debug("Exception while closing connection: {}".format(exc))
_log_dbsql_errors(exc)

Expand Down Expand Up @@ -305,14 +305,14 @@ def __init__(self, cursor: DatabricksSQLCursor):
def cancel(self) -> None:
try:
self._cursor.cancel()
except DBSQLError as exc:
except Error as exc:
logger.debug("Exception while cancelling query: {}".format(exc))
_log_dbsql_errors(exc)

def close(self) -> None:
try:
self._cursor.close()
except DBSQLError as exc:
except Error as exc:
logger.debug("Exception while closing cursor: {}".format(exc))
_log_dbsql_errors(exc)

Expand Down Expand Up @@ -421,7 +421,7 @@ def exception_handler(self, sql: str) -> Iterator[None]:
try:
yield

except DBSQLError as exc:
except Error as exc:
logger.debug(f"Error while running:\n{log_sql}")
_log_dbsql_errors(exc)
raise dbt.exceptions.RuntimeException(str(exc)) from exc
Expand Down Expand Up @@ -473,7 +473,7 @@ def add_query(
)

return connection, cursor
except DBSQLError:
except Error:
if cursor is not None:
cursor.close()
cursor = None
Expand Down Expand Up @@ -546,6 +546,8 @@ def open(cls, connection: Connection) -> Connection:
return connection

creds: DatabricksCredentials = connection.credentials
timeout = creds.connect_timeout

creds.validate_creds()

user_agent_entry = f"dbt-databricks/{__version__}"
Expand All @@ -560,9 +562,7 @@ def open(cls, connection: Connection) -> Connection:
creds.get_all_http_headers(connection_parameters.pop("http_headers", {})).items()
)

exc: Optional[Exception] = None

for i in range(1 + creds.connect_retries):
def connect() -> DatabricksSQLConnectionWrapper:
try:
# TODO: what is the error when a user specifies a catalog they don't have access to
conn: DatabricksSQLConnection = dbsql.connect(
Expand All @@ -576,53 +576,31 @@ def open(cls, connection: Connection) -> Connection:
_user_agent_entry=user_agent_entry,
**connection_parameters,
)
handle = DatabricksSQLConnectionWrapper(
conn, is_cluster=creds.cluster_id is not None
)
break
except Exception as e:
exc = e
if isinstance(e, EOFError):
# The user almost certainly has invalid credentials.
# Perhaps a token expired, or something
msg = "Failed to connect"
if creds.token is not None:
msg += ", is your token valid?"
raise dbt.exceptions.FailedToConnectException(msg) from e
retryable_message = _is_retryable_error(e)
if retryable_message and creds.connect_retries > 0:
msg = (
f"Warning: {retryable_message}\n\tRetrying in "
f"{creds.connect_timeout} seconds "
f"({i} of {creds.connect_retries})"
)
logger.warning(msg)
time.sleep(creds.connect_timeout)
elif creds.retry_all and creds.connect_retries > 0:
msg = (
f"Warning: {getattr(exc, 'message', 'No message')}, "
f"retrying due to 'retry_all' configuration "
f"set to true.\n\tRetrying in "
f"{creds.connect_timeout} seconds "
f"({i} of {creds.connect_retries})"
)
logger.warning(msg)
time.sleep(creds.connect_timeout)
else:
logger.debug(f"failed to connect: {exc}")
_log_dbsql_errors(exc)
raise dbt.exceptions.FailedToConnectException("failed to connect") from e
else:
assert exc is not None
raise exc
return DatabricksSQLConnectionWrapper(conn, is_cluster=creds.cluster_id is not None)
except Error as exc:
_log_dbsql_errors(exc)
raise

connection.handle = handle
connection.state = ConnectionState.OPEN
return connection
def exponential_backoff(attempt: int) -> int:
return attempt * attempt

retryable_exceptions = []
# this option is for backwards compatibility
if creds.retry_all:
retryable_exceptions = [Error]

return cls.retry_connection(
connection,
connect=connect,
logger=logger,
retryable_exceptions=retryable_exceptions,
retry_limit=creds.connect_retries,
retry_timeout=(timeout if timeout is not None else exponential_backoff),
)


def _log_dbsql_errors(exc: Exception) -> None:
if isinstance(exc, DBSQLError):
if isinstance(exc, Error):
logger.debug(f"{type(exc)}: {exc}")
for key, value in sorted(exc.context.items()):
logger.debug(f"{key}: {value}")

0 comments on commit 23725ae

Please sign in to comment.