Skip to content

Commit

Permalink
Fix SQLThresholdCheckOperator error on falsey vals (#37150)
Browse files Browse the repository at this point in the history
* Fix SQLThresholdCheckOperator error on falsey vals

If the user's query returned a "falsey" value e.g. 0 an exception would be falsely raised

* fixup! Fix SQLThresholdCheckOperator error on falsey vals
  • Loading branch information
Don-Burns authored Feb 15, 2024
1 parent 0be6430 commit b6ca847
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
9 changes: 7 additions & 2 deletions airflow/providers/common/sql/operators/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,8 +1058,13 @@ def __init__(

def execute(self, context: Context):
hook = self.get_db_hook()
result = hook.get_first(self.sql)[0]
if not result:
result = hook.get_first(self.sql)

# if the query returns 0 rows result will be None so cannot be indexed into
# also covers indexing out of bounds on empty list, tuple etc. if returned
try:
result = result[0]
except (TypeError, IndexError):
self._raise_exception(f"The following query returned zero rows: {self.sql}")

min_threshold = _convert_to_float_if_possible(self.min_threshold)
Expand Down
25 changes: 23 additions & 2 deletions tests/providers/common/sql/operators/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,12 +901,21 @@ def test_fail_min_sql_max_sql(self, mock_get_db_hook):
operator.execute(context=MagicMock())

@mock.patch.object(SQLThresholdCheckOperator, "get_db_hook")
def test_pass_min_value_max_sql(self, mock_get_db_hook):
@pytest.mark.parametrize(
("sql", "min_threshold", "max_threshold"),
(
("Select 75", 45, "Select 100"),
# check corner-case if result of query is "falsey" does not raise error
("Select 0", 0, 1),
("Select 1", 0, 1),
),
)
def test_pass_min_value_max_sql(self, mock_get_db_hook, sql, min_threshold, max_threshold):
mock_hook = mock.Mock()
mock_hook.get_first.side_effect = lambda x: (int(x.split()[1]),)
mock_get_db_hook.return_value = mock_hook

operator = self._construct_operator("Select 75", 45, "Select 100")
operator = self._construct_operator(sql, min_threshold, max_threshold)

operator.execute(context=MagicMock())

Expand All @@ -921,6 +930,18 @@ def test_fail_min_sql_max_value(self, mock_get_db_hook):
with pytest.raises(AirflowException, match="155.*45.*100.0"):
operator.execute(context=MagicMock())

@mock.patch.object(SQLThresholdCheckOperator, "get_db_hook")
def test_fail_if_query_returns_no_rows(self, mock_get_db_hook):
mock_hook = mock.Mock()
mock_hook.get_first.return_value = None
mock_get_db_hook.return_value = mock_hook

sql = "Select val from table1 where val = 'val not in table'"
operator = self._construct_operator(sql, 20, 100)

with pytest.raises(AirflowException, match=f"The following query returned zero rows: {sql}"):
operator.execute(context=MagicMock())


@pytest.mark.db_test
class TestSqlBranch:
Expand Down

0 comments on commit b6ca847

Please sign in to comment.