From b6ca84701e278667bd62c829f7b1f781d27555fe Mon Sep 17 00:00:00 2001 From: Donal Burns <56016914+Don-Burns@users.noreply.github.com> Date: Thu, 15 Feb 2024 15:09:20 +0000 Subject: [PATCH] Fix SQLThresholdCheckOperator error on falsey vals (#37150) * Fix SQLThresholdCheckOperator error on falsey vals If the user's query returned a "falsey" value e.g. 0 an exception would be falsely raised * fixup! Fix SQLThresholdCheckOperator error on falsey vals --- airflow/providers/common/sql/operators/sql.py | 9 +++++-- .../common/sql/operators/test_sql.py | 25 +++++++++++++++++-- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/airflow/providers/common/sql/operators/sql.py b/airflow/providers/common/sql/operators/sql.py index 78f29e3464435..d3f5b5b1f0683 100644 --- a/airflow/providers/common/sql/operators/sql.py +++ b/airflow/providers/common/sql/operators/sql.py @@ -1058,8 +1058,13 @@ def __init__( def execute(self, context: Context): hook = self.get_db_hook() - result = hook.get_first(self.sql)[0] - if not result: + result = hook.get_first(self.sql) + + # if the query returns 0 rows result will be None so cannot be indexed into + # also covers indexing out of bounds on empty list, tuple etc. if returned + try: + result = result[0] + except (TypeError, IndexError): self._raise_exception(f"The following query returned zero rows: {self.sql}") min_threshold = _convert_to_float_if_possible(self.min_threshold) diff --git a/tests/providers/common/sql/operators/test_sql.py b/tests/providers/common/sql/operators/test_sql.py index 3bd12fd7d627a..97ede82079426 100644 --- a/tests/providers/common/sql/operators/test_sql.py +++ b/tests/providers/common/sql/operators/test_sql.py @@ -901,12 +901,21 @@ def test_fail_min_sql_max_sql(self, mock_get_db_hook): operator.execute(context=MagicMock()) @mock.patch.object(SQLThresholdCheckOperator, "get_db_hook") - def test_pass_min_value_max_sql(self, mock_get_db_hook): + @pytest.mark.parametrize( + ("sql", "min_threshold", "max_threshold"), + ( + ("Select 75", 45, "Select 100"), + # check corner-case if result of query is "falsey" does not raise error + ("Select 0", 0, 1), + ("Select 1", 0, 1), + ), + ) + def test_pass_min_value_max_sql(self, mock_get_db_hook, sql, min_threshold, max_threshold): mock_hook = mock.Mock() mock_hook.get_first.side_effect = lambda x: (int(x.split()[1]),) mock_get_db_hook.return_value = mock_hook - operator = self._construct_operator("Select 75", 45, "Select 100") + operator = self._construct_operator(sql, min_threshold, max_threshold) operator.execute(context=MagicMock()) @@ -921,6 +930,18 @@ def test_fail_min_sql_max_value(self, mock_get_db_hook): with pytest.raises(AirflowException, match="155.*45.*100.0"): operator.execute(context=MagicMock()) + @mock.patch.object(SQLThresholdCheckOperator, "get_db_hook") + def test_fail_if_query_returns_no_rows(self, mock_get_db_hook): + mock_hook = mock.Mock() + mock_hook.get_first.return_value = None + mock_get_db_hook.return_value = mock_hook + + sql = "Select val from table1 where val = 'val not in table'" + operator = self._construct_operator(sql, 20, 100) + + with pytest.raises(AirflowException, match=f"The following query returned zero rows: {sql}"): + operator.execute(context=MagicMock()) + @pytest.mark.db_test class TestSqlBranch: