Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Snowflake DQ Operators #17741

Merged
merged 12 commits into from
Sep 9, 2021
Merged
284 changes: 283 additions & 1 deletion airflow/providers/snowflake/operators/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Optional
from typing import Any, Optional, SupportsAbs

from airflow.models import BaseOperator
from airflow.operators.sql import SQLCheckOperator, SQLIntervalCheckOperator, SQLValueCheckOperator
from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook


Expand Down Expand Up @@ -125,3 +126,284 @@ def execute(self, context: Any) -> None:

if self.do_xcom_push:
return execution_info


class _SnowflakeDbHookMixin:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why a class and not just a method inside the SnowflakeCheckOperator?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to be more DRY. Originally I tried making the Check classes inherit from just the SnowflakeOperator, but it resulted in an error in the constructors.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DRY philosophy is great. However, my point was that you don't need a class at all and you can get rid of the class extra layer of abstraction. My comment is inspired by https://www.youtube.com/watch?v=o9pEzgHorH0

Using the same class too in the SnowflakeOperator has been quite an improvement, but I think you can still have just a function in the file, and you can just call it from the operators. Anyway, it's not a big deal and the PR is already approved, so up to u.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wached the presentation :). Entertaining and I mostly agree with it (not everything :)

But yeah - in this case using mixin is definitely over-the-top.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind changing that @denimalpaca ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will have this fixed Tuesday, looks like with the way BaseSQLOperator works, even the function without a class might not be necessary.

def get_db_hook(self) -> SnowflakeHook:
"""
Create and return SnowflakeHook.
:return: a SnowflakeHook instance.
mik-laj marked this conversation as resolved.
Show resolved Hide resolved
:rtype: SnowflakeHook
"""
return SnowflakeHook(
snowflake_conn_id=self.snowflake_conn_id,
warehouse=self.warehouse,
database=self.database,
role=self.role,
schema=self.schema,
authenticator=self.authenticator,
session_parameters=self.session_parameters,
)


class SnowflakeCheckOperator(_SnowflakeDbHookMixin, SQLCheckOperator):
"""
Performs a check against Snowflake. The ``SnowflakeCheckOperator`` expects
a sql query that will return a single row. Each value on that
first row is evaluated using python ``bool`` casting. If any of the
values return ``False`` the check is failed and errors out.

Note that Python bool casting evals the following as ``False``:

* ``False``
* ``0``
* Empty string (``""``)
* Empty list (``[]``)
* Empty dictionary or set (``{}``)

Given a query like ``SELECT COUNT(*) FROM foo``, it will fail only if
the count ``== 0``. You can craft much more complex query that could,
for instance, check that the table has the same number of rows as
the source table upstream, or that the count of today's partition is
greater than yesterday's partition, or that a set of metrics are less
than 3 standard deviation for the 7 day average.

This operator can be used as a data quality check in your pipeline, and
depending on where you put it in your DAG, you have the choice to
stop the critical path, preventing from
publishing dubious data, or on the side and receive email alerts
without stopping the progress of the DAG.

:param sql: the sql code to be executed. (templated)
:type sql: Can receive a str representing a sql statement,
a list of str (sql statements), or reference to a template file.
Template reference are recognized by str ending in '.sql'
:param snowflake_conn_id: Reference to
:ref:`Snowflake connection id<howto/connection:snowflake>`
:type snowflake_conn_id: str
:param autocommit: if True, each command is automatically committed.
potiuk marked this conversation as resolved.
Show resolved Hide resolved
(default value: True)
:type autocommit: bool
:param parameters: (optional) the parameters to render the SQL query with.
:type parameters: dict or iterable
:param warehouse: name of warehouse (will overwrite any warehouse
defined in the connection's extra JSON)
:type warehouse: str
:param database: name of database (will overwrite database defined
in connection)
:type database: str
:param schema: name of schema (will overwrite schema defined in
connection)
:type schema: str
:param role: name of role (will overwrite any role defined in
connection's extra JSON)
:type role: str
:param authenticator: authenticator for Snowflake.
'snowflake' (default) to use the internal Snowflake authenticator
'externalbrowser' to authenticate using your web browser and
Okta, ADFS or any other SAML 2.0-compliant identify provider
(IdP) that has been defined for your account
'https://<your_okta_account_name>.okta.com' to authenticate
through native Okta.
:type authenticator: str
:param session_parameters: You can set session-level parameters at
the time you connect to Snowflake
:type session_parameters: dict
"""

template_fields = ('sql',)
template_ext = ('.sql',)
ui_color = '#ededed'

def __init__(
self,
*,
sql: Any,
snowflake_conn_id: str = 'snowflake_default',
parameters: Optional[dict] = None,
autocommit: bool = True,
do_xcom_push: bool = True,
warehouse: Optional[str] = None,
database: Optional[str] = None,
role: Optional[str] = None,
schema: Optional[str] = None,
authenticator: Optional[str] = None,
session_parameters: Optional[dict] = None,
**kwargs,
) -> None:
super().__init__(sql=sql, **kwargs)
self.snowflake_conn_id = snowflake_conn_id
self.sql = sql
self.autocommit = autocommit
self.do_xcom_push = do_xcom_push
self.parameters = parameters
self.warehouse = warehouse
self.database = database
self.role = role
self.schema = schema
self.authenticator = authenticator
self.session_parameters = session_parameters
self.query_ids = []


class SnowflakeValueCheckOperator(_SnowflakeDbHookMixin, SQLValueCheckOperator):
"""
Performs a simple value check using sql code.
denimalpaca marked this conversation as resolved.
Show resolved Hide resolved

:param sql: the sql to be executed
:type sql: str
:param snowflake_conn_id: Reference to
:ref:`Snowflake connection id<howto/connection:snowflake>`
:type snowflake_conn_id: str
:param autocommit: if True, each command is automatically committed.
(default value: True)
:type autocommit: bool
:param parameters: (optional) the parameters to render the SQL query with.
:type parameters: dict or iterable
:param warehouse: name of warehouse (will overwrite any warehouse
defined in the connection's extra JSON)
:type warehouse: str
:param database: name of database (will overwrite database defined
in connection)
:type database: str
:param schema: name of schema (will overwrite schema defined in
connection)
:type schema: str
:param role: name of role (will overwrite any role defined in
connection's extra JSON)
:type role: str
:param authenticator: authenticator for Snowflake.
'snowflake' (default) to use the internal Snowflake authenticator
'externalbrowser' to authenticate using your web browser and
Okta, ADFS or any other SAML 2.0-compliant identify provider
(IdP) that has been defined for your account
'https://<your_okta_account_name>.okta.com' to authenticate
through native Okta.
:type authenticator: str
:param session_parameters: You can set session-level parameters at
the time you connect to Snowflake
:type session_parameters: dict
"""

def __init__(
self,
*,
sql: Any,
denimalpaca marked this conversation as resolved.
Show resolved Hide resolved
pass_value: Any,
tolerance: Any = None,
snowflake_conn_id: str = 'snowflake_default',
parameters: Optional[dict] = None,
autocommit: bool = True,
do_xcom_push: bool = True,
warehouse: Optional[str] = None,
database: Optional[str] = None,
role: Optional[str] = None,
schema: Optional[str] = None,
authenticator: Optional[str] = None,
session_parameters: Optional[dict] = None,
**kwargs,
) -> None:
super().__init__(sql=sql, pass_value=pass_value, tolerance=tolerance, **kwargs)
self.snowflake_conn_id = snowflake_conn_id
self.sql = sql
self.autocommit = autocommit
self.do_xcom_push = do_xcom_push
self.parameters = parameters
self.warehouse = warehouse
self.database = database
self.role = role
self.schema = schema
self.authenticator = authenticator
self.session_parameters = session_parameters
self.query_ids = []


class SnowflakeIntervalCheckOperator(_SnowflakeDbHookMixin, SQLIntervalCheckOperator):
"""
Checks that the values of metrics given as SQL expressions are within
a certain tolerance of the ones from days_back before.

This method constructs a query like so ::

SELECT {metrics_threshold_dict_key} FROM {table}
WHERE {date_filter_column}=<date>

:param table: the table name
:type table: str
:param days_back: number of days between ds and the ds we want to check
denimalpaca marked this conversation as resolved.
Show resolved Hide resolved
against. Defaults to 7 days
:type days_back: int
:param metrics_thresholds: a dictionary of ratios indexed by metrics, for
example 'COUNT(*)': 1.5 would require a 50 percent or less difference
between the current day, and the prior days_back.
:type metrics_thresholds: dict
:param snowflake_conn_id: Reference to
:ref:`Snowflake connection id<howto/connection:snowflake>`
:type snowflake_conn_id: str
:param autocommit: if True, each command is automatically committed.
(default value: True)
:type autocommit: bool
:param parameters: (optional) the parameters to render the SQL query with.
:type parameters: dict or iterable
:param warehouse: name of warehouse (will overwrite any warehouse
defined in the connection's extra JSON)
:type warehouse: str
:param database: name of database (will overwrite database defined
in connection)
:type database: str
:param schema: name of schema (will overwrite schema defined in
connection)
:type schema: str
:param role: name of role (will overwrite any role defined in
connection's extra JSON)
:type role: str
:param authenticator: authenticator for Snowflake.
'snowflake' (default) to use the internal Snowflake authenticator
'externalbrowser' to authenticate using your web browser and
Okta, ADFS or any other SAML 2.0-compliant identify provider
(IdP) that has been defined for your account
'https://<your_okta_account_name>.okta.com' to authenticate
through native Okta.
:type authenticator: str
:param session_parameters: You can set session-level parameters at
the time you connect to Snowflake
:type session_parameters: dict
"""

def __init__(
self,
*,
table: str,
metrics_thresholds: dict,
date_filter_column: str = 'ds',
days_back: SupportsAbs[int] = -7,
snowflake_conn_id: str = 'snowflake_default',
parameters: Optional[dict] = None,
autocommit: bool = True,
do_xcom_push: bool = True,
warehouse: Optional[str] = None,
database: Optional[str] = None,
role: Optional[str] = None,
schema: Optional[str] = None,
authenticator: Optional[str] = None,
session_parameters: Optional[dict] = None,
**kwargs,
) -> None:
super().__init__(
table=table,
metrics_thresholds=metrics_thresholds,
date_filter_column=date_filter_column,
days_back=days_back,
**kwargs,
)
self.snowflake_conn_id = snowflake_conn_id
self.autocommit = autocommit
self.do_xcom_push = do_xcom_push
self.parameters = parameters
self.warehouse = warehouse
self.database = database
self.role = role
self.schema = schema
self.authenticator = authenticator
self.session_parameters = session_parameters
self.query_ids = []
30 changes: 29 additions & 1 deletion tests/providers/snowflake/operators/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,15 @@
import unittest
from unittest import mock

import pytest

from airflow.models.dag import DAG
from airflow.providers.snowflake.operators.snowflake import SnowflakeOperator
from airflow.providers.snowflake.operators.snowflake import (
SnowflakeCheckOperator,
SnowflakeIntervalCheckOperator,
SnowflakeOperator,
SnowflakeValueCheckOperator,
)
from airflow.utils import timezone

DEFAULT_DATE = timezone.datetime(2015, 1, 1)
Expand Down Expand Up @@ -48,3 +55,24 @@ def test_snowflake_operator(self, mock_get_hook):
operator = SnowflakeOperator(task_id='basic_snowflake', sql=sql, dag=self.dag, do_xcom_push=False)
# do_xcom_push=False because otherwise the XCom test will fail due to the mocking (it actually works)
operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)


@pytest.mark.parametrize(
"operator_class, kwargs",
[
(SnowflakeCheckOperator, dict(sql='Select * from test_table')),
(SnowflakeValueCheckOperator, dict(sql='Select * from test_table', pass_value=95)),
(SnowflakeIntervalCheckOperator, dict(table='test-table-id', metrics_thresholds={'COUNT(*)': 1.5})),
],
)
class TestSnowflakeCheckOperators:
@mock.patch("airflow.providers.snowflake.operators.snowflake._SnowflakeDbHookMixin.get_db_hook")
def test_get_db_hook(
self,
mock_get_db_hook,
operator_class,
kwargs,
):
operator = operator_class(task_id='snowflake_check', snowflake_conn_id='snowflake_default', **kwargs)
operator.get_db_hook()
mock_get_db_hook.assert_called_once()