From 197b589335fef77467fe14475b2eeded1eb7bfd7 Mon Sep 17 00:00:00 2001 From: Kuba Kaflik Date: Tue, 10 Jan 2023 19:48:41 +0100 Subject: [PATCH] Expose DialStrategy function to user for custom connection routing (#855) * Expose DialStrategy function to user for custom connection routing * fix changes * test: example of usage * test: example of usage * do not expose connect type * fix tests * remove dial strategy from docs * fix test --- clickhouse.go | 35 +++++++++++++++++++++++++++++------ clickhouse_options.go | 6 ++++++ tests/conn_test.go | 20 ++++++++++++++++++++ 3 files changed, 55 insertions(+), 6 deletions(-) diff --git a/clickhouse.go b/clickhouse.go index e397fa1093..5523146245 100644 --- a/clickhouse.go +++ b/clickhouse.go @@ -195,22 +195,45 @@ func (ch *clickhouse) Stats() driver.Stats { func (ch *clickhouse) dial(ctx context.Context) (conn *connect, err error) { connID := int(atomic.AddInt64(&ch.connID, 1)) - for i := range ch.opt.Addr { + + dialFunc := func(ctx context.Context, addr string, opt *Options) (DialResult, error) { + conn, err := dial(ctx, addr, connID, opt) + + return DialResult{conn}, err + } + + dialStrategy := DefaultDialStrategy + if ch.opt.DialStrategy != nil { + dialStrategy = ch.opt.DialStrategy + } + + result, err := dialStrategy(ctx, connID, ch.opt, dialFunc) + if err != nil { + return nil, err + } + return result.conn, nil +} + +func DefaultDialStrategy(ctx context.Context, connID int, opt *Options, dial Dial) (r DialResult, err error) { + for i := range opt.Addr { var num int - switch ch.opt.ConnOpenStrategy { + switch opt.ConnOpenStrategy { case ConnOpenInOrder: num = i case ConnOpenRoundRobin: - num = (int(connID) + i) % len(ch.opt.Addr) + num = (int(connID) + i) % len(opt.Addr) } - if conn, err = dial(ctx, ch.opt.Addr[num], connID, ch.opt); err == nil { - return conn, nil + + if r, err = dial(ctx, opt.Addr[num], opt); err == nil { + return r, nil } } + if err == nil { err = ErrAcquireConnNoAddress } - return nil, err + + return r, err } func (ch *clickhouse) acquire(ctx context.Context) (conn *connect, err error) { diff --git a/clickhouse_options.go b/clickhouse_options.go index 8a148113ac..e34e933427 100644 --- a/clickhouse_options.go +++ b/clickhouse_options.go @@ -115,6 +115,11 @@ func ParseDSN(dsn string) (*Options, error) { return opt, nil } +type Dial func(ctx context.Context, addr string, opt *Options) (DialResult, error) +type DialResult struct { + conn *connect +} + type Options struct { Protocol Protocol @@ -122,6 +127,7 @@ type Options struct { Addr []string Auth Auth DialContext func(ctx context.Context, addr string) (net.Conn, error) + DialStrategy func(ctx context.Context, connID int, options *Options, dial Dial) (DialResult, error) Debug bool Debugf func(format string, v ...interface{}) // only works when Debug is true Settings Settings diff --git a/tests/conn_test.go b/tests/conn_test.go index 0422815ca3..a5600955a7 100644 --- a/tests/conn_test.go +++ b/tests/conn_test.go @@ -242,3 +242,23 @@ func TestBlockBufferSize(t *testing.T) { } require.Equal(t, 10000000, i) } + +func TestConnCustomDialStrategy(t *testing.T) { + env, err := GetTestEnvironment(testSet) + require.NoError(t, err) + + opts := clientOptionsFromEnv(env, clickhouse.Settings{}) + validAddr := opts.Addr[0] + opts.Addr = []string{"invalid.host:9001"} + + opts.DialStrategy = func(ctx context.Context, connID int, opts *clickhouse.Options, dial clickhouse.Dial) (clickhouse.DialResult, error) { + return dial(ctx, validAddr, opts) + } + + conn, err := clickhouse.Open(&opts) + require.NoError(t, err) + + require.NoError(t, err) + require.NoError(t, conn.Ping(context.Background())) + require.NoError(t, conn.Close()) +}