Skip to content

Commit

Permalink
Clean DatabricksHookAsync
Browse files Browse the repository at this point in the history
Since apache/airflow#18339 was merged and released in 2.0.2 of Databricks provider - https://airflow.apache.org/docs/apache-airflow-providers-databricks/stable/commits.html#id1

We don't need the old code anymore.
  • Loading branch information
kaxil committed Nov 16, 2021
1 parent 90b1497 commit e41bf96
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 37 deletions.
33 changes: 4 additions & 29 deletions astronomer_operators/hooks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,6 @@


class DatabricksHookAsync(DatabricksHook):
def __init__(
self,
databricks_conn_id: str = DEFAULT_CONN_NAME,
timeout_seconds: int = 180,
retry_limit: int = 3,
retry_delay: float = 1.0,
) -> None:
self.databricks_conn_id = databricks_conn_id
self.databricks_conn = None # To be set asynchronously in create_hook()
self.timeout_seconds = timeout_seconds
if retry_limit < 1:
raise ValueError("Retry limit must be greater than equal to 1")
self.retry_limit = retry_limit
self.retry_delay = retry_delay

async def get_run_state_async(self, run_id: str) -> RunState:
"""
Retrieves run state of the run using an asyncronous api call.
Expand Down Expand Up @@ -64,6 +49,10 @@ async def _do_api_call_async(self, endpoint_info, json):
headers = USER_AGENT_HEADER
attempt_num = 1

self.databricks_conn = await sync_to_async(self.get_connection)(
self.databricks_conn_id
)

if "token" in self.databricks_conn.extra_dejson:
self.log.info("Using token auth. ")
auth = self.databricks_conn.extra_dejson["token"]
Expand Down Expand Up @@ -141,17 +130,3 @@ def _retryable_error_async(self, exception) -> bool:
:rtype: bool
"""
return exception.status >= 500


async def create_hook():
"""
Initializes a new DatabricksHookAsync then sets its databricks_conn
field asynchronously.
:return: a new async Databricks hook
:rtype: DataBricksHookAsync()
"""
self = DatabricksHookAsync()
self.databricks_conn = await sync_to_async(self.get_connection)(
self.databricks_conn_id
)
return self
16 changes: 8 additions & 8 deletions tests/astronomer_operators/hooks/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
SUBMIT_RUN_ENDPOINT,
)

from astronomer_operators.hooks.databricks import create_hook
from astronomer_operators.hooks.databricks import DatabricksHookAsync

TASK_ID = "databricks_check"
CONN_ID = "unit_test_conn_id"
Expand All @@ -29,7 +29,7 @@ async def test_databricks_hook_get_run_state(mocked_response):
is in a PENDING state (i.e. "RUNNING") and after it reaches a TERMINATED
state (i.e. "SUCCESS").
"""
hook = await create_hook()
hook = DatabricksHookAsync()
# Mock response while job is running
mocked_response.return_value = {
"state": {
Expand Down Expand Up @@ -71,7 +71,7 @@ async def test_do_api_call_async_get_basic_auth(caplog, aioresponse):
and basic auth.
"""
caplog.set_level(logging.INFO)
hook = await create_hook()
hook = DatabricksHookAsync()
hook.databricks_conn.login = LOGIN
hook.databricks_conn.password = PASSWORD
params = {"run_id": RUN_ID}
Expand All @@ -97,7 +97,7 @@ async def test_do_api_call_async_get_auth_token(caplog, aioresponse):
and basic auth.
"""
caplog.set_level(logging.INFO)
hook = await create_hook()
hook = DatabricksHookAsync()
hook.databricks_conn.extra = json.dumps({"token": "test_token"})
params = {"run_id": RUN_ID}

Expand All @@ -118,7 +118,7 @@ async def test_do_api_call_async_non_retryable_error(aioresponse):
Asserts that the Databricks hook will throw an exception
when a non-retryable error is returned by the API.
"""
hook = await create_hook()
hook = DatabricksHookAsync()

hook.databricks_conn.login = LOGIN
hook.databricks_conn.password = PASSWORD
Expand All @@ -141,7 +141,7 @@ async def test_do_api_call_async_retryable_error(aioresponse):
Asserts that the Databricks hook will attempt another API call as many
times as the retry_limit when a retryable error is returned by the API.
"""
hook = await create_hook()
hook = DatabricksHookAsync()
hook.databricks_conn.login = LOGIN
hook.databricks_conn.password = PASSWORD
params = {"run_id": RUN_ID}
Expand All @@ -166,7 +166,7 @@ async def test_do_api_call_async_post(aioresponse):
"""
Asserts that the Databricks hook makes a POST call as expected.
"""
hook = await create_hook()
hook = DatabricksHookAsync()
hook.databricks_conn.login = LOGIN
hook.databricks_conn.password = PASSWORD
json = {
Expand All @@ -191,7 +191,7 @@ async def test_do_api_call_async_unknown_method():
Asserts that the Databricks hook throws an exception when it attempts to
make an API call using a non-existent method.
"""
hook = await create_hook()
hook = DatabricksHookAsync()
hook.databricks_conn.login = LOGIN
hook.databricks_conn.password = PASSWORD
json = {
Expand Down

0 comments on commit e41bf96

Please sign in to comment.