Skip to content

Commit

Permalink
Fixes a few small issues in the databricks trigger / hook: (#12)
Browse files Browse the repository at this point in the history
- Enables trigger logging
- Updates HTTPException to ClientResponseError in databricks hook
- Sets databricks connection asynchronously in hook
  • Loading branch information
OlympuJupiter committed Sep 16, 2021
1 parent 3cd3900 commit 89c4fa1
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 29 deletions.
1 change: 0 additions & 1 deletion astronomer_operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ async def run(self):
)
raise AirflowException(error_message)
else:
# TODO: Figure out logging for trigger
self.log.info("%s in run state: %s", self.task_id, run_state)
self.log.info("Sleeping for %s seconds.", self.polling_period_seconds)
await asyncio.sleep(self.polling_period_seconds)
Expand Down
42 changes: 38 additions & 4 deletions astronomer_operators/hooks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,35 @@
import base64

import aiohttp
from aiohttp.web_exceptions import HTTPException
from aiohttp import ClientResponseError
from airflow.exceptions import AirflowException
from airflow.providers.databricks.hooks.databricks import (
GET_RUN_ENDPOINT,
USER_AGENT_HEADER,
DatabricksHook,
RunState,
)
from asgiref.sync import sync_to_async

DEFAULT_CONN_NAME = "databricks_default"


class DatabricksHookAsync(DatabricksHook):
def __init__(
self,
databricks_conn_id: str = DEFAULT_CONN_NAME,
timeout_seconds: int = 180,
retry_limit: int = 3,
retry_delay: float = 1.0,
) -> None:
self.databricks_conn_id = databricks_conn_id
self.databricks_conn = None # To be set asynchronously in create_hook()
self.timeout_seconds = timeout_seconds
if retry_limit < 1:
raise ValueError("Retry limit must be greater than equal to 1")
self.retry_limit = retry_limit
self.retry_delay = retry_delay

async def get_run_state_async(self, run_id: str) -> RunState:
"""
Retrieves run state of the run using an asyncronous api call.
Expand Down Expand Up @@ -88,12 +106,12 @@ async def _do_api_call_async(self, endpoint_info, json):
)
response.raise_for_status()
return await response.json()
except HTTPException as e:
except ClientResponseError as e:
if not self._retryable_error_async(e):
# In this case, the user probably made a mistake.
# Don't retry.
raise AirflowException(
f"Response: {e}, Status Code: {e.status_code}"
f"Response: {e.message}, Status Code: {e.status}"
)
self._log_request_error(attempt_num, e)

Expand All @@ -119,5 +137,21 @@ def _retryable_error_async(self, exception) -> bool:
- anything with a status code >= 500
Most retryable errors are covered by status code >= 500.
:return: if the status is retryable
:rtype: bool
"""
return exception.status_code >= 500
return exception.status >= 500


async def create_hook():
"""
Initializes a new DatabricksHookAsync then sets its databricks_conn
field asynchronously.
:return: a new async Databricks hook
:rtype: DataBricksHookAsync()
"""
self = DatabricksHookAsync()
self.databricks_conn = await sync_to_async(self.get_connection)(
self.databricks_conn_id
)
return self
31 changes: 10 additions & 21 deletions tests/astronomer_operators/hooks/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
from unittest import mock

import pytest
from aiohttp.web_exceptions import HTTPBadRequest, HTTPInternalServerError
from airflow.exceptions import AirflowException
from airflow.providers.databricks.hooks.databricks import (
GET_RUN_ENDPOINT,
SUBMIT_RUN_ENDPOINT,
)

from astronomer_operators.hooks.databricks import DatabricksHookAsync
from astronomer_operators.hooks.databricks import create_hook

TASK_ID = "databricks_check"
CONN_ID = "unit_test_conn_id"
Expand All @@ -30,7 +29,7 @@ async def test_databricks_hook_get_run_state(mocked_response):
is in a PENDING state (i.e. "RUNNING") and after it reaches a TERMINATED
state (i.e. "SUCCESS").
"""
hook = DatabricksHookAsync()
hook = await create_hook()
# Mock response while job is running
mocked_response.return_value = {
"state": {
Expand Down Expand Up @@ -72,7 +71,7 @@ async def test_do_api_call_async_get_basic_auth(caplog, aioresponse):
and basic auth.
"""
caplog.set_level(logging.INFO)
hook = DatabricksHookAsync()
hook = await create_hook()
hook.databricks_conn.login = LOGIN
hook.databricks_conn.password = PASSWORD
params = {"run_id": RUN_ID}
Expand All @@ -98,7 +97,7 @@ async def test_do_api_call_async_get_auth_token(caplog, aioresponse):
and basic auth.
"""
caplog.set_level(logging.INFO)
hook = DatabricksHookAsync()
hook = await create_hook()
hook.databricks_conn.extra = json.dumps({"token": "test_token"})
params = {"run_id": RUN_ID}

Expand All @@ -119,15 +118,15 @@ async def test_do_api_call_async_non_retryable_error(aioresponse):
Asserts that the Databricks hook will throw an exception
when a non-retryable error is returned by the API.
"""
hook = DatabricksHookAsync()
hook = await create_hook()

hook.databricks_conn.login = LOGIN
hook.databricks_conn.password = PASSWORD
params = {"run_id": RUN_ID}

aioresponse.get(
"https://localhost/api/2.0/jobs/runs/get?run_id=unit_test_run_id",
exception=HTTPBadRequest(),
status=400,
)

with pytest.raises(AirflowException) as exc:
Expand All @@ -142,14 +141,14 @@ async def test_do_api_call_async_retryable_error(aioresponse):
Asserts that the Databricks hook will attempt another API call as many
times as the retry_limit when a retryable error is returned by the API.
"""
hook = DatabricksHookAsync()
hook = await create_hook()
hook.databricks_conn.login = LOGIN
hook.databricks_conn.password = PASSWORD
params = {"run_id": RUN_ID}

aioresponse.get(
"https://localhost/api/2.0/jobs/runs/get?run_id=unit_test_run_id",
exception=HTTPInternalServerError(),
status=500,
repeat=True,
)

Expand All @@ -167,7 +166,7 @@ async def test_do_api_call_async_post(aioresponse):
"""
Asserts that the Databricks hook makes a POST call as expected.
"""
hook = DatabricksHookAsync()
hook = await create_hook()
hook.databricks_conn.login = LOGIN
hook.databricks_conn.password = PASSWORD
json = {
Expand All @@ -192,7 +191,7 @@ async def test_do_api_call_async_unknown_method():
Asserts that the Databricks hook throws an exception when it attempts to
make an API call using a non-existent method.
"""
hook = DatabricksHookAsync()
hook = await create_hook()
hook.databricks_conn.login = LOGIN
hook.databricks_conn.password = PASSWORD
json = {
Expand All @@ -205,13 +204,3 @@ async def test_do_api_call_async_unknown_method():
await hook._do_api_call_async(("NOPE", "api/2.0/jobs/runs/submit"), json)

assert str(exc.value) == "Unexpected HTTP Method: NOPE"


def test_retryable_error_async():
"""
Asserts that HTTP errors are properly identified as retryable
when the status code >= 500.
"""
hook = DatabricksHookAsync()
assert hook._retryable_error_async(HTTPInternalServerError)
assert not hook._retryable_error_async(HTTPBadRequest)
10 changes: 7 additions & 3 deletions tests/astronomer_operators/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,16 @@ async def test_databricks_trigger_running(run_state, caplog):
)

task = asyncio.create_task(trigger.run().__anext__())
await asyncio.sleep(0.5)

# TriggerEvent was not returned
assert task.done() is False

# # TODO: uncomment these after logging is implemented in Trigger
# assert f"{TASK_ID} in run state: RUNNING" in caplog.text
# assert f"Sleeping for {POLLING_PERIOD_SECONDS} seconds." in caplog.text
assert (
f"{TASK_ID} in run state: {{'life_cycle_state': 'RUNNING', 'result_state': '', 'state_message': 'In run'}}"
in caplog.text
)
assert f"Sleeping for {POLLING_PERIOD_SECONDS} seconds." in caplog.text

# Prevents error when task is destroyed while in "pending" state
asyncio.get_event_loop().stop()
Expand Down
8 changes: 8 additions & 0 deletions tests/astronomer_operators/test_external_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
DEFAULT_DATE = datetime(2015, 1, 1)
TEST_DAG_ID = "unit_test_dag"
TEST_TASK_ID = "external_task_sensor_check"
TEST_RUN_ID = "unit_test_dag_run_id"
TEST_EXT_DAG_ID = "wait_for_me_dag" # DAG the external task sensor is waiting on
TEST_EXT_TASK_ID = "wait_for_me_task" # Task the external task sensor is waiting on
TEST_STATES = ["success", "fail"]
Expand Down Expand Up @@ -224,6 +225,13 @@ async def test_task_state_trigger(session, dag):
Asserts that the TaskStateTrigger only goes off on or after a TaskInstance
reaches an allowed state (i.e. SUCCESS).
"""
dag_run = DagRun(
dag.dag_id, run_type="manual", execution_date=DEFAULT_DATE, run_id=TEST_RUN_ID
)

session.add(dag_run)
session.commit()

external_task = DummyOperator(task_id=TEST_TASK_ID, dag=dag)
instance = TaskInstance(external_task, DEFAULT_DATE)
session.add(instance)
Expand Down

0 comments on commit 89c4fa1

Please sign in to comment.