Skip to content

Commit

Permalink
Add hook_params in BaseSqlOperator (#18718)
Browse files Browse the repository at this point in the history
  • Loading branch information
denimalpaca authored Nov 15, 2021
1 parent 6ef44b6 commit ccb9d04
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 5 deletions.
8 changes: 5 additions & 3 deletions airflow/models/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,8 @@ def rotate_fernet_key(self):
if self._extra and self.is_extra_encrypted:
self._extra = fernet.rotate(self._extra.encode('utf-8')).decode()

def get_hook(self):
"""Return hook based on conn_type."""
def get_hook(self, *, hook_kwargs=None):
"""Return hook based on conn_type"""
(
hook_class_name,
conn_id_param,
Expand All @@ -304,7 +304,9 @@ def get_hook(self):
"Could not import %s when discovering %s %s", hook_class_name, hook_name, package_name
)
raise
return hook_class(**{conn_id_param: self.conn_id})
if hook_kwargs is None:
hook_kwargs = {}
return hook_class(**{conn_id_param: self.conn_id}, **hook_kwargs)

def __repr__(self):
return self.conn_id
Expand Down
12 changes: 10 additions & 2 deletions airflow/operators/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,26 @@ class BaseSQLOperator(BaseOperator):
You can custom the behavior by overriding the .get_db_hook() method.
"""

def __init__(self, *, conn_id: Optional[str] = None, database: Optional[str] = None, **kwargs):
def __init__(
self,
*,
conn_id: Optional[str] = None,
database: Optional[str] = None,
hook_params: Optional[Dict] = None,
**kwargs,
):
super().__init__(**kwargs)
self.conn_id = conn_id
self.database = database
self.hook_params = {} if hook_params is None else hook_params

@cached_property
def _hook(self):
"""Get DB Hook based on connection type"""
self.log.debug("Get connection for %s", self.conn_id)
conn = BaseHook.get_connection(self.conn_id)

hook = conn.get_hook()
hook = conn.get_hook(hook_kwargs=self.hook_params)
if not isinstance(hook, DbApiHook):
raise AirflowException(
f'The connection type is not supported by {self.__class__.__name__}. '
Expand Down
23 changes: 23 additions & 0 deletions tests/operators/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,29 @@ def test_not_allowed_conn_type(self, mock_get_conn):
with pytest.raises(AirflowException, match=r"The connection type is not supported"):
self._operator._hook

def test_sql_operator_hook_params_snowflake(self, mock_get_conn):
mock_get_conn.return_value = Connection(conn_id='snowflake_default', conn_type='snowflake')
self._operator.hook_params = {
'warehouse': 'warehouse',
'database': 'database',
'role': 'role',
'schema': 'schema',
}
assert self._operator._hook.conn_type == 'snowflake'
assert self._operator._hook.warehouse == 'warehouse'
assert self._operator._hook.database == 'database'
assert self._operator._hook.role == 'role'
assert self._operator._hook.schema == 'schema'

def test_sql_operator_hook_params_biguery(self, mock_get_conn):
mock_get_conn.return_value = Connection(
conn_id='google_cloud_bigquery_default', conn_type='gcpbigquery'
)
self._operator.hook_params = {'use_legacy_sql': True, 'location': 'us-east1'}
assert self._operator._hook.conn_type == 'gcpbigquery'
assert self._operator._hook.use_legacy_sql
assert self._operator._hook.location == 'us-east1'


class TestCheckOperator(unittest.TestCase):
def setUp(self):
Expand Down

0 comments on commit ccb9d04

Please sign in to comment.