Skip to content

Commit

Permalink
[DBNode Client] Fix race in V2 batching APIs + integration test (#1991)
Browse files Browse the repository at this point in the history
  • Loading branch information
Richard Artoul authored Oct 11, 2019
1 parent 57eba0c commit 5a830ae
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 16 deletions.
27 changes: 19 additions & 8 deletions src/dbnode/client/host_queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,13 @@ func (q *queue) drainWriteOpV2(
currV2WriteReq.NameSpaces = append(currV2WriteReq.NameSpaces, namespace.Bytes())
nsIdx = len(currV2WriteReq.NameSpaces) - 1
}
v.requestV2.NameSpace = int64(nsIdx)
currV2WriteReq.Elements = append(currV2WriteReq.Elements, &v.requestV2)

// Copy the request because operations are shared across multiple host queues so mutating
// them directly is racey.
// TODO(rartoul): Consider adding a pool for this.
requestCopy := v.requestV2
requestCopy.NameSpace = int64(nsIdx)
currV2WriteReq.Elements = append(currV2WriteReq.Elements, &requestCopy)
currV2WriteOps = append(currV2WriteOps, op)
if len(currV2WriteReq.Elements) == q.opts.WriteBatchSize() {
// Reached write batch limit, write async and reset.
Expand Down Expand Up @@ -448,8 +453,13 @@ func (q *queue) drainTaggedWriteOpV2(
currV2WriteTaggedReq.NameSpaces = append(currV2WriteTaggedReq.NameSpaces, namespace.Bytes())
nsIdx = len(currV2WriteTaggedReq.NameSpaces) - 1
}
v.requestV2.NameSpace = int64(nsIdx)
currV2WriteTaggedReq.Elements = append(currV2WriteTaggedReq.Elements, &v.requestV2)

// Copy the request because operations are shared across multiple host queues so mutating
// them directly is racey.
// TODO(rartoul): Consider adding a pool for this.
requestCopy := v.requestV2
requestCopy.NameSpace = int64(nsIdx)
currV2WriteTaggedReq.Elements = append(currV2WriteTaggedReq.Elements, &requestCopy)
currV2WriteTaggedOps = append(currV2WriteTaggedOps, op)
if len(currV2WriteTaggedReq.Elements) == q.opts.WriteBatchSize() {
// Reached write batch limit, write async and reset.
Expand Down Expand Up @@ -485,6 +495,7 @@ func (q *queue) drainFetchBatchRawV2Op(
nsIdx = len(currV2FetchBatchRawReq.NameSpaces) - 1
}
for i := range v.requestV2Elements {
// Each host queue gets its own fetchBatchOp so mutating the NameSpace field here is safe.
v.requestV2Elements[i].NameSpace = int64(nsIdx)
currV2FetchBatchRawReq.Elements = append(currV2FetchBatchRawReq.Elements, &v.requestV2Elements[i])
}
Expand Down Expand Up @@ -515,8 +526,8 @@ func (q *queue) asyncTaggedWrite(

// NB(r): Defer is slow in the hot path unfortunately
cleanup := func() {
q.writeTaggedBatchRawRequestPool.Put(req)
q.writeTaggedBatchRawRequestElementArrayPool.Put(elems)
q.writeTaggedBatchRawRequestPool.Put(req)
q.opsArrayPool.Put(ops)
q.Done()
}
Expand Down Expand Up @@ -576,8 +587,8 @@ func (q *queue) asyncTaggedWriteV2(
q.workerPool.Go(func() {
// NB(r): Defer is slow in the hot path unfortunately
cleanup := func() {
q.writeTaggedBatchRawV2RequestPool.Put(req)
q.writeTaggedBatchRawV2RequestElementArrayPool.Put(req.Elements)
q.writeTaggedBatchRawV2RequestPool.Put(req)
q.opsArrayPool.Put(ops)
q.Done()
}
Expand Down Expand Up @@ -640,8 +651,8 @@ func (q *queue) asyncWrite(

// NB(r): Defer is slow in the hot path unfortunately
cleanup := func() {
q.writeBatchRawRequestPool.Put(req)
q.writeBatchRawRequestElementArrayPool.Put(elems)
q.writeBatchRawRequestPool.Put(req)
q.opsArrayPool.Put(ops)
q.Done()
}
Expand Down Expand Up @@ -700,8 +711,8 @@ func (q *queue) asyncWriteV2(
q.workerPool.Go(func() {
// NB(r): Defer is slow in the hot path unfortunately
cleanup := func() {
q.writeBatchRawV2RequestPool.Put(req)
q.writeBatchRawV2RequestElementArrayPool.Put(req.Elements)
q.writeBatchRawV2RequestPool.Put(req)
q.opsArrayPool.Put(ops)
q.Done()
}
Expand Down
2 changes: 2 additions & 0 deletions src/dbnode/client/host_queue_write_batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ func TestHostQueueWriteBatchesDifferentNamespaces(t *testing.T) {

if opts.UseV2BatchAPIs() {
writeBatch := func(ctx thrift.Context, req *rpc.WriteBatchRawV2Request) {
assert.Equal(t, 2, len(req.NameSpaces))
assert.Equal(t, len(writes), len(req.Elements))
for i, write := range writes {
if i < 3 {
assert.Equal(t, req.Elements[i].NameSpace, int64(0))
Expand Down
17 changes: 9 additions & 8 deletions src/dbnode/integration/fetch_tagged_quorum_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,15 +249,15 @@ func makeMultiNodeSetup(
asyncInserts bool,
instances []services.ServiceInstance,
) (testSetups, closeFn, client.Options) {
var (
nsOpts = namespace.NewOptions()
md, err = namespace.NewMetadata(testNamespaces[0],
nsOpts.SetRetentionOptions(nsOpts.RetentionOptions().SetRetentionPeriod(6*time.Hour)).
SetIndexOptions(namespace.NewIndexOptions().SetEnabled(indexingEnabled)))
)
nsOpts := namespace.NewOptions()
nsOpts = nsOpts.SetRetentionOptions(nsOpts.RetentionOptions().SetRetentionPeriod(6 * time.Hour)).
SetIndexOptions(namespace.NewIndexOptions().SetEnabled(indexingEnabled))
md1, err := namespace.NewMetadata(testNamespaces[0], nsOpts)
require.NoError(t, err)
md2, err := namespace.NewMetadata(testNamespaces[1], nsOpts)
require.NoError(t, err)

nspaces := []namespace.Metadata{md}
nspaces := []namespace.Metadata{md1, md2}
nodes, topoInit, closeFn := newNodes(t, numShards, instances, nspaces, asyncInserts)
for _, node := range nodes {
node.opts = node.opts.SetNumShards(numShards)
Expand All @@ -268,7 +268,8 @@ func makeMultiNodeSetup(
SetClusterConnectTimeout(2 * time.Second).
SetWriteRequestTimeout(2 * time.Second).
SetFetchRequestTimeout(2 * time.Second).
SetTopologyInitializer(topoInit)
SetTopologyInitializer(topoInit).
SetUseV2BatchAPIs(true)

return nodes, closeFn, clientopts
}
Expand Down
159 changes: 159 additions & 0 deletions src/dbnode/integration/write_read_high_concurrency_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
// +build integration
//
// Copyright (c) 2019 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

package integration

import (
"fmt"
"sync"
"testing"
"time"

"github.com/m3db/m3/src/cluster/services"
"github.com/m3db/m3/src/cluster/shard"
"github.com/m3db/m3/src/dbnode/client"
"github.com/m3db/m3/src/dbnode/topology"
xclock "github.com/m3db/m3/src/x/clock"
"github.com/m3db/m3/src/x/ident"
xtime "github.com/m3db/m3/src/x/time"

"github.com/stretchr/testify/require"
"go.uber.org/zap"
)

// TestWriteReadHighConcurrencyTestMultiNS stress tests the conccurent write and read pathways in M3DB by spinning
// up 100s of gorotuines that all write/read to M3DB. It was added as a regression test to catch bugs in the M3DB
// client batching logic and lifecycles, but it is useful for detecting various kinds of concurrency issues at the
// integration level.
func TestWriteReadHighConcurrencyTestMultiNS(t *testing.T) {
if testing.Short() {
t.SkipNow() // Just skip if we're doing a short run
}
var (
concurrency = 100
writeEach = 1000
numShards = defaultNumShards
minShard = uint32(0)
maxShard = uint32(numShards - 1)
)
nodes, closeFn, clientopts := makeMultiNodeSetup(t, numShards, true, true, []services.ServiceInstance{
node(t, 0, newClusterShardsRange(minShard, maxShard, shard.Available)),
node(t, 1, newClusterShardsRange(minShard, maxShard, shard.Available)),
node(t, 2, newClusterShardsRange(minShard, maxShard, shard.Available)),
})
clientopts = clientopts.
SetWriteConsistencyLevel(topology.ConsistencyLevelAll).
SetReadConsistencyLevel(topology.ReadConsistencyLevelAll)

defer closeFn()
log := nodes[0].storageOpts.InstrumentOptions().Logger()
for _, n := range nodes {
require.NoError(t, n.startServer())
}

c, err := client.NewClient(clientopts)
require.NoError(t, err)
session, err := c.NewSession()
require.NoError(t, err)
defer session.Close()

var (
insertWg sync.WaitGroup
)
now := nodes[0].db.Options().ClockOptions().NowFn()()
start := time.Now()
log.Info("starting data write")

newNs1GenIDs := func(idx int) func(j int) ident.ID {
return func(j int) ident.ID {
id, _ := genIDTags(idx, j, 0)
return id
}
}
newNs2GenIDs := func(idx int) func(j int) ident.ID {
return func(j int) ident.ID {
id, _ := genIDTags(concurrency+idx, writeEach+j, 0)
return id
}
}
for i := 0; i < concurrency; i++ {
insertWg.Add(2)
idx := i
ns1GenIDs := newNs1GenIDs(idx)
ns2GenIDs := newNs2GenIDs(idx)
go func() {
defer insertWg.Done()
for j := 0; j < writeEach; j++ {
id := ns1GenIDs(j)
err := session.Write(testNamespaces[0], id, now, float64(1.0), xtime.Second, nil)
if err != nil {
panic(err)
}
}
}()
go func() {
defer insertWg.Done()
for j := 0; j < writeEach; j++ {
id := ns2GenIDs(j)
err := session.Write(testNamespaces[1], id, now, float64(1.0), xtime.Second, nil)
if err != nil {
panic(err)
}
}
}()
}

insertWg.Wait()
log.Info("test data written", zap.Duration("took", time.Since(start)))

var (
fetchWg sync.WaitGroup
)
for i := 0; i < concurrency; i++ {
fetchWg.Add(2)
idx := i
verify := func(genID func(j int) ident.ID, ns ident.ID) {
defer fetchWg.Done()
for j := 0; j < writeEach; j++ {
id := genID(j)
found := xclock.WaitUntil(func() bool {
iter, err := session.Fetch(ns, id, now.Add(-time.Hour), now.Add(time.Hour))
if err != nil {
panic(err)
}
if !iter.Next() {
return false
}
return true

}, 10*time.Second)
if !found {
panic(fmt.Sprintf("timed out waiting to fetch id: %s", id))
}
}
}
go verify(newNs1GenIDs(idx), testNamespaces[0])
go verify(newNs2GenIDs(idx), testNamespaces[1])
}
fetchWg.Wait()
log.Info("data is readable", zap.Duration("took", time.Since(start)))
}

0 comments on commit 5a830ae

Please sign in to comment.