Skip to content

Commit

Permalink
Expose DialStrategy function to user for custom connection routing (#855
Browse files Browse the repository at this point in the history
)

* Expose DialStrategy function to user for custom connection routing

* fix changes

* test: example of usage

* test: example of usage

* do not expose connect type

* fix tests

* remove dial strategy from docs

* fix test
  • Loading branch information
jkaflik authored Jan 10, 2023
1 parent 98d0eae commit 197b589
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 6 deletions.
35 changes: 29 additions & 6 deletions clickhouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,22 +195,45 @@ func (ch *clickhouse) Stats() driver.Stats {

func (ch *clickhouse) dial(ctx context.Context) (conn *connect, err error) {
connID := int(atomic.AddInt64(&ch.connID, 1))
for i := range ch.opt.Addr {

dialFunc := func(ctx context.Context, addr string, opt *Options) (DialResult, error) {
conn, err := dial(ctx, addr, connID, opt)

return DialResult{conn}, err
}

dialStrategy := DefaultDialStrategy
if ch.opt.DialStrategy != nil {
dialStrategy = ch.opt.DialStrategy
}

result, err := dialStrategy(ctx, connID, ch.opt, dialFunc)
if err != nil {
return nil, err
}
return result.conn, nil
}

func DefaultDialStrategy(ctx context.Context, connID int, opt *Options, dial Dial) (r DialResult, err error) {
for i := range opt.Addr {
var num int
switch ch.opt.ConnOpenStrategy {
switch opt.ConnOpenStrategy {
case ConnOpenInOrder:
num = i
case ConnOpenRoundRobin:
num = (int(connID) + i) % len(ch.opt.Addr)
num = (int(connID) + i) % len(opt.Addr)
}
if conn, err = dial(ctx, ch.opt.Addr[num], connID, ch.opt); err == nil {
return conn, nil

if r, err = dial(ctx, opt.Addr[num], opt); err == nil {
return r, nil
}
}

if err == nil {
err = ErrAcquireConnNoAddress
}
return nil, err

return r, err
}

func (ch *clickhouse) acquire(ctx context.Context) (conn *connect, err error) {
Expand Down
6 changes: 6 additions & 0 deletions clickhouse_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,19 @@ func ParseDSN(dsn string) (*Options, error) {
return opt, nil
}

type Dial func(ctx context.Context, addr string, opt *Options) (DialResult, error)
type DialResult struct {
conn *connect
}

type Options struct {
Protocol Protocol

TLS *tls.Config
Addr []string
Auth Auth
DialContext func(ctx context.Context, addr string) (net.Conn, error)
DialStrategy func(ctx context.Context, connID int, options *Options, dial Dial) (DialResult, error)
Debug bool
Debugf func(format string, v ...interface{}) // only works when Debug is true
Settings Settings
Expand Down
20 changes: 20 additions & 0 deletions tests/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,23 @@ func TestBlockBufferSize(t *testing.T) {
}
require.Equal(t, 10000000, i)
}

func TestConnCustomDialStrategy(t *testing.T) {
env, err := GetTestEnvironment(testSet)
require.NoError(t, err)

opts := clientOptionsFromEnv(env, clickhouse.Settings{})
validAddr := opts.Addr[0]
opts.Addr = []string{"invalid.host:9001"}

opts.DialStrategy = func(ctx context.Context, connID int, opts *clickhouse.Options, dial clickhouse.Dial) (clickhouse.DialResult, error) {
return dial(ctx, validAddr, opts)
}

conn, err := clickhouse.Open(&opts)
require.NoError(t, err)

require.NoError(t, err)
require.NoError(t, conn.Ping(context.Background()))
require.NoError(t, conn.Close())
}

0 comments on commit 197b589

Please sign in to comment.