Skip to content

Commit

Permalink
Allow Databricks SQL hook to cancel timed out queries (apache#42668)
Browse files Browse the repository at this point in the history
  • Loading branch information
R7L208 authored and ellisms committed Nov 13, 2024
1 parent 3e40a24 commit 2018c33
Show file tree
Hide file tree
Showing 6 changed files with 287 additions and 56 deletions.
32 changes: 32 additions & 0 deletions providers/src/airflow/providers/databricks/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# Note: Any AirflowException raised is expected to cause the TaskInstance
# to be marked in an ERROR state
"""Exceptions used by Databricks Provider."""

from __future__ import annotations

from airflow.exceptions import AirflowException


class DatabricksSqlExecutionError(AirflowException):
"""Raised when there is an error in sql execution."""


class DatabricksSqlExecutionTimeout(DatabricksSqlExecutionError):
"""Raised when a sql execution times out."""
41 changes: 39 additions & 2 deletions providers/src/airflow/providers/databricks/hooks/databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
# under the License.
from __future__ import annotations

import threading
import warnings
from collections import namedtuple
from contextlib import closing
from copy import copy
from datetime import timedelta
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -35,8 +37,12 @@

from databricks import sql # type: ignore[attr-defined]

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.exceptions import (
AirflowException,
AirflowProviderDeprecationWarning,
)
from airflow.providers.common.sql.hooks.sql import DbApiHook, return_single_query_results
from airflow.providers.databricks.exceptions import DatabricksSqlExecutionError, DatabricksSqlExecutionTimeout
from airflow.providers.databricks.hooks.databricks_base import BaseDatabricksHook

if TYPE_CHECKING:
Expand All @@ -49,6 +55,16 @@
T = TypeVar("T")


def create_timeout_thread(cur, execution_timeout: timedelta | None) -> threading.Timer | None:
if execution_timeout is not None:
seconds_to_timeout = execution_timeout.total_seconds()
t = threading.Timer(seconds_to_timeout, cur.connection.cancel)
else:
t = None

return t


class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
"""
Hook to interact with Databricks SQL.
Expand Down Expand Up @@ -184,6 +200,7 @@ def run(
handler: None = ...,
split_statements: bool = ...,
return_last: bool = ...,
execution_timeout: timedelta | None = None,
) -> None: ...

@overload
Expand All @@ -195,6 +212,7 @@ def run(
handler: Callable[[Any], T] = ...,
split_statements: bool = ...,
return_last: bool = ...,
execution_timeout: timedelta | None = None,
) -> tuple | list[tuple] | list[list[tuple] | tuple] | None: ...

def run(
Expand All @@ -205,6 +223,7 @@ def run(
handler: Callable[[Any], T] | None = None,
split_statements: bool = True,
return_last: bool = True,
execution_timeout: timedelta | None = None,
) -> tuple | list[tuple] | list[list[tuple] | tuple] | None:
"""
Run a command or a list of commands.
Expand All @@ -224,6 +243,8 @@ def run(
:param return_last: Whether to return result for only last statement or for all after split
:return: return only result of the LAST SQL expression if handler was provided unless return_last
is set to False.
:param execution_timeout: max time allowed for the execution of this task instance, if it goes beyond
it will raise and fail.
"""
self.descriptions = []
if isinstance(sql, str):
Expand All @@ -248,7 +269,23 @@ def run(
self.set_autocommit(conn, autocommit)

with closing(conn.cursor()) as cur:
self._run_command(cur, sql_statement, parameters) # type: ignore[attr-defined]
t = create_timeout_thread(cur, execution_timeout)

# TODO: adjust this to make testing easier
try:
self._run_command(cur, sql_statement, parameters) # type: ignore[attr-defined]
except Exception as e:
if t is None or t.is_alive():
raise DatabricksSqlExecutionError(
f"Error running SQL statement: {sql_statement}. {str(e)}"
)
raise DatabricksSqlExecutionTimeout(
f"Timeout threshold exceeded for SQL statement: {sql_statement} was cancelled."
)
finally:
if t is not None:
t.cancel()

if handler is not None:
raw_result = handler(cur)
if self.return_tuple:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,3 +353,8 @@ def execute(self, context: Context) -> Any:
self.log.info("Executing: %s", sql)
hook = self._get_hook()
hook.run(sql)

def on_kill(self) -> None:
# NB: on_kill isn't required for this operator since query cancelling gets
# handled in `DatabricksSqlHook.run()` method which is called in `execute()`
...
Loading

0 comments on commit 2018c33

Please sign in to comment.