Skip to content

Commit

Permalink
chore: extract statemachines and state machine handles to a separate …
Browse files Browse the repository at this point in the history
…package (#3906)

This will make it easier to replace raft based implementations with in
memory versions for testing

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
jvmakine and github-actions[bot] authored Jan 6, 2025
1 parent 44fda1d commit 7eaf2aa
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 80 deletions.
5 changes: 3 additions & 2 deletions cmd/raft-tester/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/block/ftl/internal/log"
"github.com/block/ftl/internal/raft"
sm "github.com/block/ftl/internal/statemachine"
)

var cli struct {
Expand Down Expand Up @@ -57,7 +58,7 @@ func (j *joinCmd) Run() error {
return run(ctx, shard)
}

func run(ctx context.Context, shard *raft.ShardHandle[IntEvent, int64, int64]) error {
func run(ctx context.Context, shard sm.StateMachineHandle[int64, int64, IntEvent]) error {
messages := make(chan int)

wg, ctx := errgroup.WithContext(ctx)
Expand Down Expand Up @@ -85,7 +86,7 @@ func run(ctx context.Context, shard *raft.ShardHandle[IntEvent, int64, int64]) e
for {
select {
case msg := <-messages:
err := shard.Propose(ctx, IntEvent(msg))
err := shard.Update(ctx, IntEvent(msg))
if err != nil {
return fmt.Errorf("failed to propose event: %w", err)
}
Expand Down
4 changes: 2 additions & 2 deletions cmd/raft-tester/statemachine.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"fmt"
"io"

"github.com/block/ftl/internal/raft"
sm "github.com/block/ftl/internal/statemachine"
)

type IntStateMachine struct {
Expand All @@ -23,7 +23,7 @@ func (i IntEvent) MarshalBinary() ([]byte, error) { //nolint:unparam
return binary.BigEndian.AppendUint64([]byte{}, uint64(i)), nil
}

var _ raft.StateMachine[int64, int64, IntEvent, *IntEvent] = &IntStateMachine{}
var _ sm.SnapshottingStateMachine[int64, int64, IntEvent] = &IntStateMachine{}

func (s IntStateMachine) Lookup(key int64) (int64, error) {
return s.sum, nil
Expand Down
31 changes: 14 additions & 17 deletions internal/raft/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/block/ftl/internal/log"
"github.com/block/ftl/internal/retry"
"github.com/block/ftl/internal/rpc"
sm "github.com/block/ftl/internal/statemachine"
)

type RaftConfig struct {
Expand Down Expand Up @@ -51,7 +52,7 @@ type Builder struct {
shards map[uint64]statemachine.CreateStateMachineFunc
controlClient *http.Client

handles []*ShardHandle[Event, any, any]
handles []*ShardHandle[sm.Marshallable, any, any]
}

func NewBuilder(cfg *RaftConfig) *Builder {
Expand All @@ -69,18 +70,18 @@ func (b *Builder) WithControlClient(client *http.Client) *Builder {
}

// AddShard adds a shard to the cluster Builder.
func AddShard[Q any, R any, E Event, EPtr Unmarshallable[E]](
func AddShard[Q any, R any, E sm.Marshallable, EPtr sm.Unmarshallable[E]](
ctx context.Context,
to *Builder,
shardID uint64,
sm StateMachine[Q, R, E, EPtr],
) *ShardHandle[E, Q, R] {
to.shards[shardID] = newStateMachineShim[Q, R, E, EPtr](sm)
statemachine sm.SnapshottingStateMachine[Q, R, E],
) sm.StateMachineHandle[Q, R, E] {
to.shards[shardID] = newStateMachineShim[Q, R, E, EPtr](statemachine)

handle := &ShardHandle[E, Q, R]{
shardID: shardID,
}
to.handles = append(to.handles, (*ShardHandle[Event, any, any])(handle))
to.handles = append(to.handles, (*ShardHandle[sm.Marshallable, any, any])(handle))
return handle
}

Expand Down Expand Up @@ -123,7 +124,7 @@ func (b *Builder) Build(ctx context.Context) *Cluster {
// E is the event type.
// Q is the query type.
// R is the query response type.
type ShardHandle[E Event, Q any, R any] struct {
type ShardHandle[E sm.Marshallable, Q any, R any] struct {
shardID uint64
cluster *Cluster
session *client.Session
Expand All @@ -133,14 +134,14 @@ type ShardHandle[E Event, Q any, R any] struct {
mu sync.Mutex
}

// Propose an event to the shard.
func (s *ShardHandle[E, Q, R]) Propose(ctx context.Context, msg E) error {
logger := log.FromContext(ctx).Scope("raft")

// Update the shard with an event.
func (s *ShardHandle[E, Q, R]) Update(ctx context.Context, msg E) error {
// client session is not thread safe, so we need to lock
s.mu.Lock()
defer s.mu.Unlock()

logger := log.FromContext(ctx).Scope("raft")

s.verifyReady()

msgBytes, err := msg.MarshalBinary()
Expand Down Expand Up @@ -207,11 +208,10 @@ func (s *ShardHandle[E, Q, R]) Changes(ctx context.Context, query Q) (chan R, er
logger := log.FromContext(ctx).Scope("raft")

// get the last known index as the starting point
reader, err := s.cluster.nh.GetLogReader(s.shardID)
last, err := s.getLastIndex()
if err != nil {
logger.Errorf(err, "failed to get log reader")
logger.Errorf(err, "failed to get last index")
}
_, last := reader.GetRange()
s.lastKnownIndex.Store(last)

go func() {
Expand Down Expand Up @@ -253,9 +253,6 @@ func (s *ShardHandle[E, Q, R]) Changes(ctx context.Context, query Q) (chan R, er
}

func (s *ShardHandle[E, Q, R]) getLastIndex() (uint64, error) {
s.mu.Lock()
defer s.mu.Unlock()

s.verifyReady()

reader, err := s.cluster.nh.GetLogReader(s.shardID)
Expand Down
33 changes: 17 additions & 16 deletions internal/raft/cluster_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/block/ftl/internal/log"
"github.com/block/ftl/internal/raft"
"github.com/block/ftl/internal/retry"
sm "github.com/block/ftl/internal/statemachine"
"golang.org/x/sync/errgroup"
)

Expand All @@ -33,7 +34,7 @@ type IntStateMachine struct {
sum int64
}

var _ raft.StateMachine[int64, int64, IntEvent, *IntEvent] = &IntStateMachine{}
var _ sm.SnapshottingStateMachine[int64, int64, IntEvent] = &IntStateMachine{}

func (s *IntStateMachine) Update(event IntEvent) error {
s.sum += int64(event)
Expand All @@ -48,18 +49,18 @@ func (s *IntStateMachine) Close() error { return nil }
func TestClusterWith2Shards(t *testing.T) {
ctx := testContext(t)

_, shards := startClusters(ctx, t, 2, func(b *raft.Builder) []*raft.ShardHandle[IntEvent, int64, int64] {
return []*raft.ShardHandle[IntEvent, int64, int64]{
_, shards := startClusters(ctx, t, 2, func(b *raft.Builder) []sm.StateMachineHandle[int64, int64, IntEvent] {
return []sm.StateMachineHandle[int64, int64, IntEvent]{
raft.AddShard(ctx, b, 1, &IntStateMachine{}),
raft.AddShard(ctx, b, 2, &IntStateMachine{}),
}
})

assert.NoError(t, shards[0][0].Propose(ctx, IntEvent(1)))
assert.NoError(t, shards[1][0].Propose(ctx, IntEvent(1)))
assert.NoError(t, shards[0][0].Update(ctx, IntEvent(1)))
assert.NoError(t, shards[1][0].Update(ctx, IntEvent(1)))

assert.NoError(t, shards[0][1].Propose(ctx, IntEvent(1)))
assert.NoError(t, shards[1][1].Propose(ctx, IntEvent(2)))
assert.NoError(t, shards[0][1].Update(ctx, IntEvent(1)))
assert.NoError(t, shards[1][1].Update(ctx, IntEvent(2)))

assertShardValue(ctx, t, 2, shards[0][0], shards[1][0])
assertShardValue(ctx, t, 3, shards[0][1], shards[1][1])
Expand Down Expand Up @@ -102,7 +103,7 @@ func TestJoiningExistingCluster(t *testing.T) {
cluster3.Stop(ctx)
})

assert.NoError(t, shard3.Propose(ctx, IntEvent(1)))
assert.NoError(t, shard3.Update(ctx, IntEvent(1)))

assertShardValue(ctx, t, 1, shard1, shard2, shard3)

Expand All @@ -116,42 +117,42 @@ func TestJoiningExistingCluster(t *testing.T) {
cluster4.Stop(ctx)
})

assert.NoError(t, shard4.Propose(ctx, IntEvent(1)))
assert.NoError(t, shard4.Update(ctx, IntEvent(1)))

assertShardValue(ctx, t, 2, shard1, shard2, shard3, shard4)
}

func TestLeavingCluster(t *testing.T) {
ctx := testContext(t)

clusters, shards := startClusters(ctx, t, 3, func(b *raft.Builder) *raft.ShardHandle[IntEvent, int64, int64] {
clusters, shards := startClusters(ctx, t, 3, func(b *raft.Builder) sm.StateMachineHandle[int64, int64, IntEvent] {
return raft.AddShard(ctx, b, 1, &IntStateMachine{})
})

t.Log("proposing event")
assert.NoError(t, shards[0].Propose(ctx, IntEvent(1)))
assert.NoError(t, shards[0].Update(ctx, IntEvent(1)))
assertShardValue(ctx, t, 1, shards...)

t.Log("removing member")
clusters[0].Stop(ctx)

t.Log("proposing event after a member has been stopped")
assert.NoError(t, shards[1].Propose(ctx, IntEvent(1)))
assert.NoError(t, shards[1].Update(ctx, IntEvent(1)))
assertShardValue(ctx, t, 2, shards[1:]...)
}

func TestChanges(t *testing.T) {
ctx := testContext(t)

_, shards := startClusters(ctx, t, 2, func(b *raft.Builder) *raft.ShardHandle[IntEvent, int64, int64] {
_, shards := startClusters(ctx, t, 2, func(b *raft.Builder) sm.StateMachineHandle[int64, int64, IntEvent] {
return raft.AddShard(ctx, b, 1, &IntStateMachine{})
})

changes, err := shards[0].Changes(ctx, 0)
assert.NoError(t, err)

assert.NoError(t, shards[0].Propose(ctx, IntEvent(1)))
assert.NoError(t, shards[1].Propose(ctx, IntEvent(1)))
assert.NoError(t, shards[0].Update(ctx, IntEvent(1)))
assert.NoError(t, shards[1].Update(ctx, IntEvent(1)))

<-changes
assert.Equal(t, <-changes, 2)
Expand Down Expand Up @@ -215,7 +216,7 @@ func startClusters[T any](ctx context.Context, t *testing.T, count int, builderF
return clusters, result
}

func assertShardValue(ctx context.Context, t *testing.T, expected int64, shards ...*raft.ShardHandle[IntEvent, int64, int64]) {
func assertShardValue(ctx context.Context, t *testing.T, expected int64, shards ...sm.StateMachineHandle[int64, int64, IntEvent]) {
t.Helper()

for _, shard := range shards {
Expand Down
20 changes: 12 additions & 8 deletions internal/raft/eventview.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,25 @@ import (
"io"

"github.com/block/ftl/internal/eventstream"
sm "github.com/block/ftl/internal/statemachine"
)

type UnitQuery struct{}

type RaftStreamEvent[View encoding.BinaryMarshaler, VPtr Unmarshallable[View]] interface {
type RaftStreamEvent[View encoding.BinaryMarshaler, VPtr sm.Unmarshallable[View]] interface {
encoding.BinaryMarshaler
eventstream.Event[View]
}

type RaftEventView[V encoding.BinaryMarshaler, VPrt Unmarshallable[V], E RaftStreamEvent[V, VPrt]] struct {
shard *ShardHandle[E, UnitQuery, V]
type RaftEventView[V encoding.BinaryMarshaler, VPrt sm.Unmarshallable[V], E RaftStreamEvent[V, VPrt]] struct {
shard sm.StateMachineHandle[UnitQuery, V, E]
}

func (s *RaftEventView[V, VPrt, E]) Publish(ctx context.Context, event E) error {
return s.shard.Propose(ctx, event)
if err := s.shard.Update(ctx, event); err != nil {
return fmt.Errorf("failed to update shard: %w", err)
}
return nil
}

func (s *RaftEventView[V, VPrt, E]) View(ctx context.Context) (V, error) {
Expand All @@ -46,9 +50,9 @@ func (s *RaftEventView[V, VPrt, E]) Changes(ctx context.Context) (chan V, error)

type eventStreamStateMachine[
V encoding.BinaryMarshaler,
VPrt Unmarshallable[V],
VPrt sm.Unmarshallable[V],
E RaftStreamEvent[V, VPrt],
EPtr Unmarshallable[E],
EPtr sm.Unmarshallable[E],
] struct {
view V
}
Expand Down Expand Up @@ -96,9 +100,9 @@ func (s *eventStreamStateMachine[V, VPrt, E, EPtr]) Recover(reader io.Reader) er
// AddEventView to the Builder
func AddEventView[
V encoding.BinaryMarshaler,
VPtr Unmarshallable[V],
VPtr sm.Unmarshallable[V],
E RaftStreamEvent[V, VPtr],
EPtr Unmarshallable[E],
EPtr sm.Unmarshallable[E],
](ctx context.Context, builder *Builder, shardID uint64) eventstream.EventView[V, E] {
sm := &eventStreamStateMachine[V, VPtr, E, EPtr]{}
shard := AddShard[UnitQuery, V, E, EPtr](ctx, builder, shardID, sm)
Expand Down
41 changes: 6 additions & 35 deletions internal/raft/statemachine.go
Original file line number Diff line number Diff line change
@@ -1,50 +1,21 @@
package raft

import (
"encoding"
"fmt"
"io"

"github.com/lni/dragonboat/v4/statemachine"
)

// Event to update the state machine. These are stored in the Raft log.
type Event interface {
encoding.BinaryMarshaler
}

// Unmarshallable is a type that can be unmarshalled from a binary representation.
type Unmarshallable[T any] interface {
*T
encoding.BinaryUnmarshaler
}

// StateMachine is a typed interface to dragonboat's statemachine.IStateMachine.
// It is used to implement the state machine for a single shard.
//
// Q is the query type.
// R is the query response type.
// E is the event type.
type StateMachine[Q any, R any, E Event, EPtr Unmarshallable[E]] interface {
// Query the state of the state machine.
Lookup(key Q) (R, error)
// Update the state of the state machine.
Update(msg E) error
// Save the state of the state machine to a snapshot.
Save(writer io.Writer) error
// Recover the state of the state machine from a snapshot.
Recover(reader io.Reader) error
// Close the state machine.
Close() error
}
sm "github.com/block/ftl/internal/statemachine"
)

// stateMachineShim is a shim to convert a typed StateMachine to a dragonboat statemachine.IStateMachine.
type stateMachineShim[Q any, R any, E Event, EPtr Unmarshallable[E]] struct {
sm StateMachine[Q, R, E, EPtr]
type stateMachineShim[Q any, R any, E sm.Marshallable, EPtr sm.Unmarshallable[E]] struct {
sm sm.SnapshottingStateMachine[Q, R, E]
}

func newStateMachineShim[Q any, R any, E Event, EPtr Unmarshallable[E]](
sm StateMachine[Q, R, E, EPtr],
func newStateMachineShim[Q any, R any, E sm.Marshallable, EPtr sm.Unmarshallable[E]](
sm sm.SnapshottingStateMachine[Q, R, E],
) statemachine.CreateStateMachineFunc {
return func(clusterID uint64, nodeID uint64) statemachine.IStateMachine {
return &stateMachineShim[Q, R, E, EPtr]{sm: sm}
Expand Down
Loading

0 comments on commit 7eaf2aa

Please sign in to comment.