Skip to content

Commit

Permalink
util/shuffle: improve shuffle code to not lock
Browse files Browse the repository at this point in the history
There was no reason to use the global rand with its mutex when shuffling.
Also, the math/rand's algorithm is better than ours, use it.

Release note: None
  • Loading branch information
ajwerner committed Oct 18, 2022
1 parent 51fb439 commit b8a318a
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 24 deletions.
20 changes: 15 additions & 5 deletions pkg/util/shuffle/shuffle.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@

package shuffle

import "math/rand"
import (
"math/rand"
"sync"
"sync/atomic"
)

// Interface for shuffle. When it is satisfied, a collection can be shuffled by
// the routines in this package. The methods require that the elements of the
Expand All @@ -23,10 +27,16 @@ type Interface interface {
Swap(i, j int)
}

var seedSource int64
var randSyncPool = sync.Pool{
New: func() interface{} {
return rand.New(rand.NewSource(atomic.AddInt64(&seedSource, 1)))
},
}

// Shuffle randomizes the order of the array.
func Shuffle(data Interface) {
n := data.Len()
for i := 1; i < n; i++ {
data.Swap(i, rand.Intn(i+1))
}
r := randSyncPool.Get().(*rand.Rand)
defer randSyncPool.Put(r)
r.Shuffle(data.Len(), data.Swap)
}
96 changes: 77 additions & 19 deletions pkg/util/shuffle/shuffle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
package shuffle

import (
"fmt"
"math/rand"
"reflect"
"sync"
"testing"
"unsafe"

"github.com/cockroachdb/cockroach/pkg/util/leaktest"
)
Expand All @@ -26,9 +29,15 @@ func (ts testSlice) Swap(i, j int) { ts[i], ts[j] = ts[j], ts[i] }

func TestShuffle(t *testing.T) {
defer leaktest.AfterTest(t)()
rand.Seed(0)
old := randSyncPool.New
defer func() { randSyncPool.New = old }()
r := rand.New(rand.NewSource(0))
randSyncPool.New = func() interface{} {
return r
}

verify := func(original, expected testSlice) {
t.Helper()
Shuffle(original)
if !reflect.DeepEqual(original, expected) {
t.Errorf("expected %v, got %v", expected, original)
Expand All @@ -44,33 +53,82 @@ func TestShuffle(t *testing.T) {
verify(ts, testSlice{1})

ts = testSlice{1, 2}
verify(ts, testSlice{2, 1})
verify(ts, testSlice{1, 2})
verify(ts, testSlice{2, 1})

ts = testSlice{1, 2, 3}
verify(ts, testSlice{3, 1, 2})
verify(ts, testSlice{2, 3, 1})
verify(ts, testSlice{1, 3, 2})
verify(ts, testSlice{1, 2, 3})
verify(ts, testSlice{1, 2, 3})
verify(ts, testSlice{3, 1, 2})

ts = testSlice{1, 2, 3, 4, 5}
verify(ts, testSlice{2, 1, 3, 5, 4})
verify(ts, testSlice{4, 2, 1, 5, 3})
verify(ts, testSlice{1, 4, 2, 3, 5})
verify(ts, testSlice{2, 5, 4, 1, 3})
verify(ts, testSlice{4, 2, 3, 1, 5})
verify(ts, testSlice{5, 1, 3, 4, 2})
verify(ts, testSlice{2, 5, 3, 1, 4})
verify(ts, testSlice{3, 2, 5, 4, 1})
verify(ts, testSlice{1, 2, 4, 3, 5})
verify(ts, testSlice{3, 1, 5, 2, 4})

verify(ts[2:2], testSlice{})
verify(ts[0:0], testSlice{})
verify(ts[5:5], testSlice{})
verify(ts[3:5], testSlice{1, 5})
verify(ts[3:5], testSlice{5, 1})
verify(ts[0:2], testSlice{4, 2})
verify(ts[0:2], testSlice{2, 4})
verify(ts[1:4], testSlice{3, 5, 4})
verify(ts[1:4], testSlice{5, 4, 3})
verify(ts[0:4], testSlice{4, 5, 2, 3})
verify(ts[0:4], testSlice{2, 4, 3, 5})

verify(ts, testSlice{1, 3, 4, 2, 5})
verify(ts[3:5], testSlice{4, 2})
verify(ts[3:5], testSlice{2, 4})
verify(ts[0:2], testSlice{1, 3})
verify(ts[0:2], testSlice{1, 3})
verify(ts[1:4], testSlice{5, 2, 3})
verify(ts[1:4], testSlice{3, 5, 2})
verify(ts[0:4], testSlice{2, 3, 1, 5})
verify(ts[0:4], testSlice{5, 3, 1, 2})

verify(ts, testSlice{2, 3, 1, 5, 4})
}

type ints []int

func (i ints) Len() int { return len(i) }
func (i ints) Swap(a, b int) { i[a], i[b] = i[b], i[a] }

// BenchmarkConcurrentShuffle is used to demonstrate that the Shuffle
// function scales with cores. Once upon a time, it did not.
func BenchmarkConcurrentShuffle(b *testing.B) {
for _, concurrency := range []int{1, 4, 8} {
b.Run(fmt.Sprintf("concurrency=%d", concurrency), func(b *testing.B) {
for _, size := range []int{1 << 7, 1 << 10, 1 << 13} {
b.Run(fmt.Sprintf("size=%d", size), func(b *testing.B) {
b.SetBytes(int64(size * int(unsafe.Sizeof(0))))
bufs := make([]ints, 0, concurrency)
for i := 0; i < concurrency; i++ {
bufs = append(bufs, rand.Perm(size))
}
ns := distribute(b.N, concurrency)
var wg sync.WaitGroup
wg.Add(concurrency)
b.ResetTimer()
for i := 0; i < concurrency; i++ {
go func(buf *ints, n int) {
defer wg.Done()
for j := 0; j < n; j++ {
Shuffle(buf)
}
}(&bufs[i], ns[i])
}
wg.Wait()
})
}
})
}
}

// distribute returns a slice of <num> integers that add up to <total> and are
// within +/-1 of each other.
func distribute(total, num int) []int {
res := make([]int, num)
for i := range res {
// Use the average number of remaining connections.
div := len(res) - i
res[i] = (total + div/2) / div
total -= res[i]
}
return res
}

0 comments on commit b8a318a

Please sign in to comment.