From 83b1a334c71ee8e12374166cf55f5fdab8cf704b Mon Sep 17 00:00:00 2001 From: Vytenis Darulis Date: Thu, 12 Nov 2020 23:20:42 -0500 Subject: [PATCH 1/2] [cluster] Store shards in sorted form --- src/cluster/shard/shard.go | 167 ++++++++++++--------- src/cluster/shard/shard_benchmark_test.go | 173 ++++++++++++++++++++++ src/cluster/shard/shard_test.go | 21 ++- 3 files changed, 287 insertions(+), 74 deletions(-) create mode 100644 src/cluster/shard/shard_benchmark_test.go diff --git a/src/cluster/shard/shard.go b/src/cluster/shard/shard.go index c6b95aa635..7d7f12ff5f 100644 --- a/src/cluster/shard/shard.go +++ b/src/cluster/shard/shard.go @@ -68,17 +68,19 @@ func (s State) Proto() (placementpb.ShardState, error) { func NewShard(id uint32) Shard { return &shard{id: id, state: Unknown} } // NewShardFromProto create a new shard from proto. -func NewShardFromProto(shard *placementpb.Shard) (Shard, error) { - state, err := NewShardStateFromProto(shard.State) +func NewShardFromProto(spb *placementpb.Shard) (Shard, error) { + state, err := NewShardStateFromProto(spb.State) if err != nil { return nil, err } - return NewShard(shard.Id). - SetState(state). - SetSourceID(shard.SourceId). - SetCutoverNanos(shard.CutoverNanos). - SetCutoffNanos(shard.CutoffNanos), nil + return &shard{ + id: spb.Id, + state: state, + sourceID: spb.SourceId, + cutoverNanos: spb.CutoverNanos, + cutoffNanos: spb.CutoffNanos, + }, nil } type shard struct { @@ -146,26 +148,26 @@ func (s *shard) Equals(other Shard) bool { } func (s *shard) Proto() (*placementpb.Shard, error) { - ss, err := s.State().Proto() + ss, err := s.state.Proto() if err != nil { return nil, err } return &placementpb.Shard{ - Id: s.ID(), + Id: s.id, State: ss, - SourceId: s.SourceID(), + SourceId: s.sourceID, CutoverNanos: s.cutoverNanos, CutoffNanos: s.cutoffNanos, }, nil } func (s *shard) Clone() Shard { - return NewShard(s.ID()). - SetState(s.State()). - SetSourceID(s.SourceID()). - SetCutoverNanos(s.CutoverNanos()). - SetCutoffNanos(s.CutoffNanos()) + if s == nil { + return nil + } + clone := *s + return &clone } // SortableShardsByIDAsc are sortable shards by ID in ascending order @@ -188,11 +190,20 @@ func (s SortableIDsAsc) Less(i, j int) bool { // NewShards creates a new instance of Shards func NewShards(ss []Shard) Shards { + shrd := make([]Shard, len(ss)) + copy(shrd, ss) + + sort.Sort(SortableShardsByIDAsc(shrd)) + shardMap := make(map[uint32]Shard, len(ss)) - for _, s := range ss { + for _, s := range shrd { shardMap[s.ID()] = s } - return shards{shardsMap: shardMap} + + return &shards{ + shards: shrd, + shardMap: shardMap, + } } // NewShardsFromProto creates a new set of shards from proto. @@ -209,52 +220,75 @@ func NewShardsFromProto(shards []*placementpb.Shard) (Shards, error) { } type shards struct { - shardsMap map[uint32]Shard + shards []Shard + shardMap map[uint32]Shard } -func (ss shards) All() []Shard { - shards := make([]Shard, 0, len(ss.shardsMap)) - for _, shard := range ss.shardsMap { - shards = append(shards, shard) - } - sort.Sort(SortableShardsByIDAsc(shards)) +func (ss *shards) All() []Shard { + shards := make([]Shard, len(ss.shards)) + copy(shards, ss.shards) + return shards } -func (ss shards) AllIDs() []uint32 { - ids := make([]uint32, 0, len(ss.shardsMap)) - for _, shard := range ss.shardsMap { - ids = append(ids, shard.ID()) +func (ss *shards) AllIDs() []uint32 { + shardIDs := make([]uint32, 0, len(ss.shards)) + for _, shrd := range ss.shards { + shardIDs = append(shardIDs, shrd.ID()) } - sort.Sort(SortableIDsAsc(ids)) - return ids + + return shardIDs } -func (ss shards) NumShards() int { - return len(ss.shardsMap) +func (ss *shards) NumShards() int { + return len(ss.shards) } -func (ss shards) Shard(id uint32) (Shard, bool) { - shard, ok := ss.shardsMap[id] - return shard, ok +func (ss *shards) Shard(id uint32) (Shard, bool) { + shard, ok := ss.shardMap[id] + if !ok { + return nil, false + } + + return shard, true } -func (ss shards) Add(shard Shard) { - ss.shardsMap[shard.ID()] = shard +func (ss *shards) Add(shard Shard) { + id := shard.ID() + i := sort.Search(len(ss.shards), func(i int) bool { return ss.shards[i].ID() >= id }) + if i < len(ss.shards) && ss.shards[i].ID() == id { + ss.shards[i] = shard + ss.shardMap[id] = shard + return + } + + ss.shards = append(ss.shards, shard) + ss.shardMap[id] = shard + + if i >= len(ss.shards)-1 { + return + } + + copy(ss.shards[i+1:], ss.shards[i:]) + ss.shards[i] = shard } -func (ss shards) Remove(shard uint32) { - delete(ss.shardsMap, shard) +func (ss *shards) Remove(id uint32) { + i := sort.Search(len(ss.shards), func(i int) bool { return ss.shards[i].ID() >= id }) + if i < len(ss.shards) && ss.shards[i].ID() == id { + delete(ss.shardMap, id) + ss.shards = ss.shards[:i+copy(ss.shards[i:], ss.shards[i+1:])] + } } -func (ss shards) Contains(shard uint32) bool { - _, ok := ss.shardsMap[shard] +func (ss *shards) Contains(shard uint32) bool { + _, ok := ss.shardMap[shard] return ok } -func (ss shards) NumShardsForState(state State) int { +func (ss *shards) NumShardsForState(state State) int { count := 0 - for _, s := range ss.shardsMap { + for _, s := range ss.shards { if s.State() == state { count++ } @@ -262,9 +296,9 @@ func (ss shards) NumShardsForState(state State) int { return count } -func (ss shards) ShardsForState(state State) []Shard { - var r []Shard - for _, s := range ss.shardsMap { +func (ss *shards) ShardsForState(state State) []Shard { + r := make([]Shard, 0, len(ss.shards)) + for _, s := range ss.shards { if s.State() == state { r = append(r, s) } @@ -272,14 +306,13 @@ func (ss shards) ShardsForState(state State) []Shard { return r } -func (ss shards) Equals(other Shards) bool { - shards := ss.All() - otherShards := other.All() - if len(shards) != len(otherShards) { +func (ss *shards) Equals(other Shards) bool { + if len(ss.shards) != other.NumShards() { return false } - for i, shard := range shards { + otherShards := other.All() + for i, shard := range ss.shards { otherShard := otherShards[i] if !shard.Equals(otherShard) { return false @@ -288,7 +321,7 @@ func (ss shards) Equals(other Shards) bool { return true } -func (ss shards) String() string { +func (ss *shards) String() string { var strs []string for _, state := range validStates() { ids := NewShards(ss.ShardsForState(state)).AllIDs() @@ -298,10 +331,9 @@ func (ss shards) String() string { return fmt.Sprintf("[%s]", strings.Join(strs, ", ")) } -func (ss shards) Proto() ([]*placementpb.Shard, error) { - res := make([]*placementpb.Shard, 0, len(ss.shardsMap)) - // All() returns the shards in ID ascending order. - for _, shard := range ss.All() { +func (ss *shards) Proto() ([]*placementpb.Shard, error) { + res := make([]*placementpb.Shard, 0, len(ss.shards)) + for _, shard := range ss.shards { sp, err := shard.Proto() if err != nil { return nil, err @@ -312,22 +344,21 @@ func (ss shards) Proto() ([]*placementpb.Shard, error) { return res, nil } -func (ss shards) Clone() Shards { - shards := make([]Shard, ss.NumShards()) - for i, shard := range ss.All() { - shards[i] = shard.Clone() +func (ss *shards) Clone() Shards { + shrds := make([]Shard, 0, len(ss.shards)) + shardMap := make(map[uint32]Shard, len(ss.shards)) + + for _, shrd := range ss.shards { + shrds = append(shrds, shrd.Clone()) + shardMap[shrd.ID()] = shrd } - return NewShards(shards) + return &shards{ + shards: shrds, + shardMap: shardMap, + } } -// SortableShardProtosByIDAsc sorts shard protos by their ids in ascending order. -type SortableShardProtosByIDAsc []*placementpb.Shard - -func (su SortableShardProtosByIDAsc) Len() int { return len(su) } -func (su SortableShardProtosByIDAsc) Less(i, j int) bool { return su[i].Id < su[j].Id } -func (su SortableShardProtosByIDAsc) Swap(i, j int) { su[i], su[j] = su[j], su[i] } - // validStates returns all the valid states. func validStates() []State { return []State{ diff --git a/src/cluster/shard/shard_benchmark_test.go b/src/cluster/shard/shard_benchmark_test.go new file mode 100644 index 0000000000..84e681cac8 --- /dev/null +++ b/src/cluster/shard/shard_benchmark_test.go @@ -0,0 +1,173 @@ +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package shard + +import ( + "fmt" + "math/rand" + "runtime" + "testing" +) + +func BenchmarkNewShards(b *testing.B) { + for i := 16; i <= 4096; i *= 4 { + rndShards := makeTestShards(i) + b.Run(fmt.Sprintf("%d shards", i), func(b *testing.B) { + for i := 0; i < b.N; i++ { + res := NewShards(rndShards) + runtime.KeepAlive(res) + } + }) + } +} + +func BenchmarkShardsAllShards(b *testing.B) { + for i := 16; i <= 4096; i *= 4 { + rndShards := makeTestShards(i) + shards := NewShards(rndShards) + + b.Run(fmt.Sprintf("%d shards", i), func(b *testing.B) { + var res []Shard + for i := 0; i < b.N; i++ { + res = shards.All() + } + runtime.KeepAlive(res) + }) + } +} + +func BenchmarkShardsAllShardIDs(b *testing.B) { + for i := 16; i <= 4096; i *= 4 { + rndShards := makeTestShards(i) + shards := NewShards(rndShards) + + b.Run(fmt.Sprintf("%d shards", i), func(b *testing.B) { + var res []uint32 + for i := 0; i < b.N; i++ { + res = shards.AllIDs() + } + runtime.KeepAlive(res) + }) + } +} + +func BenchmarkShardsAdd(b *testing.B) { + for i := 16; i <= 4096; i *= 4 { + rndShards := makeTestShards(i) + b.Run(fmt.Sprintf("%d shards", i), func(b *testing.B) { + for i := 0; i < b.N; i++ { + res := NewShards(nil) + for j := 0; j < len(rndShards); j++ { + res.Add(rndShards[j]) + } + if res.NumShards() != len(rndShards) { + b.Fail() + } + } + }) + } +} + +func BenchmarkShardsEquals(b *testing.B) { + for i := 16; i <= 4096; i *= 4 { + rndShards := makeTestShards(i) + shards := NewShards(rndShards) + clone := shards.Clone() + + b.Run(fmt.Sprintf("%d shards", i), func(b *testing.B) { + var res bool + for i := 0; i < b.N; i++ { + shards.Equals(clone) + } + runtime.KeepAlive(res) + }) + } +} + +func BenchmarkShardsNumShardsForState(b *testing.B) { + for i := 16; i <= 4096; i *= 4 { + rndShards := makeTestShards(i) + shards := NewShards(rndShards) + + b.Run(fmt.Sprintf("%d shards", i), func(b *testing.B) { + var res int + for i := 0; i < b.N; i++ { + res = shards.NumShardsForState(defaultShardState) + } + runtime.KeepAlive(res) + }) + } +} + +func BenchmarkShardsShard(b *testing.B) { + for i := 16; i <= 4096; i *= 4 { + ids := randomIDs(1, i*2) + rndShards := makeTestShards(i) + shards := NewShards(rndShards) + + b.Run(fmt.Sprintf("%d shards", i), func(b *testing.B) { + var res Shard + for i := 0; i < b.N; i++ { + res, _ = shards.Shard(ids[i%len(ids)]) + } + runtime.KeepAlive(res) + }) + } +} + +func BenchmarkShardsContains(b *testing.B) { + for i := 16; i <= 4096; i *= 4 { + ids := randomIDs(1, i*2) + rndShards := makeTestShards(i) + shards := NewShards(rndShards) + + b.Run(fmt.Sprintf("%d shards", i), func(b *testing.B) { + var res bool + for i := 0; i < b.N; i++ { + res = shards.Contains(ids[i%len(ids)]) + } + runtime.KeepAlive(res) + }) + } +} + +func randomIDs(seed int64, num int) []uint32 { + rnd := rand.New(rand.NewSource(seed)) // #nosec + ids := make([]uint32, num) + + for i := uint32(0); i < uint32(num); i++ { + ids[i] = i + } + + rnd.Shuffle(len(ids), func(i, j int) { + ids[i], ids[j] = ids[j], ids[i] + }) + return ids +} + +func makeTestShards(num int) []Shard { + shardIDs := randomIDs(0, num) + s := make([]Shard, num) + for i, shardID := range shardIDs { + s[i] = NewShard(shardID) + } + return s +} diff --git a/src/cluster/shard/shard_test.go b/src/cluster/shard/shard_test.go index 2a6016f265..c3c4335be1 100644 --- a/src/cluster/shard/shard_test.go +++ b/src/cluster/shard/shard_test.go @@ -22,7 +22,6 @@ package shard import ( "math" - "sort" "testing" "github.com/stretchr/testify/assert" @@ -38,7 +37,7 @@ func TestShard(t *testing.T) { assert.Equal(t, int64(100), s.CutoverNanos()) } -func TestShardEqualts(t *testing.T) { +func TestShardEquals(t *testing.T) { s := NewShard(1).SetState(Initializing).SetSourceID("id").SetCutoffNanos(1000).SetCutoverNanos(100) assert.False(t, s.Equals(NewShard(1).SetState(Initializing).SetSourceID("id").SetCutoffNanos(1000))) assert.False(t, s.Equals(NewShard(1).SetState(Initializing).SetSourceID("id").SetCutoverNanos(100))) @@ -87,6 +86,11 @@ func TestShards(t *testing.T) { shards.Add(NewShard(3).SetState(Leaving)) assert.Equal(t, "[Initializing=[1], Available=[2], Leaving=[3]]", shards.String()) + + shards.Add(NewShard(4)) + shards.Add(NewShard(4)) + shards.Add(NewShard(4)) + assert.Equal(t, 4, shards.NumShards()) } func TestShardsEquals(t *testing.T) { @@ -110,10 +114,15 @@ func TestSort(t *testing.T) { shards = append(shards, NewShard(1)) shards = append(shards, NewShard(2)) shards = append(shards, NewShard(0)) - sortable := SortableShardsByIDAsc(shards) - sort.Sort(sortable) - for i := range shards { - assert.Equal(t, uint32(i), shards[i].ID()) + shards = append(shards, NewShard(3)) + + var prev int = -1 + for _, shard := range NewShards(shards).All() { + id := int(shard.ID()) + if id <= prev { + t.Fatalf("expected id to be greater than %d, got %d", prev, id) + } + prev = id } } From 9532dee1a8c5028b0704f0068b8eb45838591612 Mon Sep 17 00:00:00 2001 From: Vytenis Darulis Date: Fri, 13 Nov 2020 17:13:53 -0500 Subject: [PATCH 2/2] address comments --- src/cluster/placement/algo/sharded_test.go | 4 +- src/cluster/shard/shard.go | 24 +++++-- src/cluster/shard/shard_benchmark_test.go | 24 ------- src/cluster/shard/shard_test.go | 78 +++++++++++++++++++--- 4 files changed, 89 insertions(+), 41 deletions(-) diff --git a/src/cluster/placement/algo/sharded_test.go b/src/cluster/placement/algo/sharded_test.go index 8d61736838..3f529e3fe0 100644 --- a/src/cluster/placement/algo/sharded_test.go +++ b/src/cluster/placement/algo/sharded_test.go @@ -1230,8 +1230,8 @@ func verifyAllShardsInAvailableState(t *testing.T, p placement.Placement) { for _, instance := range p.Instances() { s := instance.Shards() require.Equal(t, len(s.All()), len(s.ShardsForState(shard.Available))) - require.Nil(t, s.ShardsForState(shard.Initializing)) - require.Nil(t, s.ShardsForState(shard.Leaving)) + require.Empty(t, s.ShardsForState(shard.Initializing)) + require.Empty(t, s.ShardsForState(shard.Leaving)) } } diff --git a/src/cluster/shard/shard.go b/src/cluster/shard/shard.go index 7d7f12ff5f..315cbed3e8 100644 --- a/src/cluster/shard/shard.go +++ b/src/cluster/shard/shard.go @@ -190,18 +190,21 @@ func (s SortableIDsAsc) Less(i, j int) bool { // NewShards creates a new instance of Shards func NewShards(ss []Shard) Shards { - shrd := make([]Shard, len(ss)) - copy(shrd, ss) - - sort.Sort(SortableShardsByIDAsc(shrd)) - + // deduplicate first, last one wins shardMap := make(map[uint32]Shard, len(ss)) - for _, s := range shrd { + for _, s := range ss { shardMap[s.ID()] = s } + shrds := make([]Shard, 0, len(shardMap)) + for _, s := range shardMap { + shrds = append(shrds, s) + } + + sort.Sort(SortableShardsByIDAsc(shrds)) + return &shards{ - shards: shrd, + shards: shrds, shardMap: shardMap, } } @@ -255,6 +258,8 @@ func (ss *shards) Shard(id uint32) (Shard, bool) { func (ss *shards) Add(shard Shard) { id := shard.ID() + // we keep a sorted slice of shards, do a binary search to either find the index + // of an existing shard for replacement, or the target index position i := sort.Search(len(ss.shards), func(i int) bool { return ss.shards[i].ID() >= id }) if i < len(ss.shards) && ss.shards[i].ID() == id { ss.shards[i] = shard @@ -262,21 +267,26 @@ func (ss *shards) Add(shard Shard) { return } + // extend the sorted shard slice by 1 ss.shards = append(ss.shards, shard) ss.shardMap[id] = shard + // target position was at the end, so extending with the new shard was enough if i >= len(ss.shards)-1 { return } + // if not, copy over all slice elements shifted by 1 and overwrite data at index copy(ss.shards[i+1:], ss.shards[i:]) ss.shards[i] = shard } func (ss *shards) Remove(id uint32) { + // we keep a sorted slice of shards, do a binary search to find the index i := sort.Search(len(ss.shards), func(i int) bool { return ss.shards[i].ID() >= id }) if i < len(ss.shards) && ss.shards[i].ID() == id { delete(ss.shardMap, id) + // shift all other elements back after removal ss.shards = ss.shards[:i+copy(ss.shards[i:], ss.shards[i+1:])] } } diff --git a/src/cluster/shard/shard_benchmark_test.go b/src/cluster/shard/shard_benchmark_test.go index 84e681cac8..d61f745581 100644 --- a/src/cluster/shard/shard_benchmark_test.go +++ b/src/cluster/shard/shard_benchmark_test.go @@ -22,7 +22,6 @@ package shard import ( "fmt" - "math/rand" "runtime" "testing" ) @@ -148,26 +147,3 @@ func BenchmarkShardsContains(b *testing.B) { }) } } - -func randomIDs(seed int64, num int) []uint32 { - rnd := rand.New(rand.NewSource(seed)) // #nosec - ids := make([]uint32, num) - - for i := uint32(0); i < uint32(num); i++ { - ids[i] = i - } - - rnd.Shuffle(len(ids), func(i, j int) { - ids[i], ids[j] = ids[j], ids[i] - }) - return ids -} - -func makeTestShards(num int) []Shard { - shardIDs := randomIDs(0, num) - s := make([]Shard, num) - for i, shardID := range shardIDs { - s[i] = NewShard(shardID) - } - return s -} diff --git a/src/cluster/shard/shard_test.go b/src/cluster/shard/shard_test.go index c3c4335be1..62c0379426 100644 --- a/src/cluster/shard/shard_test.go +++ b/src/cluster/shard/shard_test.go @@ -22,6 +22,7 @@ package shard import ( "math" + "math/rand" "testing" "github.com/stretchr/testify/assert" @@ -116,14 +117,7 @@ func TestSort(t *testing.T) { shards = append(shards, NewShard(0)) shards = append(shards, NewShard(3)) - var prev int = -1 - for _, shard := range NewShards(shards).All() { - id := int(shard.ID()) - if id <= prev { - t.Fatalf("expected id to be greater than %d, got %d", prev, id) - } - prev = id - } + shardsAreSorted(t, NewShards(shards)) } func TestShardCutoverTimes(t *testing.T) { @@ -218,3 +212,71 @@ func TestClone(t *testing.T) { ss1.Add(NewShard(2).SetState(Leaving)) require.False(t, ss1.Equals(ss2)) } + +func TestShardAdd(t *testing.T) { + for i := 1; i < 500; i++ { + rndShards := makeTestShards(i) + shards := NewShards(nil) + for j := 0; j < len(rndShards); j++ { + id := rndShards[j].ID() + require.False(t, shards.Contains(id)) + + shards.Add(rndShards[j]) + require.True(t, shards.Contains(id)) + + shrd, ok := shards.Shard(id) + require.True(t, ok) + require.Equal(t, id, shrd.ID()) + } + shardsAreSorted(t, shards) + } +} + +func TestShardRemove(t *testing.T) { + for i := 1; i < 500; i++ { + rndShards := makeTestShards(i) + shards := NewShards(rndShards) + for j := 0; j < len(rndShards); j++ { + id := rndShards[j].ID() + require.True(t, shards.Contains(id)) + shards.Remove(id) + require.Equal(t, len(rndShards)-j-1, shards.NumShards()) + require.False(t, shards.Contains(id)) + } + shardsAreSorted(t, shards) + } +} + +func randomIDs(seed int64, num int) []uint32 { + rnd := rand.New(rand.NewSource(seed)) // #nosec + ids := make([]uint32, num) + + for i := uint32(0); i < uint32(num); i++ { + ids[i] = i + } + + rnd.Shuffle(len(ids), func(i, j int) { + ids[i], ids[j] = ids[j], ids[i] + }) + return ids +} + +func makeTestShards(num int) []Shard { + shardIDs := randomIDs(0, num) + s := make([]Shard, num) + for i, shardID := range shardIDs { + s[i] = NewShard(shardID) + } + return s +} + +func shardsAreSorted(t *testing.T, shards Shards) { + var prev int = -1 + for _, shard := range shards.All() { + id := int(shard.ID()) + if id <= prev { + t.Fatalf("expected id to be greater than %d, got %d", prev, id) + } + prev = id + } +}