Skip to content

Commit

Permalink
[dbnode] Client borrow connection API (#3019)
Browse files Browse the repository at this point in the history
  • Loading branch information
robskillington authored Dec 22, 2020
1 parent 2752154 commit 31a21e3
Show file tree
Hide file tree
Showing 16 changed files with 250 additions and 106 deletions.
43 changes: 37 additions & 6 deletions src/dbnode/client/client_mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 16 additions & 10 deletions src/dbnode/client/connection_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ import (
"sync/atomic"
"time"

"github.com/m3db/m3/src/dbnode/generated/thrift/rpc"
"github.com/m3db/m3/src/dbnode/topology"
xresource "github.com/m3db/m3/src/x/resource"
murmur3 "github.com/m3db/stackmurmur3/v2"

"github.com/uber-go/tally"
"github.com/uber/tchannel-go"
"github.com/uber/tchannel-go/thrift"
"go.uber.org/zap"

"github.com/m3db/m3/src/dbnode/generated/thrift/rpc"
"github.com/m3db/m3/src/dbnode/topology"
)

const (
Expand Down Expand Up @@ -67,15 +67,21 @@ type connPool struct {
healthStatus tally.Gauge
}

// PooledChannel is a tchannel.Channel for a pooled connection.
type PooledChannel interface {
GetSubChannel(serviceName string, opts ...tchannel.SubChannelOption) *tchannel.SubChannel
Close()
}

type conn struct {
channel xresource.SimpleCloser
channel PooledChannel
client rpc.TChanNode
}

// NewConnectionFn is a function that creates a connection.
type NewConnectionFn func(
channelName string, addr string, opts Options,
) (xresource.SimpleCloser, rpc.TChanNode, error)
) (PooledChannel, rpc.TChanNode, error)

type healthCheckFn func(client rpc.TChanNode, opts Options) error

Expand Down Expand Up @@ -134,20 +140,20 @@ func (p *connPool) ConnectionCount() int {
return int(poolLen)
}

func (p *connPool) NextClient() (rpc.TChanNode, error) {
func (p *connPool) NextClient() (rpc.TChanNode, PooledChannel, error) {
p.RLock()
if p.status != statusOpen {
p.RUnlock()
return nil, errConnectionPoolClosed
return nil, nil, errConnectionPoolClosed
}
if p.poolLen < 1 {
p.RUnlock()
return nil, errConnectionPoolHasNoConnections
return nil, nil, errConnectionPoolHasNoConnections
}
n := atomic.AddInt64(&p.used, 1)
conn := p.pool[n%p.poolLen]
p.RUnlock()
return conn.client, nil
return conn.client, conn.channel, nil
}

func (p *connPool) Close() {
Expand Down
39 changes: 22 additions & 17 deletions src/dbnode/client/connection_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ import (
"github.com/m3db/m3/src/dbnode/generated/thrift/rpc"
"github.com/m3db/m3/src/dbnode/topology"
xclock "github.com/m3db/m3/src/x/clock"
xresource "github.com/m3db/m3/src/x/resource"
"github.com/stretchr/testify/require"

"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"github.com/uber/tchannel-go"
)

const (
Expand All @@ -42,10 +42,19 @@ const (
)

var (
h = topology.NewHost(testHostStr, testHostAddr)
channelNone = &nullChannel{}
h = topology.NewHost(testHostStr, testHostAddr)
)

type noopPooledChannel struct{}

func (c *noopPooledChannel) Close() {}
func (c *noopPooledChannel) GetSubChannel(
serviceName string,
opts ...tchannel.SubChannelOption,
) *tchannel.SubChannel {
return nil
}

func newConnectionPoolTestOptions() Options {
return newSessionTestOptions().
SetBackgroundConnectInterval(5 * time.Millisecond).
Expand Down Expand Up @@ -85,12 +94,12 @@ func TestConnectionPoolConnectsAndRetriesConnects(t *testing.T) {

fn := func(
ch string, addr string, opts Options,
) (xresource.SimpleCloser, rpc.TChanNode, error) {
) (PooledChannel, rpc.TChanNode, error) {
attempt := int(atomic.AddInt32(&attempts, 1))
if attempt == 1 {
return nil, nil, fmt.Errorf("a connect error")
}
return channelNone, nil, nil
return &noopPooledChannel{}, nil, nil
}

opts = opts.SetNewConnectionFn(fn)
Expand Down Expand Up @@ -151,7 +160,7 @@ func TestConnectionPoolConnectsAndRetriesConnects(t *testing.T) {
conns.Close()
doneWg.Done()

nextClient, err := conns.NextClient()
nextClient, _, err := conns.NextClient()
require.Nil(t, nextClient)
require.Equal(t, errConnectionPoolClosed, err)
}
Expand Down Expand Up @@ -237,12 +246,12 @@ func TestConnectionPoolHealthChecks(t *testing.T) {

fn := func(
ch string, addr string, opts Options,
) (xresource.SimpleCloser, rpc.TChanNode, error) {
) (PooledChannel, rpc.TChanNode, error) {
attempt := atomic.AddInt32(&newConnAttempt, 1)
if attempt == 1 {
return channelNone, client1, nil
return &noopPooledChannel{}, client1, nil
} else if attempt == 2 {
return channelNone, client2, nil
return &noopPooledChannel{}, client2, nil
}
return nil, nil, fmt.Errorf("spawning only 2 connections")
}
Expand Down Expand Up @@ -307,7 +316,7 @@ func TestConnectionPoolHealthChecks(t *testing.T) {
return conns.ConnectionCount() == 1
}, 5*time.Second)
for i := 0; i < 2; i++ {
nextClient, err := conns.NextClient()
nextClient, _, err := conns.NextClient()
require.NoError(t, err)
require.Equal(t, client2, nextClient)
}
Expand All @@ -324,17 +333,13 @@ func TestConnectionPoolHealthChecks(t *testing.T) {
// and the connection actually being removed.
return conns.ConnectionCount() == 0
}, 5*time.Second)
nextClient, err := conns.NextClient()
nextClient, _, err := conns.NextClient()
require.Nil(t, nextClient)
require.Equal(t, errConnectionPoolHasNoConnections, err)

conns.Close()

nextClient, err = conns.NextClient()
nextClient, _, err = conns.NextClient()
require.Nil(t, nextClient)
require.Equal(t, errConnectionPoolClosed, err)
}

type nullChannel struct{}

func (*nullChannel) Close() {}
24 changes: 12 additions & 12 deletions src/dbnode/client/host_queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ func (q *queue) asyncTaggedWrite(
// NB(bl): host is passed to writeState to determine the state of the
// shard on the node we're writing to

client, err := q.connPool.NextClient()
client, _, err := q.connPool.NextClient()
if err != nil {
// No client available
callAllCompletionFns(ops, q.host, err)
Expand Down Expand Up @@ -591,7 +591,7 @@ func (q *queue) asyncTaggedWriteV2(

// NB(bl): host is passed to writeState to determine the state of the
// shard on the node we're writing to.
client, err := q.connPool.NextClient()
client, _, err := q.connPool.NextClient()
if err != nil {
// No client available
callAllCompletionFns(ops, q.host, err)
Expand Down Expand Up @@ -656,7 +656,7 @@ func (q *queue) asyncWrite(
// NB(bl): host is passed to writeState to determine the state of the
// shard on the node we're writing to

client, err := q.connPool.NextClient()
client, _, err := q.connPool.NextClient()
if err != nil {
// No client available
callAllCompletionFns(ops, q.host, err)
Expand Down Expand Up @@ -715,7 +715,7 @@ func (q *queue) asyncWriteV2(

// NB(bl): host is passed to writeState to determine the state of the
// shard on the node we're writing to.
client, err := q.connPool.NextClient()
client, _, err := q.connPool.NextClient()
if err != nil {
// No client available.
callAllCompletionFns(ops, q.host, err)
Expand Down Expand Up @@ -768,7 +768,7 @@ func (q *queue) asyncFetch(op *fetchBatchOp) {
q.Done()
}

client, err := q.connPool.NextClient()
client, _, err := q.connPool.NextClient()
if err != nil {
// No client available
op.completeAll(nil, err)
Expand Down Expand Up @@ -821,7 +821,7 @@ func (q *queue) asyncFetchV2(
q.Done()
}

client, err := q.connPool.NextClient()
client, _, err := q.connPool.NextClient()
if err != nil {
// No client available.
callAllCompletionFns(ops, nil, err)
Expand Down Expand Up @@ -868,7 +868,7 @@ func (q *queue) asyncFetchTagged(op *fetchTaggedOp) {
q.Done()
}

client, err := q.connPool.NextClient()
client, _, err := q.connPool.NextClient()
if err != nil {
// No client available
op.CompletionFn()(fetchTaggedResultAccumulatorOpts{host: q.host}, err)
Expand Down Expand Up @@ -901,7 +901,7 @@ func (q *queue) asyncAggregate(op *aggregateOp) {
q.Done()
}

client, err := q.connPool.NextClient()
client, _, err := q.connPool.NextClient()
if err != nil {
// No client available
op.CompletionFn()(aggregateResultAccumulatorOpts{host: q.host}, err)
Expand Down Expand Up @@ -931,7 +931,7 @@ func (q *queue) asyncTruncate(op *truncateOp) {
q.workerPool.Go(func() {
cleanup := q.Done

client, err := q.connPool.NextClient()
client, _, err := q.connPool.NextClient()
if err != nil {
// No client available
op.completionFn(nil, err)
Expand Down Expand Up @@ -1003,7 +1003,7 @@ func (q *queue) ConnectionPool() connectionPool {
return q.connPool
}

func (q *queue) BorrowConnection(fn withConnectionFn) error {
func (q *queue) BorrowConnection(fn WithConnectionFn) error {
q.RLock()
if q.status != statusOpen {
q.RUnlock()
Expand All @@ -1014,12 +1014,12 @@ func (q *queue) BorrowConnection(fn withConnectionFn) error {
defer q.Done()
q.RUnlock()

conn, err := q.connPool.NextClient()
conn, ch, err := q.connPool.NextClient()
if err != nil {
return err
}

fn(conn)
fn(conn, ch)
return nil
}

Expand Down
Loading

0 comments on commit 31a21e3

Please sign in to comment.