Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix bug where a new writer advances their token too quickly #16473

Merged
merged 10 commits into from
Oct 23, 2023
43 changes: 29 additions & 14 deletions synapse/storage/util/id_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,11 @@ def __init__(
# The maximum stream ID that we have seen been allocated across any writer.
self._max_seen_allocated_stream_id = 1

# The maximum position of the local instance. This can be higher than
# the corresponding position in `current_positions` table when there are
# no active writes in progress.
self._max_position_of_local_instance = self._max_seen_allocated_stream_id

Comment on lines +423 to +427
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Err, _max_position_of_local_instance isn't read before it gets rewritten on 455. Does that make this block redundant? (Apart from the comment)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeeeeeeeeeeeeeeeah, potentially. We could move the definition further down, but I'm a bit cautious about constructors calling functions on itself without fully declaring all the fields 🤷

self._sequence_gen = PostgresSequenceGenerator(sequence_name)

# We check that the table and sequence haven't diverged.
Expand All @@ -439,6 +444,16 @@ def __init__(
self._current_positions.values(), default=1
)

# For the case where `stream_positions` is not up to date,
# `_persisted_upto_position` may be higher.
self._max_seen_allocated_stream_id = max(
self._max_seen_allocated_stream_id, self._persisted_upto_position
)

# Bump our local maximum position now that we've loaded things from the
# DB.
self._max_position_of_local_instance = self._max_seen_allocated_stream_id

if not writers:
# If there have been no explicit writers given then any instance can
# write to the stream. In which case, let's pre-seed our own
Expand Down Expand Up @@ -708,6 +723,7 @@ def _mark_id_as_finished(self, next_id: int) -> None:
if new_cur:
curr = self._current_positions.get(self._instance_name, 0)
self._current_positions[self._instance_name] = max(curr, new_cur)
self._max_position_of_local_instance = max(curr, new_cur)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To check: this function is handling a stream entry that

  • we (this worker) are responsible for
  • we have just finished committing/persisting?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, indeed


self._add_persisted_position(next_id)

Expand All @@ -722,6 +738,9 @@ def get_current_token_for_writer(self, instance_name: str) -> int:
# persisted up to position. This stops Synapse from doing a full table
# scan when a new writer announces itself over replication.
with self._lock:
if self._instance_name == instance_name:
return self._return_factor * self._max_position_of_local_instance
Comment on lines +743 to +744
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need the lock here? I guess the lock guards _max_position_of_local_instance... but I'm not sure what that buys us?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We lock on all functions, as they can be called from DB threads. It's probably safe to move this particular thing out of the lock, but that's only true because it's a trivial expression (only _max_position_of_local_instance can change). With locks I'd kinda prefer to keep them consistent and just have everything in them


pos = self._current_positions.get(
instance_name, self._persisted_upto_position
)
Expand All @@ -731,20 +750,6 @@ def get_current_token_for_writer(self, instance_name: str) -> int:
# possible.
pos = max(pos, self._persisted_upto_position)

if (
self._instance_name == instance_name
and not self._in_flight_fetches
and not self._unfinished_ids
):
# For our own instance when there's nothing in flight, it's safe
# to advance to the maximum persisted position we've seen (as we
# know that any new tokens we request will be greater).
max_pos_of_all_writers = max(
self._current_positions.values(),
default=self._persisted_upto_position,
)
pos = max(pos, max_pos_of_all_writers)

return self._return_factor * pos
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved

def get_minimal_local_current_token(self) -> int:
Expand Down Expand Up @@ -821,6 +826,16 @@ def _add_persisted_position(self, new_id: int) -> None:

self._persisted_upto_position = max(min_curr, self._persisted_upto_position)

# Advance our local max position.
self._max_position_of_local_instance = max(
self._max_position_of_local_instance, self._persisted_upto_position
)

if not self._unfinished_ids and not self._in_flight_fetches:
# If we don't have anything in flight, it's safe to advance to the
# max seen stream ID.
self._max_position_of_local_instance = self._max_seen_allocated_stream_id
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved

# We now iterate through the seen positions, discarding those that are
# less than the current min positions, and incrementing the min position
# if its exactly one greater.
Expand Down
51 changes: 51 additions & 0 deletions tests/storage/test_id_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,57 @@ def test_minimal_local_token(self) -> None:
self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(second_id_gen.get_minimal_local_current_token(), 7)

def test_current_token_gap(self) -> None:
"""Test that getting the current token for a writer returns the maximal
token when there are no writes.
"""
self._insert_rows("first", 3)
self._insert_rows("second", 4)

first_id_gen = self._create_id_generator(
"first", writers=["first", "second", "third"]
)
second_id_gen = self._create_id_generator(
"second", writers=["first", "second", "third"]
)

self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
self.assertEqual(second_id_gen.get_current_token(), 7)

# Check that the first ID gen advancing causes the second ID gen to
# advance (as it has nothing in flight).
Copy link
Contributor

@DMRobertson DMRobertson Oct 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Err, what mechanism ensures that the second ID gen sees the new facts from this? Is there a rdis running behind the scenes?

EDIT: oh, we call advance manually.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    # advance (as it has nothing in flight).

Ambiguous: s/it/the second/ please! Ditto below on 697.


async def _get_next_async() -> None:
async with first_id_gen.get_next_mult(2):
pass

self.get_success(_get_next_async())
second_id_gen.advance("first", 9)

self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 9)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 9)
self.assertEqual(second_id_gen.get_current_token(), 7)

# Check that the first ID gen advancing doesn't advance the second ID
# gen when it has stuff in flight.
self.get_success(_get_next_async())

ctxmgr = second_id_gen.get_next()
self.get_success(ctxmgr.__aenter__())

second_id_gen.advance("first", 11)

self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 11)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 9)
self.assertEqual(second_id_gen.get_current_token(), 7)

self.get_success(ctxmgr.__aexit__(None, None, None))

self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 11)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 12)
self.assertEqual(second_id_gen.get_current_token(), 7)


class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""Tests MultiWriterIdGenerator that produce *negative* stream IDs."""
Expand Down
Loading