Skip to content

Commit

Permalink
Add hook_params in SqlSensor using the latest changes from PR apache#…
Browse files Browse the repository at this point in the history
  • Loading branch information
kazanzhy authored and Dillon Johnson committed Dec 1, 2021
1 parent f734dcc commit 1cf3291
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
17 changes: 15 additions & 2 deletions airflow/sensors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class SqlSensor(BaseSensorOperator):
:type failure: Optional<Callable[[Any], bool]>
:param fail_on_empty: Explicitly fail on no rows returned.
:type fail_on_empty: bool
:param hook_params: Extra config params to be passed to the underlying hook.
Should match the desired hook constructor params.
:type hook_params: dict
"""

template_fields: Iterable[str] = ('sql',)
Expand All @@ -58,14 +61,24 @@ class SqlSensor(BaseSensorOperator):
ui_color = '#7c7287'

def __init__(
self, *, conn_id, sql, parameters=None, success=None, failure=None, fail_on_empty=False, **kwargs
self,
*,
conn_id,
sql,
parameters=None,
success=None,
failure=None,
fail_on_empty=False,
hook_params=None,
**kwargs,
):
self.conn_id = conn_id
self.sql = sql
self.parameters = parameters
self.success = success
self.failure = failure
self.fail_on_empty = fail_on_empty
self.hook_params = hook_params
super().__init__(**kwargs)

def _get_hook(self):
Expand All @@ -90,7 +103,7 @@ def _get_hook(self):
f"Connection type ({conn.conn_type}) is not supported by SqlSensor. "
+ f"Supported connection types: {list(allowed_conn_type)}"
)
return conn.get_hook()
return conn.get_hook(hook_kwargs=self.hook_params)

def poke(self, context):
hook = self._get_hook()
Expand Down
12 changes: 12 additions & 0 deletions tests/sensors/test_sql_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,15 @@ def test_sql_sensor_presto(self):
dag=self.dag,
)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

def test_sql_sensor_hook_params(self):
op = SqlSensor(
task_id='sql_sensor_hook_params',
conn_id='google_cloud_default',
sql="SELECT 1",
hook_params={
'delegate_to': 'me',
},
)
hook = op._get_hook()
assert hook.delegate_to == 'me'

0 comments on commit 1cf3291

Please sign in to comment.