Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add a batching queue implementation. #10017

Merged
merged 10 commits into from
May 21, 2021
Merged
1 change: 1 addition & 0 deletions changelog.d/10017.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a batching queue implementation.
137 changes: 137 additions & 0 deletions synapse/util/batching_queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import (
Awaitable,
Callable,
Dict,
Generic,
Hashable,
List,
Set,
Tuple,
TypeVar,
)

from twisted.internet import defer

from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util import Clock

logger = logging.getLogger(__name__)


V = TypeVar("V")
R = TypeVar("R")


class BatchingQueue(Generic[V, R]):
"""A queue that batches up work, calling the provided processing function
with all pending work (for a given key).

The provided processing function will only be called once at a time for each
key.

Note that the return value of `add_to_queue` will be the return value of the
processing function that processed the given item. This means that the
returned value will likely include data for other items that were in the
batch.
"""
clokep marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self, name: str, clock: Clock, process_items: Callable[[List[V]], Awaitable[R]]
):
self._name = name
self._clock = clock

# The set of keys currently being processed.
self._processing_keys = set() # type: Set[Hashable]

# The currently pending batch of values by key, with a Deferred to call
# with the result of the corresponding `process_items` call.
self._next_values = {} # type: Dict[Hashable, List[Tuple[V, defer.Deferred]]]

# The function to call with batches of values.
self.process_items = process_items
clokep marked this conversation as resolved.
Show resolved Hide resolved

LaterGauge(
"synapse_util_batching_queue_number_queued",
"The number of items waiting in the queue across all keys",
labelnames=("name",),
caller=lambda: sum(len(v) for v in self._next_values.values()),
)

LaterGauge(
"synapse_util_batching_queue_number_of_keys",
"The number of distinct keys that have items queued",
labelnames=("name",),
caller=lambda: len(self._next_values),
)

async def add_to_queue(self, value: V, key: Hashable = ()) -> R:
clokep marked this conversation as resolved.
Show resolved Hide resolved
"""Adds the value to the queue with the given key, returning the result
of the processing function for the batch that included the given value.
"""

# First we create a defer and add it and the value to the list of
# pending items.
d = defer.Deferred()
self._next_values.setdefault(key, []).append((value, d))

# If we're not currently processing the key fire off a background
# process to start processing.
if key not in self._processing_keys:
run_as_background_process(self._name, self._process_queue, key)

return await make_deferred_yieldable(d)

async def _process_queue(self, key: Hashable) -> None:
"""A background task to repeatedly pull things off the queue for the
given key and call the `self.process_items` with the values.
"""

try:
if key in self._processing_keys:
return

self._processing_keys.add(key)

while self._next_values:
clokep marked this conversation as resolved.
Show resolved Hide resolved
# We purposefully wait a reactor tick to allow us to batch
# together requests that we're about to receive. A common
# pattern is to call `add_to_queue` multiple times at once, and
# deferring to the next reactor tick allows us to batch all of
# those up.
await self._clock.sleep(0)

next_values = self._next_values.pop(key, [])

try:
values = [value for value, _ in next_values]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A fun trick you can do is values, deferreds = zip(*next_values), doesn't really save anything here though unfortunately since you have to iterate all the Deferreds anyway.

results = await self.process_items(values)

for _, deferred in next_values:
with PreserveLoggingContext():
deferred.callback(results)

except Exception as e:
for _, deferred in next_values:
with PreserveLoggingContext():
deferred.errback(e)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this cause a Deferred to have both callback and errback called on it if e.g. the second deferred.callback causes an error to be raised?

(I wonder if we should be guarding with if not deferred.called?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think deferred.callback should ever explode, but may as well add the guard.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would hope not. Another option would be to move the callback calls to an else so that only calling the processing function is in the try clause?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I prefer it as it is, as then we really should always call either .errback or .callback for each?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, isn't that what my suggestion guarantees though?

try:
    result = processing_callback(values)
except Exception as e:
    for d in deferreds:
        d.errback(e)
else:
    for d in deferreds:
        d.callback(result)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if d.callback never throws then that's the same as

try:
    result = processing_callback(values)

    for d in deferreds:
        d.callback(result)
except Exception as e:
    for d in deferreds:
        d.errback(e)

?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. 😄

If it does throw it has the benefit of causing a stack trace and letting us know about it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


finally:
self._processing_keys.discard(key)
150 changes: 150 additions & 0 deletions tests/util/test_batching_queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer

from synapse.logging.context import make_deferred_yieldable
from synapse.util.batching_queue import BatchingQueue

from tests.server import get_clock
from tests.unittest import TestCase


class BatchingQueueTestCase(TestCase):
def setUp(self):
self.clock, hs_clock = get_clock()

self._pending_calls = []
self.queue = BatchingQueue("test_queue", hs_clock, self._process_queue)

async def _process_queue(self, values):
d = defer.Deferred()
self._pending_calls.append((values, d))
return await make_deferred_yieldable(d)

def test_simple(self):
"""Tests the basic case of calling `add_to_queue` once and having
`_process_queue` return.
"""

self.assertFalse(self._pending_calls)

queue_d = defer.ensureDeferred(self.queue.add_to_queue("foo"))

# The queue should wait a reactor tick before calling the processing
# function.
self.assertFalse(self._pending_calls)
self.assertFalse(queue_d.called)

# We should see a call to `_process_queue` after a reactor tick.
self.clock.pump([0])

self.assertEqual(len(self._pending_calls), 1)
self.assertEqual(self._pending_calls[0][0], ["foo"])
self.assertFalse(queue_d.called)

# Return value of the `_process_queue` should be propagated back.
self._pending_calls.pop()[1].callback("bar")

self.assertEqual(self.successResultOf(queue_d), "bar")

def test_batching(self):
"""Test that multiple calls at the same time get batched up into one
call to `_process_queue`.
"""

self.assertFalse(self._pending_calls)

queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1"))
queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2"))

self.clock.pump([0])

# We should see only *one* call to `_process_queue`
self.assertEqual(len(self._pending_calls), 1)
self.assertEqual(self._pending_calls[0][0], ["foo1", "foo2"])
self.assertFalse(queue_d1.called)
self.assertFalse(queue_d2.called)

# Return value of the `_process_queue` should be propagated back to both.
self._pending_calls.pop()[1].callback("bar")

self.assertEqual(self.successResultOf(queue_d1), "bar")
self.assertEqual(self.successResultOf(queue_d2), "bar")

def test_queuing(self):
"""Test that we queue up requests while a `_process_queue` is being
called.
"""

self.assertFalse(self._pending_calls)

queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1"))
self.clock.pump([0])

queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2"))

# We should see only *one* call to `_process_queue`
self.assertEqual(len(self._pending_calls), 1)
self.assertEqual(self._pending_calls[0][0], ["foo1"])
self.assertFalse(queue_d1.called)
self.assertFalse(queue_d2.called)

# Return value of the `_process_queue` should be propagated back to the
# first.
self._pending_calls.pop()[1].callback("bar1")

self.assertEqual(self.successResultOf(queue_d1), "bar1")
self.assertFalse(queue_d2.called)

# We should now see a second call to `_process_queue`
self.clock.pump([0])
self.assertEqual(len(self._pending_calls), 1)
self.assertEqual(self._pending_calls[0][0], ["foo2"])
self.assertFalse(queue_d2.called)

# Return value of the `_process_queue` should be propagated back to the
# second.
self._pending_calls.pop()[1].callback("bar2")

self.assertEqual(self.successResultOf(queue_d2), "bar2")

def test_different_keys(self):
"""Test that calls to different keys get processed in parallel."""

self.assertFalse(self._pending_calls)

queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1", key=1))
self.clock.pump([0])
queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2", key=2))
self.clock.pump([0])

# We should see two calls to `_process_queue`
self.assertEqual(len(self._pending_calls), 2)
self.assertEqual(self._pending_calls[0][0], ["foo1"])
self.assertEqual(self._pending_calls[1][0], ["foo2"])
self.assertFalse(queue_d1.called)
self.assertFalse(queue_d2.called)

# Return value of the `_process_queue` should be propagated back to the
# first.
self._pending_calls.pop(0)[1].callback("bar1")

self.assertEqual(self.successResultOf(queue_d1), "bar1")
self.assertFalse(queue_d2.called)

# Return value of the `_process_queue` should be propagated back to the
# second.
self._pending_calls.pop()[1].callback("bar2")

self.assertEqual(self.successResultOf(queue_d2), "bar2")