Skip to content

Commit

Permalink
fix: use sessions to avoid duplicate raft messages on timeouts (#3888)
Browse files Browse the repository at this point in the history
  • Loading branch information
jvmakine authored Jan 3, 2025
1 parent 9e0b74c commit 4257c61
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 37 deletions.
55 changes: 33 additions & 22 deletions internal/raft/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"sync"
"time"

"github.com/jpillora/backoff"
Expand All @@ -24,7 +25,7 @@ type RaftConfig struct {
ShardReadyTimeout time.Duration `help:"Timeout for shard to be ready" default:"5s"`
// Raft configuration
RTT time.Duration `help:"Estimated average round trip time between nodes" default:"200ms"`
ElectionTimeoutRTT uint64 `help:"Election timeout RTT as a multiple of RTT" default:"10"`
ElectionRTT uint64 `help:"Election RTT as a multiple of RTT" default:"10"`
HeartbeatRTT uint64 `help:"Heartbeat RTT as a multiple of RTT" default:"1"`
SnapshotEntries uint64 `help:"Snapshot entries" default:"10"`
CompactionOverhead uint64 `help:"Compaction overhead" default:"100"`
Expand Down Expand Up @@ -91,24 +92,39 @@ type ShardHandle[E Event, Q any, R any] struct {
shardID uint64
cluster *Cluster
session *client.Session

mu sync.Mutex
}

// Propose an event to the shard.
func (s *ShardHandle[E, Q, R]) Propose(ctx context.Context, msg E) error {
// client session is not thread safe, so we need to lock
s.mu.Lock()
defer s.mu.Unlock()

s.verifyReady()

msgBytes, err := msg.MarshalBinary()
if err != nil {
return fmt.Errorf("failed to marshal event: %w", err)
}
if s.session == nil {
// use a no-op session for now. This means that a retry on timeout could result into duplicate events.
s.session = s.cluster.nh.GetNoOPSession(s.shardID)
if err := s.cluster.withRetry(ctx, s.shardID, s.cluster.config.ReplicaID, func(ctx context.Context) error {
s.session, err = s.cluster.nh.SyncGetSession(ctx, s.shardID)
return err //nolint:wrapcheck
}); err != nil {
return fmt.Errorf("failed to get session: %w", err)
}
}

if err := s.cluster.withRetry(ctx, s.shardID, s.cluster.config.ReplicaID, func(ctx context.Context) error {
s.session.PrepareForPropose()
_, err := s.cluster.nh.SyncPropose(ctx, s.session, msgBytes)
return err //nolint:wrapcheck
if err != nil {
return err //nolint:wrapcheck
}
s.session.ProposalCompleted()
return nil
}); err != nil {
return fmt.Errorf("failed to propose event: %w", err)
}
Expand Down Expand Up @@ -177,7 +193,7 @@ func (c *Cluster) start(ctx context.Context, join bool) error {
ReplicaID: c.config.ReplicaID,
ShardID: shardID,
CheckQuorum: true,
ElectionRTT: c.config.ElectionTimeoutRTT,
ElectionRTT: c.config.ElectionRTT,
HeartbeatRTT: c.config.HeartbeatRTT,
SnapshotEntries: c.config.SnapshotEntries,
CompactionOverhead: c.config.CompactionOverhead,
Expand Down Expand Up @@ -208,21 +224,16 @@ func (c *Cluster) start(ctx context.Context, join bool) error {
}

// Stop the node host and all shards.
func (c *Cluster) Stop(ctx context.Context) error {
if c.nh == nil {
return nil
}

for shardID := range c.shards {
if err := c.removeShardMember(ctx, shardID, c.config.ReplicaID); err != nil {
return fmt.Errorf("failed to remove shard (%d) member: %w", shardID, err)
// After this call, all the shard handlers created with this cluster are invalid.
func (c *Cluster) Stop(ctx context.Context) {
if c.nh != nil {
for shardID := range c.shards {
c.removeShardMember(ctx, shardID, c.config.ReplicaID)
}
c.nh.Close()
c.nh = nil
c.shards = nil
}

c.nh.Close()
c.nh = nil

return nil
}

// AddMember to the cluster. This needs to be called on an existing running cluster member,
Expand All @@ -241,16 +252,16 @@ func (c *Cluster) AddMember(ctx context.Context, shardID uint64, replicaID uint6

// removeShardMember from the given shard. This removes the given member from the membership group
// and blocks until the change has been committed
func (c *Cluster) removeShardMember(ctx context.Context, shardID uint64, replicaID uint64) error {
func (c *Cluster) removeShardMember(ctx context.Context, shardID uint64, replicaID uint64) {
logger := log.FromContext(ctx).Scope("raft")
logger.Infof("removing replica %d from shard %d", shardID, replicaID)

if err := c.withRetry(ctx, shardID, replicaID, func(ctx context.Context) error {
return c.nh.SyncRequestDeleteReplica(ctx, shardID, replicaID, 0)
}); err != nil {
return fmt.Errorf("failed to remove member: %w", err)
// This can happen if the cluster is shutting down and no longer has quorum.
logger.Warnf("removing replica %d from shard %d failed: %s", replicaID, shardID, err)
}
return nil
}

// withTimeout runs an async dragonboat call and blocks until it succeeds or the context is cancelled.
Expand All @@ -268,7 +279,7 @@ func (c *Cluster) withRetry(ctx context.Context, shardID, replicaID uint64, f fu
// Timeout for the proposal to reach the leader and reach a quorum.
// If the leader is not available, the proposal will time out, in which case
// we retry the operation.
timeout := time.Duration(c.config.ElectionTimeoutRTT) * c.config.RTT
timeout := time.Duration(c.config.ElectionRTT) * c.config.RTT
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()

Expand Down
30 changes: 18 additions & 12 deletions internal/raft/cluster_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func (s *IntStateMachine) Close() error { return nil }
func TestCluster(t *testing.T) {
ctx := log.ContextWithNewDefaultLogger(context.Background())
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(20*time.Second))
defer cancel()
t.Cleanup(cancel)

members, err := local.FreeTCPAddresses(2)
assert.NoError(t, err)
Expand All @@ -64,8 +64,10 @@ func TestCluster(t *testing.T) {
wg.Go(func() error { return cluster1.Start(wctx) })
wg.Go(func() error { return cluster2.Start(wctx) })
assert.NoError(t, wg.Wait())
defer cluster1.Stop(ctx) //nolint:errcheck
defer cluster2.Stop(ctx) //nolint:errcheck
t.Cleanup(func() {
cluster1.Stop(ctx)
cluster2.Stop(ctx)
})

assert.NoError(t, shard1_1.Propose(ctx, IntEvent(1)))
assert.NoError(t, shard2_1.Propose(ctx, IntEvent(2)))
Expand All @@ -80,7 +82,7 @@ func TestCluster(t *testing.T) {
func TestJoiningExistingCluster(t *testing.T) {
ctx := log.ContextWithNewDefaultLogger(context.Background())
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(20*time.Second))
defer cancel()
t.Cleanup(cancel)

members, err := local.FreeTCPAddresses(4)
assert.NoError(t, err)
Expand All @@ -97,8 +99,10 @@ func TestJoiningExistingCluster(t *testing.T) {
wg.Go(func() error { return cluster1.Start(wctx) })
wg.Go(func() error { return cluster2.Start(wctx) })
assert.NoError(t, wg.Wait())
defer cluster1.Stop(ctx) //nolint:errcheck
defer cluster2.Stop(ctx) //nolint:errcheck
t.Cleanup(func() {
cluster1.Stop(ctx)
cluster2.Stop(ctx)
})

t.Log("join to the existing cluster as a new member")
builder3 := testBuilder(t, nil, 3, members[2].String())
Expand Down Expand Up @@ -131,7 +135,7 @@ func TestJoiningExistingCluster(t *testing.T) {
func TestLeavingCluster(t *testing.T) {
ctx := log.ContextWithNewDefaultLogger(context.Background())
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(20*time.Second))
defer cancel()
t.Cleanup(cancel)

members, err := local.FreeTCPAddresses(3)
assert.NoError(t, err)
Expand All @@ -151,16 +155,18 @@ func TestLeavingCluster(t *testing.T) {
wg.Go(func() error { return cluster2.Start(wctx) })
wg.Go(func() error { return cluster3.Start(wctx) })
assert.NoError(t, wg.Wait())
defer cluster1.Stop(ctx) //nolint:errcheck
defer cluster2.Stop(ctx) //nolint:errcheck
defer cluster3.Stop(ctx) //nolint:errcheck
t.Cleanup(func() {
cluster1.Stop(ctx)
cluster2.Stop(ctx)
cluster3.Stop(ctx)
})

t.Log("proposing event")
assert.NoError(t, shard1.Propose(ctx, IntEvent(1)))
assertShardValue(ctx, t, 1, shard1, shard2, shard3)

t.Log("removing member")
assert.NoError(t, cluster1.Stop(ctx))
cluster1.Stop(ctx)

t.Log("proposing event after removal")
assert.NoError(t, shard2.Propose(ctx, IntEvent(1)))
Expand All @@ -179,7 +185,7 @@ func testBuilder(t *testing.T, addresses []*net.TCPAddr, id uint64, address stri
DataDir: t.TempDir(),
InitialMembers: members,
HeartbeatRTT: 1,
ElectionTimeoutRTT: 10,
ElectionRTT: 5,
SnapshotEntries: 10,
CompactionOverhead: 10,
RTT: 10 * time.Millisecond,
Expand Down
8 changes: 5 additions & 3 deletions internal/raft/eventview_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func (v *IntSumView) UnmarshalBinary(data []byte) error {
func TestEventView(t *testing.T) {
ctx := log.ContextWithNewDefaultLogger(context.Background())
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(60*time.Second))
defer cancel()
t.Cleanup(cancel)

members, err := local.FreeTCPAddresses(2)
assert.NoError(t, err)
Expand All @@ -63,8 +63,10 @@ func TestEventView(t *testing.T) {
eg.Go(func() error { return cluster1.Start(wctx) })
eg.Go(func() error { return cluster2.Start(wctx) })
assert.NoError(t, eg.Wait())
defer cluster1.Stop(ctx) //nolint:errcheck
defer cluster2.Stop(ctx) //nolint:errcheck
t.Cleanup(func() {
cluster1.Stop(ctx)
cluster2.Stop(ctx)
})

assert.NoError(t, view1.Publish(ctx, IntStreamEvent{Value: 1}))

Expand Down

0 comments on commit 4257c61

Please sign in to comment.