Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Handle cancellation in DatabasePool.runInteraction() #12199

Merged
merged 6 commits into from
Mar 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12199.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Handle cancellation in `DatabasePool.runInteraction()`.
61 changes: 37 additions & 24 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from typing_extensions import Literal

from twisted.enterprise import adbapi
from twisted.internet import defer

from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig
Expand All @@ -55,6 +56,7 @@
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
from synapse.util.async_helpers import delay_cancellation
from synapse.util.iterutils import batch_iter

if TYPE_CHECKING:
Expand Down Expand Up @@ -732,34 +734,45 @@ async def runInteraction(
Returns:
The result of func
"""
after_callbacks: List[_CallbackListEntry] = []
exception_callbacks: List[_CallbackListEntry] = []

if not current_context():
logger.warning("Starting db txn '%s' from sentinel context", desc)
async def _runInteraction() -> R:
after_callbacks: List[_CallbackListEntry] = []
exception_callbacks: List[_CallbackListEntry] = []

try:
with opentracing.start_active_span(f"db.{desc}"):
result = await self.runWithConnection(
self.new_transaction,
desc,
after_callbacks,
exception_callbacks,
func,
*args,
db_autocommit=db_autocommit,
isolation_level=isolation_level,
**kwargs,
)
if not current_context():
logger.warning("Starting db txn '%s' from sentinel context", desc)

for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args, **after_kwargs)
except Exception:
for after_callback, after_args, after_kwargs in exception_callbacks:
after_callback(*after_args, **after_kwargs)
raise
try:
with opentracing.start_active_span(f"db.{desc}"):
result = await self.runWithConnection(
self.new_transaction,
desc,
after_callbacks,
exception_callbacks,
func,
*args,
db_autocommit=db_autocommit,
isolation_level=isolation_level,
**kwargs,
)

return cast(R, result)
for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args, **after_kwargs)

return cast(R, result)
except Exception:
for after_callback, after_args, after_kwargs in exception_callbacks:
after_callback(*after_args, **after_kwargs)
raise

# To handle cancellation, we ensure that `after_callback`s and
# `exception_callback`s are always run, since the transaction will complete
# on another thread regardless of cancellation.
#
# We also wait until everything above is done before releasing the
# `CancelledError`, so that logging contexts won't get used after they have been
# finished.
return await delay_cancellation(defer.ensureDeferred(_runInteraction()))

async def runWithConnection(
self,
Expand Down
58 changes: 58 additions & 0 deletions tests/storage/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from typing import Callable, Tuple
from unittest.mock import Mock, call

from twisted.internet import defer
from twisted.internet.defer import CancelledError, Deferred
from twisted.test.proto_helpers import MemoryReactor

from synapse.server import HomeServer
Expand Down Expand Up @@ -124,3 +126,59 @@ def test_successful_retry(self) -> None:
)
self.assertEqual(after_callback.call_count, 2) # no additional calls
exception_callback.assert_not_called()


class CancellationTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool

def test_after_callback(self) -> None:
"""Test that the after callback is called when a transaction succeeds."""
d: "Deferred[None]"
after_callback = Mock()
exception_callback = Mock()

def _test_txn(txn: LoggingTransaction) -> None:
txn.call_after(after_callback, 123, 456, extra=789)
txn.call_on_exception(exception_callback, 987, 654, extra=321)
d.cancel()

d = defer.ensureDeferred(
self.db_pool.runInteraction("test_transaction", _test_txn)
)
self.get_failure(d, CancelledError)

after_callback.assert_called_once_with(123, 456, extra=789)
exception_callback.assert_not_called()

def test_exception_callback(self) -> None:
"""Test that the exception callback is called when a transaction fails."""
d: "Deferred[None]"
after_callback = Mock()
exception_callback = Mock()

def _test_txn(txn: LoggingTransaction) -> None:
txn.call_after(after_callback, 123, 456, extra=789)
txn.call_on_exception(exception_callback, 987, 654, extra=321)
d.cancel()
# Simulate a retryable failure on every attempt.
raise self.db_pool.engine.module.OperationalError()

d = defer.ensureDeferred(
self.db_pool.runInteraction("test_transaction", _test_txn)
)
self.get_failure(d, CancelledError)

after_callback.assert_not_called()
exception_callback.assert_has_calls(
[
call(987, 654, extra=321),
call(987, 654, extra=321),
call(987, 654, extra=321),
call(987, 654, extra=321),
call(987, 654, extra=321),
call(987, 654, extra=321),
]
)
self.assertEqual(exception_callback.call_count, 6) # no additional calls