Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(server): fix compatibility with rdb snapshot #3121

Merged
merged 3 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/server/snapshot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ void SliceSnapshot::Start(bool stream_journal, const Cancellation* cll) {
VLOG(1) << "DbSaver::Start - saving entries with version less than " << snapshot_version_;

snapshot_fb_ = fb2::Fiber("snapshot", [this, stream_journal, cll] {
IterateBucketsFb(cll);
IterateBucketsFb(cll, stream_journal);
db_slice_->UnregisterOnChange(snapshot_version_);
if (cll->IsCancelled()) {
Cancel();
Expand Down Expand Up @@ -174,7 +174,7 @@ void SliceSnapshot::Join() {
// and survived until it finished.

// Serializes all the entries with version less than snapshot_version_.
void SliceSnapshot::IterateBucketsFb(const Cancellation* cll) {
void SliceSnapshot::IterateBucketsFb(const Cancellation* cll, bool send_full_sync_cut) {
{
auto fiber_name = absl::StrCat("SliceSnapshot-", ProactorBase::me()->GetPoolIndex());
ThisFiber::SetName(std::move(fiber_name));
Expand Down Expand Up @@ -223,8 +223,10 @@ void SliceSnapshot::IterateBucketsFb(const Cancellation* cll) {
} // for (dbindex)

CHECK(!serialize_bucket_running_);
CHECK(!serializer_->SendFullSyncCut());
PushSerializedToChannel(true);
if (send_full_sync_cut) {
CHECK(!serializer_->SendFullSyncCut());
PushSerializedToChannel(true);
}

// serialized + side_saved must be equal to the total saved.
VLOG(1) << "Exit SnapshotSerializer (loop_serialized/side_saved/cbcalls): "
Expand Down
2 changes: 1 addition & 1 deletion src/server/snapshot.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class SliceSnapshot {
private:
// Main fiber that iterates over all buckets in the db slice
// and submits them to SerializeBucket.
void IterateBucketsFb(const Cancellation* cll);
void IterateBucketsFb(const Cancellation* cll, bool send_full_sync_cut);

// Called on traversing cursor by IterateBucketsFb.
bool BucketSaveCb(PrimeIterator it);
Expand Down
23 changes: 22 additions & 1 deletion tests/dragonfly/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pathlib import Path
from tempfile import TemporaryDirectory

from .instance import DflyInstance, DflyParams, DflyInstanceFactory
from .instance import DflyInstance, DflyParams, DflyInstanceFactory, RedisServer
from . import PortPicker, dfly_args
from .utility import DflySeederFactory, gen_ca_cert, gen_certificate

Expand Down Expand Up @@ -366,3 +366,24 @@ def run_before_and_after_test():
yield # this is where the testing happens

# Teardown


@pytest.fixture(scope="function")
def redis_server(port_picker) -> RedisServer:
s = RedisServer(port_picker.get_available_port())
try:
s.start()
except FileNotFoundError as e:
pytest.skip("Redis server not found")
return None
time.sleep(1)
yield s
s.stop()


@pytest.fixture(scope="function")
def redis_local_server(port_picker) -> RedisServer:
s = RedisServer(port_picker.get_available_port())
time.sleep(1)
yield s
s.stop()
33 changes: 33 additions & 0 deletions tests/dragonfly/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,3 +364,36 @@ def stop_all(self):

def __repr__(self) -> str:
return f"Factory({self.args})"


class RedisServer:
def __init__(self, port):
self.port = port
self.proc = None

def start(self, **kwargs):
command = [
"redis-server-6.2.11",
f"--port {self.port}",
"--save ''",
"--appendonly no",
"--protected-mode no",
"--repl-diskless-sync yes",
"--repl-diskless-sync-delay 0",
]
# Convert kwargs to command-line arguments
for key, value in kwargs.items():
if value is None:
command.append(f"--{key}")
else:
command.append(f"--{key} {value}")

self.proc = subprocess.Popen(command)
logging.debug(self.proc.args)

def stop(self):
self.proc.terminate()
try:
self.proc.wait(timeout=10)
except Exception as e:
pass
40 changes: 0 additions & 40 deletions tests/dragonfly/redis_replication_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,6 @@
from .proxy import Proxy


class RedisServer:
def __init__(self, port):
self.port = port
self.proc = None

def start(self):
self.proc = subprocess.Popen(
[
"redis-server-6.2.11",
f"--port {self.port}",
"--save ''",
"--appendonly no",
"--protected-mode no",
"--repl-diskless-sync yes",
"--repl-diskless-sync-delay 0",
]
)
logging.debug(self.proc.args)

def stop(self):
self.proc.terminate()
try:
self.proc.wait(timeout=10)
except Exception as e:
pass


# Checks that master redis and dragonfly replica are synced by writing a random key to master
# and waiting for it to exist in replica. Foreach db in 0..dbcount-1.
async def await_synced(c_master: aioredis.Redis, c_replica: aioredis.Redis, dbcount=1):
Expand Down Expand Up @@ -71,19 +44,6 @@ async def check_data(seeder, replicas, c_replicas):
assert await seeder.compare(capture, port=replica.port)


@pytest.fixture(scope="function")
def redis_server(port_picker) -> RedisServer:
s = RedisServer(port_picker.get_available_port())
try:
s.start()
except FileNotFoundError as e:
pytest.skip("Redis server not found")
return None
time.sleep(1)
yield s
s.stop()


full_sync_replication_specs = [
([1], dict(keys=100, dbcount=1, unsupported_types=[ValueType.JSON])),
([1], dict(keys=5000, dbcount=2, unsupported_types=[ValueType.JSON])),
Expand Down
29 changes: 20 additions & 9 deletions tests/dragonfly/seeder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,24 @@
class SeederBase:
UID_COUNTER = 1 # multiple generators should not conflict on keys
CACHED_SCRIPTS = {}
TYPES = ["STRING", "LIST", "SET", "HASH", "ZSET", "JSON"]
DEFAULT_TYPES = ["STRING", "LIST", "SET", "HASH", "ZSET", "JSON"]

def __init__(self):
def __init__(self, types: typing.Optional[typing.List[str]] = None):
self.uid = SeederBase.UID_COUNTER
SeederBase.UID_COUNTER += 1
self.types = types if types is not None else SeederBase.DEFAULT_TYPES

@classmethod
async def capture(clz, client: aioredis.Redis) -> typing.Tuple[int]:
async def capture(
clz, client: aioredis.Redis, types: typing.Optional[typing.List[str]] = None
) -> typing.Tuple[int]:
"""Generate hash capture for all data stored in instance pointed by client"""

sha = await client.script_load(clz._load_script("hash"))
types_to_capture = types if types is not None else clz.DEFAULT_TYPES
return tuple(
await asyncio.gather(
*(clz._run_capture(client, sha, data_type) for data_type in clz.TYPES)
*(clz._run_capture(client, sha, data_type) for data_type in types_to_capture)
)
)

Expand Down Expand Up @@ -69,8 +73,15 @@ def _load_script(clz, fname):
class StaticSeeder(SeederBase):
"""Wrapper around DEBUG POPULATE with fuzzy key sizes and a balanced type mix"""

def __init__(self, key_target=10_000, data_size=100, variance=5, samples=10):
SeederBase.__init__(self)
def __init__(
self,
key_target=10_000,
data_size=100,
variance=5,
samples=10,
types: typing.Optional[typing.List[str]] = None,
):
SeederBase.__init__(self, types)
self.key_target = key_target
self.data_size = data_size
self.variance = variance
Expand All @@ -79,7 +90,7 @@ def __init__(self, key_target=10_000, data_size=100, variance=5, samples=10):
async def run(self, client: aioredis.Redis):
"""Run with specified options until key_target is met"""
samples = [
(dtype, f"k-s{self.uid}u{i}-") for i, dtype in enumerate(self.TYPES * self.samples)
(dtype, f"k-s{self.uid}u{i}-") for i, dtype in enumerate(self.types * self.samples)
]

# Handle samples in chuncks of 24 to not overload client pool and instance
Expand All @@ -89,7 +100,7 @@ async def run(self, client: aioredis.Redis):
)

async def _run_unit(self, client: aioredis.Redis, dtype: str, prefix: str):
key_target = self.key_target // (self.samples * len(self.TYPES))
key_target = self.key_target // (self.samples * len(self.types))
if dtype == "STRING":
dsize = random.uniform(self.data_size / self.variance, self.data_size * self.variance)
csize = 1
Expand Down Expand Up @@ -120,7 +131,7 @@ def __init__(self, units=10, key_target=10_000, data_size=100):
self.units = [
Seeder.Unit(
prefix=f"k-s{self.uid}u{i}-",
type=Seeder.TYPES[i % len(Seeder.TYPES)],
type=Seeder.DEFAULT_TYPES[i % len(Seeder.DEFAULT_TYPES)],
counter=0,
stop_key=f"_s{self.uid}u{i}-stop",
)
Expand Down
2 changes: 1 addition & 1 deletion tests/dragonfly/seeder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async def test_static_seeder(async_client: aioredis.Redis):
@dfly_args({"proactor_threads": 4})
async def test_seeder_key_target(async_client: aioredis.Redis):
"""Ensure seeder reaches its key targets"""
s = Seeder(units=len(Seeder.TYPES) * 2, key_target=5000)
s = Seeder(units=len(Seeder.DEFAULT_TYPES) * 2, key_target=5000)

# Ensure tests are not reasonably slow
async with async_timeout.timeout(1 + 4):
Expand Down
28 changes: 28 additions & 0 deletions tests/dragonfly/snapshot_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import logging
import os
import glob
import asyncio
Expand All @@ -7,6 +8,7 @@
from redis import asyncio as aioredis
from pathlib import Path
import boto3
from .instance import RedisServer

from . import dfly_args
from .utility import wait_available_async, chunked, is_saving
Expand Down Expand Up @@ -124,6 +126,32 @@ async def test_dbfilenames(
assert await StaticSeeder.capture(client) == start_capture


@pytest.mark.asyncio
@dfly_args({**BASIC_ARGS, "proactor_threads": 4, "dbfilename": "test-redis-load-rdb"})
async def test_redis_load_snapshot(
async_client: aioredis.Redis, df_server, redis_local_server: RedisServer, tmp_dir: Path
):
"""
Test redis server loading dragonfly snapshot rdb format
"""
await StaticSeeder(
**LIGHTWEIGHT_SEEDER_ARGS, types=["STRING", "LIST", "SET", "HASH", "ZSET"]
).run(async_client)

await async_client.execute_command("SAVE", "rdb")
dbsize = await async_client.dbsize()

await async_client.connection_pool.disconnect()
df_server.stop()

redis_local_server.start(dir=tmp_dir, dbfilename="test-redis-load-rdb.rdb")
await asyncio.sleep(1)
c_master = aioredis.Redis(port=redis_local_server.port)
await c_master.ping()

assert await c_master.dbsize() == dbsize
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First redis compatibility test 👍🏻

If more checks are needed, you can re-define dragonfly.ihash in redis. It's bit of a hassle, but doable 🙂

local dragonfly = {}
function draognfly.ihash(hash, sort, ...)
  local res = redis.pcall(arg)
  -- if sort, sort res
  -- use any hash function to res... lua really doesn't have one :(
end

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to use the seeder capture but got error
redis.exceptions.ResponseError: Error compiling script (new function): user_script:1: unexpected symbol near '#'

I want the fix to go to 1.19, so I will merge this now and can investigate more if we can do a simple change to make this work

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adiholden you got an error with Redis, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean the script error? yes its in redis



@pytest.mark.slow
@dfly_args({**BASIC_ARGS, "dbfilename": "test-cron", "snapshot_cron": "* * * * *"})
async def test_cron_snapshot(tmp_dir: Path, async_client: aioredis.Redis):
Expand Down
Loading