diff --git a/src/server/snapshot.cc b/src/server/snapshot.cc index 9034cb382ee3..b46aed03ad7c 100644 --- a/src/server/snapshot.cc +++ b/src/server/snapshot.cc @@ -70,7 +70,7 @@ void SliceSnapshot::Start(bool stream_journal, const Cancellation* cll) { VLOG(1) << "DbSaver::Start - saving entries with version less than " << snapshot_version_; snapshot_fb_ = fb2::Fiber("snapshot", [this, stream_journal, cll] { - IterateBucketsFb(cll); + IterateBucketsFb(cll, stream_journal); db_slice_->UnregisterOnChange(snapshot_version_); if (cll->IsCancelled()) { Cancel(); @@ -174,7 +174,7 @@ void SliceSnapshot::Join() { // and survived until it finished. // Serializes all the entries with version less than snapshot_version_. -void SliceSnapshot::IterateBucketsFb(const Cancellation* cll) { +void SliceSnapshot::IterateBucketsFb(const Cancellation* cll, bool send_full_sync_cut) { { auto fiber_name = absl::StrCat("SliceSnapshot-", ProactorBase::me()->GetPoolIndex()); ThisFiber::SetName(std::move(fiber_name)); @@ -223,8 +223,10 @@ void SliceSnapshot::IterateBucketsFb(const Cancellation* cll) { } // for (dbindex) CHECK(!serialize_bucket_running_); - CHECK(!serializer_->SendFullSyncCut()); - PushSerializedToChannel(true); + if (send_full_sync_cut) { + CHECK(!serializer_->SendFullSyncCut()); + PushSerializedToChannel(true); + } // serialized + side_saved must be equal to the total saved. VLOG(1) << "Exit SnapshotSerializer (loop_serialized/side_saved/cbcalls): " diff --git a/src/server/snapshot.h b/src/server/snapshot.h index 3b72e4e8c3c1..9bfaf1fdf591 100644 --- a/src/server/snapshot.h +++ b/src/server/snapshot.h @@ -88,7 +88,7 @@ class SliceSnapshot { private: // Main fiber that iterates over all buckets in the db slice // and submits them to SerializeBucket. - void IterateBucketsFb(const Cancellation* cll); + void IterateBucketsFb(const Cancellation* cll, bool send_full_sync_cut); // Called on traversing cursor by IterateBucketsFb. bool BucketSaveCb(PrimeIterator it); diff --git a/tests/dragonfly/conftest.py b/tests/dragonfly/conftest.py index 635ac565eac2..7ce464447847 100644 --- a/tests/dragonfly/conftest.py +++ b/tests/dragonfly/conftest.py @@ -21,7 +21,7 @@ from pathlib import Path from tempfile import TemporaryDirectory -from .instance import DflyInstance, DflyParams, DflyInstanceFactory +from .instance import DflyInstance, DflyParams, DflyInstanceFactory, RedisServer from . import PortPicker, dfly_args from .utility import DflySeederFactory, gen_ca_cert, gen_certificate @@ -366,3 +366,24 @@ def run_before_and_after_test(): yield # this is where the testing happens # Teardown + + +@pytest.fixture(scope="function") +def redis_server(port_picker) -> RedisServer: + s = RedisServer(port_picker.get_available_port()) + try: + s.start() + except FileNotFoundError as e: + pytest.skip("Redis server not found") + return None + time.sleep(1) + yield s + s.stop() + + +@pytest.fixture(scope="function") +def redis_local_server(port_picker) -> RedisServer: + s = RedisServer(port_picker.get_available_port()) + time.sleep(1) + yield s + s.stop() diff --git a/tests/dragonfly/instance.py b/tests/dragonfly/instance.py index 5ed043fc99a2..ed102c0c9c66 100644 --- a/tests/dragonfly/instance.py +++ b/tests/dragonfly/instance.py @@ -364,3 +364,36 @@ def stop_all(self): def __repr__(self) -> str: return f"Factory({self.args})" + + +class RedisServer: + def __init__(self, port): + self.port = port + self.proc = None + + def start(self, **kwargs): + command = [ + "redis-server-6.2.11", + f"--port {self.port}", + "--save ''", + "--appendonly no", + "--protected-mode no", + "--repl-diskless-sync yes", + "--repl-diskless-sync-delay 0", + ] + # Convert kwargs to command-line arguments + for key, value in kwargs.items(): + if value is None: + command.append(f"--{key}") + else: + command.append(f"--{key} {value}") + + self.proc = subprocess.Popen(command) + logging.debug(self.proc.args) + + def stop(self): + self.proc.terminate() + try: + self.proc.wait(timeout=10) + except Exception as e: + pass diff --git a/tests/dragonfly/redis_replication_test.py b/tests/dragonfly/redis_replication_test.py index 4305d41f899e..f0375e017c90 100644 --- a/tests/dragonfly/redis_replication_test.py +++ b/tests/dragonfly/redis_replication_test.py @@ -8,33 +8,6 @@ from .proxy import Proxy -class RedisServer: - def __init__(self, port): - self.port = port - self.proc = None - - def start(self): - self.proc = subprocess.Popen( - [ - "redis-server-6.2.11", - f"--port {self.port}", - "--save ''", - "--appendonly no", - "--protected-mode no", - "--repl-diskless-sync yes", - "--repl-diskless-sync-delay 0", - ] - ) - logging.debug(self.proc.args) - - def stop(self): - self.proc.terminate() - try: - self.proc.wait(timeout=10) - except Exception as e: - pass - - # Checks that master redis and dragonfly replica are synced by writing a random key to master # and waiting for it to exist in replica. Foreach db in 0..dbcount-1. async def await_synced(c_master: aioredis.Redis, c_replica: aioredis.Redis, dbcount=1): @@ -71,19 +44,6 @@ async def check_data(seeder, replicas, c_replicas): assert await seeder.compare(capture, port=replica.port) -@pytest.fixture(scope="function") -def redis_server(port_picker) -> RedisServer: - s = RedisServer(port_picker.get_available_port()) - try: - s.start() - except FileNotFoundError as e: - pytest.skip("Redis server not found") - return None - time.sleep(1) - yield s - s.stop() - - full_sync_replication_specs = [ ([1], dict(keys=100, dbcount=1, unsupported_types=[ValueType.JSON])), ([1], dict(keys=5000, dbcount=2, unsupported_types=[ValueType.JSON])), diff --git a/tests/dragonfly/seeder/__init__.py b/tests/dragonfly/seeder/__init__.py index d2361949e781..dbfbf820f860 100644 --- a/tests/dragonfly/seeder/__init__.py +++ b/tests/dragonfly/seeder/__init__.py @@ -18,20 +18,24 @@ class SeederBase: UID_COUNTER = 1 # multiple generators should not conflict on keys CACHED_SCRIPTS = {} - TYPES = ["STRING", "LIST", "SET", "HASH", "ZSET", "JSON"] + DEFAULT_TYPES = ["STRING", "LIST", "SET", "HASH", "ZSET", "JSON"] - def __init__(self): + def __init__(self, types: typing.Optional[typing.List[str]] = None): self.uid = SeederBase.UID_COUNTER SeederBase.UID_COUNTER += 1 + self.types = types if types is not None else SeederBase.DEFAULT_TYPES @classmethod - async def capture(clz, client: aioredis.Redis) -> typing.Tuple[int]: + async def capture( + clz, client: aioredis.Redis, types: typing.Optional[typing.List[str]] = None + ) -> typing.Tuple[int]: """Generate hash capture for all data stored in instance pointed by client""" sha = await client.script_load(clz._load_script("hash")) + types_to_capture = types if types is not None else clz.DEFAULT_TYPES return tuple( await asyncio.gather( - *(clz._run_capture(client, sha, data_type) for data_type in clz.TYPES) + *(clz._run_capture(client, sha, data_type) for data_type in types_to_capture) ) ) @@ -69,8 +73,15 @@ def _load_script(clz, fname): class StaticSeeder(SeederBase): """Wrapper around DEBUG POPULATE with fuzzy key sizes and a balanced type mix""" - def __init__(self, key_target=10_000, data_size=100, variance=5, samples=10): - SeederBase.__init__(self) + def __init__( + self, + key_target=10_000, + data_size=100, + variance=5, + samples=10, + types: typing.Optional[typing.List[str]] = None, + ): + SeederBase.__init__(self, types) self.key_target = key_target self.data_size = data_size self.variance = variance @@ -79,7 +90,7 @@ def __init__(self, key_target=10_000, data_size=100, variance=5, samples=10): async def run(self, client: aioredis.Redis): """Run with specified options until key_target is met""" samples = [ - (dtype, f"k-s{self.uid}u{i}-") for i, dtype in enumerate(self.TYPES * self.samples) + (dtype, f"k-s{self.uid}u{i}-") for i, dtype in enumerate(self.types * self.samples) ] # Handle samples in chuncks of 24 to not overload client pool and instance @@ -89,7 +100,7 @@ async def run(self, client: aioredis.Redis): ) async def _run_unit(self, client: aioredis.Redis, dtype: str, prefix: str): - key_target = self.key_target // (self.samples * len(self.TYPES)) + key_target = self.key_target // (self.samples * len(self.types)) if dtype == "STRING": dsize = random.uniform(self.data_size / self.variance, self.data_size * self.variance) csize = 1 @@ -120,7 +131,7 @@ def __init__(self, units=10, key_target=10_000, data_size=100): self.units = [ Seeder.Unit( prefix=f"k-s{self.uid}u{i}-", - type=Seeder.TYPES[i % len(Seeder.TYPES)], + type=Seeder.DEFAULT_TYPES[i % len(Seeder.DEFAULT_TYPES)], counter=0, stop_key=f"_s{self.uid}u{i}-stop", ) diff --git a/tests/dragonfly/seeder_test.py b/tests/dragonfly/seeder_test.py index d4bb9379f291..61eba28cedb0 100644 --- a/tests/dragonfly/seeder_test.py +++ b/tests/dragonfly/seeder_test.py @@ -17,7 +17,7 @@ async def test_static_seeder(async_client: aioredis.Redis): @dfly_args({"proactor_threads": 4}) async def test_seeder_key_target(async_client: aioredis.Redis): """Ensure seeder reaches its key targets""" - s = Seeder(units=len(Seeder.TYPES) * 2, key_target=5000) + s = Seeder(units=len(Seeder.DEFAULT_TYPES) * 2, key_target=5000) # Ensure tests are not reasonably slow async with async_timeout.timeout(1 + 4): diff --git a/tests/dragonfly/snapshot_test.py b/tests/dragonfly/snapshot_test.py index d775ea429dd6..e972bf10c6f0 100644 --- a/tests/dragonfly/snapshot_test.py +++ b/tests/dragonfly/snapshot_test.py @@ -1,4 +1,5 @@ import pytest +import logging import os import glob import asyncio @@ -7,6 +8,7 @@ from redis import asyncio as aioredis from pathlib import Path import boto3 +from .instance import RedisServer from . import dfly_args from .utility import wait_available_async, chunked, is_saving @@ -124,6 +126,32 @@ async def test_dbfilenames( assert await StaticSeeder.capture(client) == start_capture +@pytest.mark.asyncio +@dfly_args({**BASIC_ARGS, "proactor_threads": 4, "dbfilename": "test-redis-load-rdb"}) +async def test_redis_load_snapshot( + async_client: aioredis.Redis, df_server, redis_local_server: RedisServer, tmp_dir: Path +): + """ + Test redis server loading dragonfly snapshot rdb format + """ + await StaticSeeder( + **LIGHTWEIGHT_SEEDER_ARGS, types=["STRING", "LIST", "SET", "HASH", "ZSET"] + ).run(async_client) + + await async_client.execute_command("SAVE", "rdb") + dbsize = await async_client.dbsize() + + await async_client.connection_pool.disconnect() + df_server.stop() + + redis_local_server.start(dir=tmp_dir, dbfilename="test-redis-load-rdb.rdb") + await asyncio.sleep(1) + c_master = aioredis.Redis(port=redis_local_server.port) + await c_master.ping() + + assert await c_master.dbsize() == dbsize + + @pytest.mark.slow @dfly_args({**BASIC_ARGS, "dbfilename": "test-cron", "snapshot_cron": "* * * * *"}) async def test_cron_snapshot(tmp_dir: Path, async_client: aioredis.Redis):