Skip to content

Commit

Permalink
Rollback for all retry exceptions (apache#40882)
Browse files Browse the repository at this point in the history
In apache#19856, we added `DBAPIError` besides `OperationalError` to the retry exception types, but did not change the `retry_db_transaction` decorator to rollback transaction after failures and before a retry.

In certain cases (see apache#40882), this is needed as otherwise all retries will fail when the current session/transaction was "poisened" by the initial error.
  • Loading branch information
jmaicher committed Jul 19, 2024
1 parent afde88a commit 62e1fb2
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
8 changes: 4 additions & 4 deletions airflow/utils/retries.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from inspect import signature
from typing import Callable, TypeVar, overload

from sqlalchemy.exc import DBAPIError, OperationalError
from sqlalchemy.exc import DBAPIError

from airflow.configuration import conf

Expand All @@ -36,7 +36,7 @@ def run_with_db_retries(max_retries: int = MAX_DB_RETRIES, logger: logging.Logge

# Default kwargs
retry_kwargs = dict(
retry=tenacity.retry_if_exception_type(exception_types=(OperationalError, DBAPIError)),
retry=tenacity.retry_if_exception_type(exception_types=(DBAPIError)),
wait=tenacity.wait_random_exponential(multiplier=0.5, max=5),
stop=tenacity.stop_after_attempt(max_retries),
reraise=True,
Expand All @@ -58,7 +58,7 @@ def retry_db_transaction(_func: F) -> F: ...

def retry_db_transaction(_func: Callable | None = None, *, retries: int = MAX_DB_RETRIES, **retry_kwargs):
"""
Retry functions in case of ``OperationalError`` from DB.
Retry functions in case of ``DBAPIError`` from DB.
It should not be used with ``@provide_session``.
"""
Expand Down Expand Up @@ -96,7 +96,7 @@ def wrapped_function(*args, **kwargs):
)
try:
return func(*args, **kwargs)
except OperationalError:
except DBAPIError:
session.rollback()
raise

Expand Down
15 changes: 10 additions & 5 deletions tests/utils/test_retries.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING
from unittest import mock

import pytest
from sqlalchemy.exc import OperationalError
from sqlalchemy.exc import InternalError, OperationalError

from airflow.utils.retries import retry_db_transaction

if TYPE_CHECKING:
from sqlalchemy.exc import DBAPIError


class TestRetries:
def test_retry_db_transaction_with_passing_retries(self):
Expand All @@ -45,23 +49,24 @@ def test_function(session):
assert mock_obj.call_count == 2

@pytest.mark.db_test
def test_retry_db_transaction_with_default_retries(self, caplog):
@pytest.mark.parametrize("excection_type", [OperationalError, InternalError])
def test_retry_db_transaction_with_default_retries(self, caplog, excection_type: type[DBAPIError]):
"""Test that by default 3 retries will be carried out"""
mock_obj = mock.MagicMock()
mock_session = mock.MagicMock()
mock_rollback = mock.MagicMock()
mock_session.rollback = mock_rollback
op_error = OperationalError(statement=mock.ANY, params=mock.ANY, orig=mock.ANY)
db_error = excection_type(statement=mock.ANY, params=mock.ANY, orig=mock.ANY)

@retry_db_transaction
def test_function(session):
session.execute("select 1")
mock_obj(2)
raise op_error
raise db_error

caplog.set_level(logging.DEBUG, logger=self.__module__)
caplog.clear()
with pytest.raises(OperationalError):
with pytest.raises(excection_type):
test_function(session=mock_session)

for try_no in range(1, 4):
Expand Down

0 comments on commit 62e1fb2

Please sign in to comment.