Skip to content

Commit

Permalink
Add on_kill() to kill Trino query if the task is killed (#24559)
Browse files Browse the repository at this point in the history
This PR is a follow up to PR #24415.
It adds on_kill method to the TrinoOperator to kill Trino query if Airflow task is killed
  • Loading branch information
phanikumv authored Jun 20, 2022
1 parent 05c542d commit 4f4f37c
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 3 deletions.
2 changes: 2 additions & 0 deletions airflow/providers/trino/hooks/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class TrinoHook(DbApiHook):
default_conn_name = 'trino_default'
conn_type = 'trino'
hook_name = 'Trino'
query_id = ''

def get_conn(self) -> Connection:
"""Returns a connection object"""
Expand Down Expand Up @@ -301,6 +302,7 @@ def run(
results = []
for sql_statement in sql:
self._run_command(cur, self._strip_sql(sql_statement), parameters)
self.query_id = cur.stats["queryId"]
if handler is not None:
result = handler(cur)
results.append(result)
Expand Down
17 changes: 17 additions & 0 deletions airflow/providers/trino/operators/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, Union

from trino.exceptions import TrinoQueryError

from airflow.models import BaseOperator
from airflow.providers.trino.hooks.trino import TrinoHook

Expand Down Expand Up @@ -79,3 +81,18 @@ def execute(self, context: 'Context') -> None:
self.hook.run(
sql=self.sql, autocommit=self.autocommit, parameters=self.parameters, handler=self.handler
)

def on_kill(self) -> None:
if self.hook is not None and isinstance(self.hook, TrinoHook):
query_id = "'" + self.hook.query_id + "'"
try:
self.log.info("Stopping query run with queryId - %s", self.hook.query_id)
self.hook.run(
sql=f"CALL system.runtime.kill_query(query_id => {query_id},message => 'Job "
f"killed by "
f"user');",
handler=list,
)
except TrinoQueryError as e:
self.log.info(str(e))
self.log.info("Trino query (%s) terminated", query_id)
6 changes: 3 additions & 3 deletions tests/system/providers/trino/example_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
)
trino_create_table = TrinoOperator(
task_id="trino_create_table",
sql=f"""CREATE TABLE {SCHEMA}.{TABLE}(
sql=f"""CREATE TABLE IF NOT EXISTS {SCHEMA}.{TABLE}(
cityid bigint,
cityname varchar
)""",
Expand All @@ -60,9 +60,9 @@

trino_multiple_queries = TrinoOperator(
task_id="trino_multiple_queries",
sql=f"""CREATE TABLE {SCHEMA}.{TABLE1}(cityid bigint,cityname varchar);
sql=f"""CREATE TABLE IF NOT EXISTS {SCHEMA}.{TABLE1}(cityid bigint,cityname varchar);
INSERT INTO {SCHEMA}.{TABLE1} VALUES (2, 'San Jose');
CREATE TABLE {SCHEMA}.{TABLE2}(cityid bigint,cityname varchar);
CREATE TABLE IF NOT EXISTS {SCHEMA}.{TABLE2}(cityid bigint,cityname varchar);
INSERT INTO {SCHEMA}.{TABLE2} VALUES (3, 'San Diego');""",
handler=list,
)
Expand Down

0 comments on commit 4f4f37c

Please sign in to comment.