Skip to content

Commit

Permalink
Clean DatabricksHookAsync (#16)
Browse files Browse the repository at this point in the history
Since apache/airflow#18339 was merged and released in 2.0.2 of Databricks provider - https://airflow.apache.org/docs/apache-airflow-providers-databricks/stable/commits.html#id1

We don't need the old code anymore.

* Fix `ExternalTaskSensorAsync` test
  • Loading branch information
kaxil authored Nov 16, 2021
1 parent 90b1497 commit 32a6ae8
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 41 deletions.
34 changes: 5 additions & 29 deletions astronomer_operators/hooks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,6 @@


class DatabricksHookAsync(DatabricksHook):
def __init__(
self,
databricks_conn_id: str = DEFAULT_CONN_NAME,
timeout_seconds: int = 180,
retry_limit: int = 3,
retry_delay: float = 1.0,
) -> None:
self.databricks_conn_id = databricks_conn_id
self.databricks_conn = None # To be set asynchronously in create_hook()
self.timeout_seconds = timeout_seconds
if retry_limit < 1:
raise ValueError("Retry limit must be greater than equal to 1")
self.retry_limit = retry_limit
self.retry_delay = retry_delay

async def get_run_state_async(self, run_id: str) -> RunState:
"""
Retrieves run state of the run using an asyncronous api call.
Expand Down Expand Up @@ -64,6 +49,11 @@ async def _do_api_call_async(self, endpoint_info, json):
headers = USER_AGENT_HEADER
attempt_num = 1

if not self.databricks_conn:
self.databricks_conn = await sync_to_async(self.get_connection)(
self.databricks_conn_id
)

if "token" in self.databricks_conn.extra_dejson:
self.log.info("Using token auth. ")
auth = self.databricks_conn.extra_dejson["token"]
Expand Down Expand Up @@ -141,17 +131,3 @@ def _retryable_error_async(self, exception) -> bool:
:rtype: bool
"""
return exception.status >= 500


async def create_hook():
"""
Initializes a new DatabricksHookAsync then sets its databricks_conn
field asynchronously.
:return: a new async Databricks hook
:rtype: DataBricksHookAsync()
"""
self = DatabricksHookAsync()
self.databricks_conn = await sync_to_async(self.get_connection)(
self.databricks_conn_id
)
return self
36 changes: 27 additions & 9 deletions tests/astronomer_operators/hooks/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
GET_RUN_ENDPOINT,
SUBMIT_RUN_ENDPOINT,
)
from asgiref.sync import sync_to_async

from astronomer_operators.hooks.databricks import create_hook
from astronomer_operators.hooks.databricks import DatabricksHookAsync

TASK_ID = "databricks_check"
CONN_ID = "unit_test_conn_id"
Expand All @@ -29,7 +30,7 @@ async def test_databricks_hook_get_run_state(mocked_response):
is in a PENDING state (i.e. "RUNNING") and after it reaches a TERMINATED
state (i.e. "SUCCESS").
"""
hook = await create_hook()
hook = DatabricksHookAsync()
# Mock response while job is running
mocked_response.return_value = {
"state": {
Expand Down Expand Up @@ -71,7 +72,10 @@ async def test_do_api_call_async_get_basic_auth(caplog, aioresponse):
and basic auth.
"""
caplog.set_level(logging.INFO)
hook = await create_hook()
hook = DatabricksHookAsync()
hook.databricks_conn = await sync_to_async(hook.get_connection)(
hook.databricks_conn_id
)
hook.databricks_conn.login = LOGIN
hook.databricks_conn.password = PASSWORD
params = {"run_id": RUN_ID}
Expand All @@ -97,7 +101,10 @@ async def test_do_api_call_async_get_auth_token(caplog, aioresponse):
and basic auth.
"""
caplog.set_level(logging.INFO)
hook = await create_hook()
hook = DatabricksHookAsync()
hook.databricks_conn = await sync_to_async(hook.get_connection)(
hook.databricks_conn_id
)
hook.databricks_conn.extra = json.dumps({"token": "test_token"})
params = {"run_id": RUN_ID}

Expand All @@ -118,8 +125,10 @@ async def test_do_api_call_async_non_retryable_error(aioresponse):
Asserts that the Databricks hook will throw an exception
when a non-retryable error is returned by the API.
"""
hook = await create_hook()

hook = DatabricksHookAsync()
hook.databricks_conn = await sync_to_async(hook.get_connection)(
hook.databricks_conn_id
)
hook.databricks_conn.login = LOGIN
hook.databricks_conn.password = PASSWORD
params = {"run_id": RUN_ID}
Expand All @@ -141,7 +150,10 @@ async def test_do_api_call_async_retryable_error(aioresponse):
Asserts that the Databricks hook will attempt another API call as many
times as the retry_limit when a retryable error is returned by the API.
"""
hook = await create_hook()
hook = DatabricksHookAsync()
hook.databricks_conn = await sync_to_async(hook.get_connection)(
hook.databricks_conn_id
)
hook.databricks_conn.login = LOGIN
hook.databricks_conn.password = PASSWORD
params = {"run_id": RUN_ID}
Expand All @@ -166,7 +178,10 @@ async def test_do_api_call_async_post(aioresponse):
"""
Asserts that the Databricks hook makes a POST call as expected.
"""
hook = await create_hook()
hook = DatabricksHookAsync()
hook.databricks_conn = await sync_to_async(hook.get_connection)(
hook.databricks_conn_id
)
hook.databricks_conn.login = LOGIN
hook.databricks_conn.password = PASSWORD
json = {
Expand All @@ -191,7 +206,10 @@ async def test_do_api_call_async_unknown_method():
Asserts that the Databricks hook throws an exception when it attempts to
make an API call using a non-existent method.
"""
hook = await create_hook()
hook = DatabricksHookAsync()
hook.databricks_conn = await sync_to_async(hook.get_connection)(
hook.databricks_conn_id
)
hook.databricks_conn.login = LOGIN
hook.databricks_conn.password = PASSWORD
json = {
Expand Down
4 changes: 1 addition & 3 deletions tests/astronomer_operators/test_external_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,7 @@ async def test_dag_state_trigger(session, dag):
reaches an allowed state (i.e. SUCCESS).
"""
dag_run = DagRun(
dag.dag_id,
run_type="manual",
execution_date=DEFAULT_DATE,
dag.dag_id, run_type="manual", execution_date=DEFAULT_DATE, run_id=TEST_RUN_ID
)

session.add(dag_run)
Expand Down

0 comments on commit 32a6ae8

Please sign in to comment.