From b1a37584d5a5da4e7d6cfe91fab900f094603199 Mon Sep 17 00:00:00 2001 From: Kuba Kaflik Date: Sat, 17 Dec 2022 00:05:11 +0100 Subject: [PATCH 1/8] Expose DialStrategy function to user for custom connection routing --- clickhouse.go | 19 ++++++++++++++----- clickhouse_options.go | 3 ++- clickhouse_options_test.go | 2 +- clickhouse_std.go | 2 +- conn.go | 6 +++--- 5 files changed, 21 insertions(+), 11 deletions(-) diff --git a/clickhouse.go b/clickhouse.go index e397fa1093..e323d584a8 100644 --- a/clickhouse.go +++ b/clickhouse.go @@ -43,7 +43,7 @@ type ( var ( ErrBatchInvalid = errors.New("clickhouse: batch is invalid. check appended data is correct") ErrBatchAlreadySent = errors.New("clickhouse: batch has already been sent") - ErrAcquireConnTimeout = errors.New("clickhouse: acquire conn timeout. you can increase the number of max open conn or the dial timeout") + ErrAcquireConnTimeout = errors.New("clickhouse: acquire conn timeout. you can increase the number of max open conn or the Dial timeout") ErrUnsupportedServerRevision = errors.New("clickhouse: unsupported server revision") ErrBindMixedParamsFormats = errors.New("clickhouse [bind]: mixed named, numeric or positional parameters") ErrAcquireConnNoAddress = errors.New("clickhouse: no valid address supplied") @@ -195,15 +195,24 @@ func (ch *clickhouse) Stats() driver.Stats { func (ch *clickhouse) dial(ctx context.Context) (conn *connect, err error) { connID := int(atomic.AddInt64(&ch.connID, 1)) - for i := range ch.opt.Addr { + + if ch.opt.DialStrategy != nil { + return ch.opt.DialStrategy(ctx, ch.opt, connID) + } + + return ch.defaultDialStrategy(ctx, ch.opt, connID) +} + +func (ch *clickhouse) defaultDialStrategy(ctx context.Context, opt *Options, connID int) (conn *connect, err error) { + for i := range opt.Addr { var num int - switch ch.opt.ConnOpenStrategy { + switch opt.ConnOpenStrategy { case ConnOpenInOrder: num = i case ConnOpenRoundRobin: - num = (int(connID) + i) % len(ch.opt.Addr) + num = (int(connID) + i) % len(opt.Addr) } - if conn, err = dial(ctx, ch.opt.Addr[num], connID, ch.opt); err == nil { + if conn, err = Dial(ctx, opt.Addr[num], connID, ch.opt); err == nil { return conn, nil } } diff --git a/clickhouse_options.go b/clickhouse_options.go index 7537257d3b..6cb2e0fca4 100644 --- a/clickhouse_options.go +++ b/clickhouse_options.go @@ -121,6 +121,7 @@ type Options struct { Addr []string Auth Auth DialContext func(ctx context.Context, addr string) (net.Conn, error) + DialStrategy func(ctx context.Context, options *Options, ID int) (*connect, error) Debug bool Debugf func(format string, v ...interface{}) // only works when Debug is true Settings Settings @@ -210,7 +211,7 @@ func (o *Options) fromDSN(in string) error { case "dial_timeout": duration, err := time.ParseDuration(params.Get(v)) if err != nil { - return fmt.Errorf("clickhouse [dsn parse]: dial timeout: %s", err) + return fmt.Errorf("clickhouse [dsn parse]: Dial timeout: %s", err) } o.DialTimeout = duration case "block_buffer_size": diff --git a/clickhouse_options_test.go b/clickhouse_options_test.go index 3c65d0deed..5552ad5f6d 100644 --- a/clickhouse_options_test.go +++ b/clickhouse_options_test.go @@ -307,7 +307,7 @@ func TestParseDSN(t *testing.T) { "compress_level invalid value: strconv.ParseInt: parsing \"first\": invalid syntax", }, { - "native protocol dial timeout", + "native protocol Dial timeout", "clickhouse://127.0.0.1/test_database?max_compression_buffer=1024", &Options{ Protocol: Native, diff --git a/clickhouse_std.go b/clickhouse_std.go index 18d7df8115..64a45b1fbe 100644 --- a/clickhouse_std.go +++ b/clickhouse_std.go @@ -60,7 +60,7 @@ func (o *stdConnOpener) Connect(ctx context.Context) (_ driver.Conn, err error) } default: dialFunc = func(ctx context.Context, addr string, num int, opt *Options) (stdConnect, error) { - return dial(ctx, addr, num, opt) + return Dial(ctx, addr, num, opt) } } diff --git a/conn.go b/conn.go index fa557634be..2dcb119cd3 100644 --- a/conn.go +++ b/conn.go @@ -33,7 +33,7 @@ import ( "github.com/ClickHouse/clickhouse-go/v2/lib/proto" ) -func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, error) { +func Dial(ctx context.Context, addr string, num int, opt *Options) (*connect, error) { var ( err error conn net.Conn @@ -72,7 +72,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er var ( connect = &connect{ - id: num, + id: num, opt: opt, conn: conn, debugf: debugf, @@ -103,7 +103,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er // https://github.com/ClickHouse/ClickHouse/blob/master/src/Client/Connection.cpp type connect struct { - id int + id int opt *Options conn net.Conn debugf func(format string, v ...interface{}) From 2256d43b3b1068b5bc8a07d3942af803fe07b6f7 Mon Sep 17 00:00:00 2001 From: Kuba Kaflik Date: Sat, 17 Dec 2022 00:14:25 +0100 Subject: [PATCH 2/8] fix changes --- clickhouse.go | 2 +- clickhouse_options.go | 2 +- clickhouse_options_test.go | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/clickhouse.go b/clickhouse.go index e323d584a8..848dca08e0 100644 --- a/clickhouse.go +++ b/clickhouse.go @@ -43,7 +43,7 @@ type ( var ( ErrBatchInvalid = errors.New("clickhouse: batch is invalid. check appended data is correct") ErrBatchAlreadySent = errors.New("clickhouse: batch has already been sent") - ErrAcquireConnTimeout = errors.New("clickhouse: acquire conn timeout. you can increase the number of max open conn or the Dial timeout") + ErrAcquireConnTimeout = errors.New("clickhouse: acquire conn timeout. you can increase the number of max open conn or the dial timeout") ErrUnsupportedServerRevision = errors.New("clickhouse: unsupported server revision") ErrBindMixedParamsFormats = errors.New("clickhouse [bind]: mixed named, numeric or positional parameters") ErrAcquireConnNoAddress = errors.New("clickhouse: no valid address supplied") diff --git a/clickhouse_options.go b/clickhouse_options.go index 6cb2e0fca4..90824b8499 100644 --- a/clickhouse_options.go +++ b/clickhouse_options.go @@ -211,7 +211,7 @@ func (o *Options) fromDSN(in string) error { case "dial_timeout": duration, err := time.ParseDuration(params.Get(v)) if err != nil { - return fmt.Errorf("clickhouse [dsn parse]: Dial timeout: %s", err) + return fmt.Errorf("clickhouse [dsn parse]: dial timeout: %s", err) } o.DialTimeout = duration case "block_buffer_size": diff --git a/clickhouse_options_test.go b/clickhouse_options_test.go index 5552ad5f6d..3c65d0deed 100644 --- a/clickhouse_options_test.go +++ b/clickhouse_options_test.go @@ -307,7 +307,7 @@ func TestParseDSN(t *testing.T) { "compress_level invalid value: strconv.ParseInt: parsing \"first\": invalid syntax", }, { - "native protocol Dial timeout", + "native protocol dial timeout", "clickhouse://127.0.0.1/test_database?max_compression_buffer=1024", &Options{ Protocol: Native, From cadd1c98c1e8c6f2b0234f8d504b4fe6085a7ab3 Mon Sep 17 00:00:00 2001 From: Kuba Kaflik Date: Thu, 22 Dec 2022 17:53:43 +0100 Subject: [PATCH 3/8] test: example of usage --- clickhouse.go | 20 ++++++++++---------- clickhouse_options.go | 2 +- clickhouse_std.go | 10 +++++----- conn.go | 24 ++++++++++++------------ conn_async_insert.go | 2 +- conn_batch.go | 6 +++--- conn_check.go | 2 +- conn_check_ping.go | 2 +- conn_exec.go | 2 +- conn_handshake.go | 2 +- conn_http.go | 4 ++-- conn_http_batch.go | 2 +- conn_http_query.go | 2 +- conn_logs.go | 2 +- conn_ping.go | 2 +- conn_process.go | 8 ++++---- conn_profile_events.go | 2 +- conn_query.go | 4 ++-- conn_send_query.go | 2 +- tests/conn_test.go | 19 +++++++++++++++++++ 20 files changed, 69 insertions(+), 50 deletions(-) diff --git a/clickhouse.go b/clickhouse.go index 848dca08e0..8105843586 100644 --- a/clickhouse.go +++ b/clickhouse.go @@ -80,14 +80,14 @@ func Open(opt *Options) (driver.Conn, error) { o := opt.setDefaults() return &clickhouse{ opt: o, - idle: make(chan *connect, o.MaxIdleConns), + idle: make(chan *Connect, o.MaxIdleConns), open: make(chan struct{}, o.MaxOpenConns), }, nil } type clickhouse struct { opt *Options - idle chan *connect + idle chan *Connect open chan struct{} connID int64 } @@ -193,17 +193,17 @@ func (ch *clickhouse) Stats() driver.Stats { } } -func (ch *clickhouse) dial(ctx context.Context) (conn *connect, err error) { - connID := int(atomic.AddInt64(&ch.connID, 1)) - +func (ch *clickhouse) dial(ctx context.Context) (conn *Connect, err error) { if ch.opt.DialStrategy != nil { - return ch.opt.DialStrategy(ctx, ch.opt, connID) + return ch.opt.DialStrategy(ctx, ch.opt) } - return ch.defaultDialStrategy(ctx, ch.opt, connID) + return ch.defaultDialStrategy(ctx, ch.opt) } -func (ch *clickhouse) defaultDialStrategy(ctx context.Context, opt *Options, connID int) (conn *connect, err error) { +func (ch *clickhouse) defaultDialStrategy(ctx context.Context, opt *Options) (conn *Connect, err error) { + connID := int(atomic.AddInt64(&ch.connID, 1)) + for i := range opt.Addr { var num int switch opt.ConnOpenStrategy { @@ -222,7 +222,7 @@ func (ch *clickhouse) defaultDialStrategy(ctx context.Context, opt *Options, con return nil, err } -func (ch *clickhouse) acquire(ctx context.Context) (conn *connect, err error) { +func (ch *clickhouse) acquire(ctx context.Context) (conn *Connect, err error) { timer := time.NewTimer(ch.opt.DialTimeout) defer timer.Stop() select { @@ -263,7 +263,7 @@ func (ch *clickhouse) acquire(ctx context.Context) (conn *connect, err error) { return conn, nil } -func (ch *clickhouse) release(conn *connect, err error) { +func (ch *clickhouse) release(conn *Connect, err error) { if conn.released { return } diff --git a/clickhouse_options.go b/clickhouse_options.go index 90824b8499..d3e21b7001 100644 --- a/clickhouse_options.go +++ b/clickhouse_options.go @@ -121,7 +121,7 @@ type Options struct { Addr []string Auth Auth DialContext func(ctx context.Context, addr string) (net.Conn, error) - DialStrategy func(ctx context.Context, options *Options, ID int) (*connect, error) + DialStrategy func(ctx context.Context, options *Options) (*Connect, error) Debug bool Debugf func(format string, v ...interface{}) // only works when Debug is true Settings Settings diff --git a/clickhouse_std.go b/clickhouse_std.go index 64a45b1fbe..41840a1198 100644 --- a/clickhouse_std.go +++ b/clickhouse_std.go @@ -116,7 +116,7 @@ func OpenDB(opt *Options) *sql.DB { } if len(settings) != 0 { return sql.OpenDB(&stdConnOpener{ - err: fmt.Errorf("cannot connect. invalid settings. use %s (see https://pkg.go.dev/database/sql)", strings.Join(settings, ",")), + err: fmt.Errorf("cannot Connect. invalid settings. use %s (see https://pkg.go.dev/database/sql)", strings.Join(settings, ",")), }) } o := opt.setDefaults() @@ -128,10 +128,10 @@ func OpenDB(opt *Options) *sql.DB { type stdConnect interface { isBad() bool close() error - query(ctx context.Context, release func(*connect, error), query string, args ...interface{}) (*rows, error) + query(ctx context.Context, release func(*Connect, error), query string, args ...interface{}) (*rows, error) exec(ctx context.Context, query string, args ...interface{}) error ping(ctx context.Context) (err error) - prepareBatch(ctx context.Context, query string, release func(*connect, error)) (ldriver.Batch, error) + prepareBatch(ctx context.Context, query string, release func(*Connect, error)) (ldriver.Batch, error) asyncInsert(ctx context.Context, query string, wait bool) error } @@ -192,7 +192,7 @@ func (std *stdDriver) ExecContext(ctx context.Context, query string, args []driv } func (std *stdDriver) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { - r, err := std.conn.query(ctx, func(*connect, error) {}, query, rebind(args)...) + r, err := std.conn.query(ctx, func(*Connect, error) {}, query, rebind(args)...) if err != nil { return nil, err } @@ -206,7 +206,7 @@ func (std *stdDriver) Prepare(query string) (driver.Stmt, error) { } func (std *stdDriver) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { - batch, err := std.conn.prepareBatch(ctx, query, func(*connect, error) {}) + batch, err := std.conn.prepareBatch(ctx, query, func(*Connect, error) {}) if err != nil { return nil, err } diff --git a/conn.go b/conn.go index 2dcb119cd3..9a316908c2 100644 --- a/conn.go +++ b/conn.go @@ -33,7 +33,7 @@ import ( "github.com/ClickHouse/clickhouse-go/v2/lib/proto" ) -func Dial(ctx context.Context, addr string, num int, opt *Options) (*connect, error) { +func Dial(ctx context.Context, addr string, num int, opt *Options) (*Connect, error) { var ( err error conn net.Conn @@ -71,7 +71,7 @@ func Dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er } var ( - connect = &connect{ + connect = &Connect{ id: num, opt: opt, conn: conn, @@ -102,7 +102,7 @@ func Dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er } // https://github.com/ClickHouse/ClickHouse/blob/master/src/Client/Connection.cpp -type connect struct { +type Connect struct { id int opt *Options conn net.Conn @@ -122,7 +122,7 @@ type connect struct { maxCompressionBuffer int } -func (c *connect) settings(querySettings Settings) []proto.Setting { +func (c *Connect) settings(querySettings Settings) []proto.Setting { settings := make([]proto.Setting, 0, len(c.opt.Settings)+len(querySettings)) for k, v := range c.opt.Settings { settings = append(settings, proto.Setting{ @@ -139,7 +139,7 @@ func (c *connect) settings(querySettings Settings) []proto.Setting { return settings } -func (c *connect) isBad() bool { +func (c *Connect) isBad() bool { switch { case c.closed: return true @@ -150,7 +150,7 @@ func (c *connect) isBad() bool { return false } -func (c *connect) close() error { +func (c *Connect) close() error { if c.closed { return nil } @@ -163,7 +163,7 @@ func (c *connect) close() error { return nil } -func (c *connect) progress() (*Progress, error) { +func (c *Connect) progress() (*Progress, error) { var progress proto.Progress if err := progress.Decode(c.reader, c.revision); err != nil { return nil, err @@ -172,7 +172,7 @@ func (c *connect) progress() (*Progress, error) { return &progress, nil } -func (c *connect) exception() error { +func (c *Connect) exception() error { var e Exception if err := e.Decode(c.reader); err != nil { return err @@ -181,7 +181,7 @@ func (c *connect) exception() error { return &e } -func (c *connect) compressBuffer(start int) error { +func (c *Connect) compressBuffer(start int) error { if c.compression != CompressionNone && len(c.buffer.Buf) > 0 { data := c.buffer.Buf[start:] if err := c.compressor.Compress(compress.Method(c.compression), data); err != nil { @@ -192,7 +192,7 @@ func (c *connect) compressBuffer(start int) error { return nil } -func (c *connect) sendData(block *proto.Block, name string) error { +func (c *Connect) sendData(block *proto.Block, name string) error { c.debugf("[send data] compression=%t", c.compression) c.buffer.PutByte(proto.ClientData) c.buffer.PutString(name) @@ -229,7 +229,7 @@ func (c *connect) sendData(block *proto.Block, name string) error { return nil } -func (c *connect) readData(packet byte, compressible bool) (*proto.Block, error) { +func (c *Connect) readData(packet byte, compressible bool) (*proto.Block, error) { if _, err := c.reader.Str(); err != nil { return nil, err } @@ -246,7 +246,7 @@ func (c *connect) readData(packet byte, compressible bool) (*proto.Block, error) return &block, nil } -func (c *connect) flush() error { +func (c *Connect) flush() error { if len(c.buffer.Buf) == 0 { // Nothing to flush. return nil diff --git a/conn_async_insert.go b/conn_async_insert.go index a324bd1da6..ebf709f741 100644 --- a/conn_async_insert.go +++ b/conn_async_insert.go @@ -21,7 +21,7 @@ import ( "context" ) -func (c *connect) asyncInsert(ctx context.Context, query string, wait bool) error { +func (c *Connect) asyncInsert(ctx context.Context, query string, wait bool) error { options := queryOptions(ctx) { options.settings["async_insert"] = 1 diff --git a/conn_batch.go b/conn_batch.go index cc2461891b..32db766516 100644 --- a/conn_batch.go +++ b/conn_batch.go @@ -34,7 +34,7 @@ import ( var splitInsertRe = regexp.MustCompile(`(?i)\sVALUES\s*\(`) var columnMatch = regexp.MustCompile(`.*\((?P.+)\)$`) -func (c *connect) prepareBatch(ctx context.Context, query string, release func(*connect, error)) (driver.Batch, error) { +func (c *Connect) prepareBatch(ctx context.Context, query string, release func(*Connect, error)) (driver.Batch, error) { //defer func() { // if err := recover(); err != nil { // fmt.Printf("panic occurred on %d:\n", c.num) @@ -86,11 +86,11 @@ func (c *connect) prepareBatch(ctx context.Context, query string, release func(* type batch struct { err error ctx context.Context - conn *connect + conn *Connect sent bool released bool block *proto.Block - connRelease func(*connect, error) + connRelease func(*Connect, error) onProcess *onProcess } diff --git a/conn_check.go b/conn_check.go index f30194f847..cfe3cc24cd 100644 --- a/conn_check.go +++ b/conn_check.go @@ -26,7 +26,7 @@ import ( "syscall" ) -func (c *connect) connCheck() error { +func (c *Connect) connCheck() error { var sysErr error sysConn, ok := c.conn.(syscall.Conn) if !ok { diff --git a/conn_check_ping.go b/conn_check_ping.go index fe1fe7bfaa..53791b960f 100644 --- a/conn_check_ping.go +++ b/conn_check_ping.go @@ -25,7 +25,7 @@ import ( "time" ) -func (c *connect) connCheck() error { +func (c *Connect) connCheck() error { ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second)) defer cancel() if err := c.ping(ctx); err != nil { diff --git a/conn_exec.go b/conn_exec.go index 1a39da0d05..68bc4f15c8 100644 --- a/conn_exec.go +++ b/conn_exec.go @@ -22,7 +22,7 @@ import ( "time" ) -func (c *connect) exec(ctx context.Context, query string, args ...interface{}) error { +func (c *Connect) exec(ctx context.Context, query string, args ...interface{}) error { var ( options = queryOptions(ctx) body, err = bind(c.server.Timezone, query, args...) diff --git a/conn_handshake.go b/conn_handshake.go index a476bdba09..3cb2fae680 100644 --- a/conn_handshake.go +++ b/conn_handshake.go @@ -25,7 +25,7 @@ import ( "github.com/ClickHouse/clickhouse-go/v2/lib/proto" ) -func (c *connect) handshake(database, username, password string) error { +func (c *Connect) handshake(database, username, password string) error { defer c.buffer.Reset() c.debugf("[handshake] -> %s", proto.ClientHandshake{}) // set a read deadline - alternative to context.Read operation will fail if no data is received after deadline. diff --git a/conn_http.go b/conn_http.go index 419691fe4c..5420bb8dcc 100644 --- a/conn_http.go +++ b/conn_http.go @@ -262,7 +262,7 @@ func (h *httpConnect) isBad() bool { } func (h *httpConnect) readTimeZone(ctx context.Context) (*time.Location, error) { - rows, err := h.query(ctx, func(*connect, error) {}, "SELECT timezone()") + rows, err := h.query(ctx, func(*Connect, error) {}, "SELECT timezone()") if err != nil { return nil, err } @@ -280,7 +280,7 @@ func (h *httpConnect) readTimeZone(ctx context.Context) (*time.Location, error) } func (h *httpConnect) readVersion(ctx context.Context) (proto.Version, error) { - rows, err := h.query(ctx, func(*connect, error) {}, "SELECT version()") + rows, err := h.query(ctx, func(*Connect, error) {}, "SELECT version()") if err != nil { return proto.Version{}, err } diff --git a/conn_http_batch.go b/conn_http_batch.go index d9692a1c84..aa89f72118 100644 --- a/conn_http_batch.go +++ b/conn_http_batch.go @@ -34,7 +34,7 @@ import ( var httpInsertRe = regexp.MustCompile(`(?i)^INSERT INTO\s+\x60?([\w.^\(]+)\x60?\s*(\([^\)]*\))?`) // release is ignored, because http used by std with empty release function -func (h *httpConnect) prepareBatch(ctx context.Context, query string, release func(*connect, error)) (driver.Batch, error) { +func (h *httpConnect) prepareBatch(ctx context.Context, query string, release func(*Connect, error)) (driver.Batch, error) { matches := httpInsertRe.FindStringSubmatch(query) if len(matches) < 3 { return nil, errors.New("cannot get table name from query") diff --git a/conn_http_query.go b/conn_http_query.go index 828fc58321..c0fa9bdb26 100644 --- a/conn_http_query.go +++ b/conn_http_query.go @@ -29,7 +29,7 @@ import ( ) // release is ignored, because http used by std with empty release function -func (h *httpConnect) query(ctx context.Context, release func(*connect, error), query string, args ...interface{}) (*rows, error) { +func (h *httpConnect) query(ctx context.Context, release func(*Connect, error), query string, args ...interface{}) (*rows, error) { query, err := bind(h.location, query, args...) if err != nil { return nil, err diff --git a/conn_logs.go b/conn_logs.go index 4f234ccfb3..4375b145da 100644 --- a/conn_logs.go +++ b/conn_logs.go @@ -34,7 +34,7 @@ type Log struct { Text string } -func (c *connect) logs() ([]Log, error) { +func (c *Connect) logs() ([]Log, error) { block, err := c.readData(proto.ServerLog, false) if err != nil { return nil, err diff --git a/conn_ping.go b/conn_ping.go index cf8103086e..ea5ad1dda7 100644 --- a/conn_ping.go +++ b/conn_ping.go @@ -27,7 +27,7 @@ import ( // Connection::ping // https://github.com/ClickHouse/ClickHouse/blob/master/src/Client/Connection.cpp -func (c *connect) ping(ctx context.Context) (err error) { +func (c *Connect) ping(ctx context.Context) (err error) { // set a read deadline - alternative to context.Read operation will fail if no data is received after deadline. c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) defer c.conn.SetReadDeadline(time.Time{}) diff --git a/conn_process.go b/conn_process.go index b673c78240..1f5ba307f7 100644 --- a/conn_process.go +++ b/conn_process.go @@ -32,7 +32,7 @@ type onProcess struct { profileEvents func([]ProfileEvent) } -func (c *connect) firstBlock(ctx context.Context, on *onProcess) (*proto.Block, error) { +func (c *Connect) firstBlock(ctx context.Context, on *onProcess) (*proto.Block, error) { for { select { case <-ctx.Done(): @@ -58,7 +58,7 @@ func (c *connect) firstBlock(ctx context.Context, on *onProcess) (*proto.Block, } } -func (c *connect) process(ctx context.Context, on *onProcess) error { +func (c *Connect) process(ctx context.Context, on *onProcess) error { for { select { case <-ctx.Done(): @@ -81,7 +81,7 @@ func (c *connect) process(ctx context.Context, on *onProcess) error { } } -func (c *connect) handle(packet byte, on *onProcess) error { +func (c *Connect) handle(packet byte, on *onProcess) error { switch packet { case proto.ServerData, proto.ServerTotals, proto.ServerExtremes: block, err := c.readData(packet, true) @@ -134,7 +134,7 @@ func (c *connect) handle(packet byte, on *onProcess) error { return nil } -func (c *connect) cancel() error { +func (c *Connect) cancel() error { c.debugf("[cancel]") c.buffer.PutUVarInt(proto.ClientCancel) wErr := c.flush() diff --git a/conn_profile_events.go b/conn_profile_events.go index 3fa7db74d6..c8f28ace7a 100644 --- a/conn_profile_events.go +++ b/conn_profile_events.go @@ -33,7 +33,7 @@ type ProfileEvent struct { Value int64 } -func (c *connect) profileEvents() ([]ProfileEvent, error) { +func (c *Connect) profileEvents() ([]ProfileEvent, error) { block, err := c.readData(proto.ServerProfileEvents, false) if err != nil { return nil, err diff --git a/conn_query.go b/conn_query.go index 9dfa9f7999..b5df87f4b9 100644 --- a/conn_query.go +++ b/conn_query.go @@ -24,7 +24,7 @@ import ( "github.com/ClickHouse/clickhouse-go/v2/lib/proto" ) -func (c *connect) query(ctx context.Context, release func(*connect, error), query string, args ...interface{}) (*rows, error) { +func (c *Connect) query(ctx context.Context, release func(*Connect, error), query string, args ...interface{}) (*rows, error) { var ( options = queryOptions(ctx) onProcess = options.onProcess() @@ -88,7 +88,7 @@ func (c *connect) query(ctx context.Context, release func(*connect, error), quer }, nil } -func (c *connect) queryRow(ctx context.Context, release func(*connect, error), query string, args ...interface{}) *row { +func (c *Connect) queryRow(ctx context.Context, release func(*Connect, error), query string, args ...interface{}) *row { rows, err := c.query(ctx, release, query, args...) if err != nil { return &row{ diff --git a/conn_send_query.go b/conn_send_query.go index f2367b798a..d61c02a48c 100644 --- a/conn_send_query.go +++ b/conn_send_query.go @@ -23,7 +23,7 @@ import ( // Connection::sendQuery // https://github.com/ClickHouse/ClickHouse/blob/master/src/Client/Connection.cpp -func (c *connect) sendQuery(body string, o *QueryOptions) error { +func (c *Connect) sendQuery(body string, o *QueryOptions) error { c.debugf("[send query] compression=%t %s", c.compression, body) c.buffer.PutByte(proto.ClientQuery) q := proto.Query{ diff --git a/tests/conn_test.go b/tests/conn_test.go index 0422815ca3..60e5870fa1 100644 --- a/tests/conn_test.go +++ b/tests/conn_test.go @@ -242,3 +242,22 @@ func TestBlockBufferSize(t *testing.T) { } require.Equal(t, 10000000, i) } + +func TestConnCustomDialStrategy(t *testing.T) { + env, err := GetTestEnvironment(testSet) + require.NoError(t, err) + + actualAddr := fmt.Sprintf("%s:%d", env.Host, env.Port) + env.Host = "non-existent.host" + opts := clientOptionsFromEnv(env, clickhouse.Settings{}) + opts.DialStrategy = func(ctx context.Context, opts *clickhouse.Options) (*clickhouse.Connect, error) { + return clickhouse.Dial(ctx, actualAddr, 1, opts) + } + + conn, err := clickhouse.Open(&opts) + require.NoError(t, err) + + require.NoError(t, err) + require.NoError(t, conn.Ping(context.Background())) + require.NoError(t, conn.Close()) +} From 631660dc06ef5eaa20e9e61e4eadb90f9691e13c Mon Sep 17 00:00:00 2001 From: Kuba Kaflik Date: Thu, 22 Dec 2022 18:13:02 +0100 Subject: [PATCH 4/8] test: example of usage --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 8295b3693f..13420b2029 100644 --- a/README.md +++ b/README.md @@ -70,6 +70,9 @@ Support for the ClickHouse protocol advanced features using `Context`: var d net.Dialer return d.DialContext(ctx, "tcp", addr) }, + opts.DialStrategy = func(ctx context.Context, opts *clickhouse.Options) (*clickhouse.Connect, error) { + return clickhouse.Dial(ctx, "127.0.0.1:5678", 1, opts) + } Debug: true, Debugf: func(format string, v ...interface{}) { fmt.Printf(format, v) From 95ec12c28b94bc3fdc55d38192d3eb73f038b897 Mon Sep 17 00:00:00 2001 From: Kuba Kaflik Date: Fri, 23 Dec 2022 18:39:54 +0100 Subject: [PATCH 5/8] do not expose connect type --- clickhouse.go | 10 +++++----- clickhouse_options.go | 2 +- clickhouse_std.go | 10 +++++----- conn.go | 24 ++++++++++++------------ conn_async_insert.go | 2 +- conn_batch.go | 6 +++--- conn_check.go | 2 +- conn_check_ping.go | 2 +- conn_exec.go | 2 +- conn_handshake.go | 4 ++-- conn_http.go | 4 ++-- conn_http_batch.go | 2 +- conn_http_query.go | 2 +- conn_logs.go | 2 +- conn_ping.go | 2 +- conn_process.go | 8 ++++---- conn_profile_events.go | 2 +- conn_query.go | 4 ++-- conn_send_query.go | 2 +- 19 files changed, 46 insertions(+), 46 deletions(-) diff --git a/clickhouse.go b/clickhouse.go index fdc4e35173..79e5537d32 100644 --- a/clickhouse.go +++ b/clickhouse.go @@ -80,14 +80,14 @@ func Open(opt *Options) (driver.Conn, error) { o := opt.setDefaults() return &clickhouse{ opt: o, - idle: make(chan *Connect, o.MaxIdleConns), + idle: make(chan *connect, o.MaxIdleConns), open: make(chan struct{}, o.MaxOpenConns), }, nil } type clickhouse struct { opt *Options - idle chan *Connect + idle chan *connect open chan struct{} connID int64 } @@ -193,7 +193,7 @@ func (ch *clickhouse) Stats() driver.Stats { } } -func (ch *clickhouse) dial(ctx context.Context) (conn *Connect, err error) { +func (ch *clickhouse) dial(ctx context.Context) (conn *connect, err error) { connID := int(atomic.AddInt64(&ch.connID, 1)) dialFunc := func(ctx context.Context, addr string, opt *Options) DialResult { @@ -236,7 +236,7 @@ func DefaultDialStrategy(ctx context.Context, connID int, opt *Options, dial Dia return DialResult{}, ErrAcquireConnNoAddress } -func (ch *clickhouse) acquire(ctx context.Context) (conn *Connect, err error) { +func (ch *clickhouse) acquire(ctx context.Context) (conn *connect, err error) { timer := time.NewTimer(ch.opt.DialTimeout) defer timer.Stop() select { @@ -277,7 +277,7 @@ func (ch *clickhouse) acquire(ctx context.Context) (conn *Connect, err error) { return conn, nil } -func (ch *clickhouse) release(conn *Connect, err error) { +func (ch *clickhouse) release(conn *connect, err error) { if conn.released { return } diff --git a/clickhouse_options.go b/clickhouse_options.go index 34513c903e..f33f8a6c1e 100644 --- a/clickhouse_options.go +++ b/clickhouse_options.go @@ -117,7 +117,7 @@ func ParseDSN(dsn string) (*Options, error) { type Dial func(ctx context.Context, addr string, opt *Options) DialResult type DialResult struct { - conn *Connect + conn *connect err error } diff --git a/clickhouse_std.go b/clickhouse_std.go index 0971de7ce8..18d7df8115 100644 --- a/clickhouse_std.go +++ b/clickhouse_std.go @@ -116,7 +116,7 @@ func OpenDB(opt *Options) *sql.DB { } if len(settings) != 0 { return sql.OpenDB(&stdConnOpener{ - err: fmt.Errorf("cannot Connect. invalid settings. use %s (see https://pkg.go.dev/database/sql)", strings.Join(settings, ",")), + err: fmt.Errorf("cannot connect. invalid settings. use %s (see https://pkg.go.dev/database/sql)", strings.Join(settings, ",")), }) } o := opt.setDefaults() @@ -128,10 +128,10 @@ func OpenDB(opt *Options) *sql.DB { type stdConnect interface { isBad() bool close() error - query(ctx context.Context, release func(*Connect, error), query string, args ...interface{}) (*rows, error) + query(ctx context.Context, release func(*connect, error), query string, args ...interface{}) (*rows, error) exec(ctx context.Context, query string, args ...interface{}) error ping(ctx context.Context) (err error) - prepareBatch(ctx context.Context, query string, release func(*Connect, error)) (ldriver.Batch, error) + prepareBatch(ctx context.Context, query string, release func(*connect, error)) (ldriver.Batch, error) asyncInsert(ctx context.Context, query string, wait bool) error } @@ -192,7 +192,7 @@ func (std *stdDriver) ExecContext(ctx context.Context, query string, args []driv } func (std *stdDriver) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { - r, err := std.conn.query(ctx, func(*Connect, error) {}, query, rebind(args)...) + r, err := std.conn.query(ctx, func(*connect, error) {}, query, rebind(args)...) if err != nil { return nil, err } @@ -206,7 +206,7 @@ func (std *stdDriver) Prepare(query string) (driver.Stmt, error) { } func (std *stdDriver) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { - batch, err := std.conn.prepareBatch(ctx, query, func(*Connect, error) {}) + batch, err := std.conn.prepareBatch(ctx, query, func(*connect, error) {}) if err != nil { return nil, err } diff --git a/conn.go b/conn.go index 3f059b0338..75eaf607cf 100644 --- a/conn.go +++ b/conn.go @@ -33,7 +33,7 @@ import ( "github.com/ClickHouse/clickhouse-go/v2/lib/proto" ) -func dial(ctx context.Context, addr string, num int, opt *Options) (*Connect, error) { +func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, error) { var ( err error conn net.Conn @@ -71,7 +71,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*Connect, er } var ( - connect = &Connect{ + connect = &connect{ id: num, opt: opt, conn: conn, @@ -107,7 +107,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*Connect, er } // https://github.com/ClickHouse/ClickHouse/blob/master/src/Client/Connection.cpp -type Connect struct { +type connect struct { id int opt *Options conn net.Conn @@ -127,7 +127,7 @@ type Connect struct { maxCompressionBuffer int } -func (c *Connect) settings(querySettings Settings) []proto.Setting { +func (c *connect) settings(querySettings Settings) []proto.Setting { settings := make([]proto.Setting, 0, len(c.opt.Settings)+len(querySettings)) for k, v := range c.opt.Settings { settings = append(settings, proto.Setting{ @@ -144,7 +144,7 @@ func (c *Connect) settings(querySettings Settings) []proto.Setting { return settings } -func (c *Connect) isBad() bool { +func (c *connect) isBad() bool { switch { case c.closed: return true @@ -155,7 +155,7 @@ func (c *Connect) isBad() bool { return false } -func (c *Connect) close() error { +func (c *connect) close() error { if c.closed { return nil } @@ -168,7 +168,7 @@ func (c *Connect) close() error { return nil } -func (c *Connect) progress() (*Progress, error) { +func (c *connect) progress() (*Progress, error) { var progress proto.Progress if err := progress.Decode(c.reader, c.revision); err != nil { return nil, err @@ -177,7 +177,7 @@ func (c *Connect) progress() (*Progress, error) { return &progress, nil } -func (c *Connect) exception() error { +func (c *connect) exception() error { var e Exception if err := e.Decode(c.reader); err != nil { return err @@ -186,7 +186,7 @@ func (c *Connect) exception() error { return &e } -func (c *Connect) compressBuffer(start int) error { +func (c *connect) compressBuffer(start int) error { if c.compression != CompressionNone && len(c.buffer.Buf) > 0 { data := c.buffer.Buf[start:] if err := c.compressor.Compress(compress.Method(c.compression), data); err != nil { @@ -197,7 +197,7 @@ func (c *Connect) compressBuffer(start int) error { return nil } -func (c *Connect) sendData(block *proto.Block, name string) error { +func (c *connect) sendData(block *proto.Block, name string) error { c.debugf("[send data] compression=%t", c.compression) c.buffer.PutByte(proto.ClientData) c.buffer.PutString(name) @@ -234,7 +234,7 @@ func (c *Connect) sendData(block *proto.Block, name string) error { return nil } -func (c *Connect) readData(packet byte, compressible bool) (*proto.Block, error) { +func (c *connect) readData(packet byte, compressible bool) (*proto.Block, error) { if _, err := c.reader.Str(); err != nil { return nil, err } @@ -251,7 +251,7 @@ func (c *Connect) readData(packet byte, compressible bool) (*proto.Block, error) return &block, nil } -func (c *Connect) flush() error { +func (c *connect) flush() error { if len(c.buffer.Buf) == 0 { // Nothing to flush. return nil diff --git a/conn_async_insert.go b/conn_async_insert.go index ebf709f741..a324bd1da6 100644 --- a/conn_async_insert.go +++ b/conn_async_insert.go @@ -21,7 +21,7 @@ import ( "context" ) -func (c *Connect) asyncInsert(ctx context.Context, query string, wait bool) error { +func (c *connect) asyncInsert(ctx context.Context, query string, wait bool) error { options := queryOptions(ctx) { options.settings["async_insert"] = 1 diff --git a/conn_batch.go b/conn_batch.go index 32db766516..cc2461891b 100644 --- a/conn_batch.go +++ b/conn_batch.go @@ -34,7 +34,7 @@ import ( var splitInsertRe = regexp.MustCompile(`(?i)\sVALUES\s*\(`) var columnMatch = regexp.MustCompile(`.*\((?P.+)\)$`) -func (c *Connect) prepareBatch(ctx context.Context, query string, release func(*Connect, error)) (driver.Batch, error) { +func (c *connect) prepareBatch(ctx context.Context, query string, release func(*connect, error)) (driver.Batch, error) { //defer func() { // if err := recover(); err != nil { // fmt.Printf("panic occurred on %d:\n", c.num) @@ -86,11 +86,11 @@ func (c *Connect) prepareBatch(ctx context.Context, query string, release func(* type batch struct { err error ctx context.Context - conn *Connect + conn *connect sent bool released bool block *proto.Block - connRelease func(*Connect, error) + connRelease func(*connect, error) onProcess *onProcess } diff --git a/conn_check.go b/conn_check.go index cfe3cc24cd..f30194f847 100644 --- a/conn_check.go +++ b/conn_check.go @@ -26,7 +26,7 @@ import ( "syscall" ) -func (c *Connect) connCheck() error { +func (c *connect) connCheck() error { var sysErr error sysConn, ok := c.conn.(syscall.Conn) if !ok { diff --git a/conn_check_ping.go b/conn_check_ping.go index 53791b960f..fe1fe7bfaa 100644 --- a/conn_check_ping.go +++ b/conn_check_ping.go @@ -25,7 +25,7 @@ import ( "time" ) -func (c *Connect) connCheck() error { +func (c *connect) connCheck() error { ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second)) defer cancel() if err := c.ping(ctx); err != nil { diff --git a/conn_exec.go b/conn_exec.go index 2ae0eb4b6e..b7666214b8 100644 --- a/conn_exec.go +++ b/conn_exec.go @@ -23,7 +23,7 @@ import ( "time" ) -func (c *Connect) exec(ctx context.Context, query string, args ...interface{}) error { +func (c *connect) exec(ctx context.Context, query string, args ...interface{}) error { var ( options = queryOptions(ctx) queryParamsProtocolSupport = c.revision >= proto.DBMS_MIN_PROTOCOL_VERSION_WITH_PARAMETERS diff --git a/conn_handshake.go b/conn_handshake.go index fa2633095b..ff3be24645 100644 --- a/conn_handshake.go +++ b/conn_handshake.go @@ -25,7 +25,7 @@ import ( "github.com/ClickHouse/clickhouse-go/v2/lib/proto" ) -func (c *Connect) handshake(database, username, password string) error { +func (c *connect) handshake(database, username, password string) error { defer c.buffer.Reset() c.debugf("[handshake] -> %s", proto.ClientHandshake{}) // set a read deadline - alternative to context.Read operation will fail if no data is received after deadline. @@ -77,7 +77,7 @@ func (c *Connect) handshake(database, username, password string) error { return nil } -func (c *Connect) sendAddendum() error { +func (c *connect) sendAddendum() error { if c.revision >= proto.DBMS_MIN_PROTOCOL_VERSION_WITH_QUOTA_KEY { c.buffer.PutString("") // todo quota key support } diff --git a/conn_http.go b/conn_http.go index 132b870766..3947a9cdcf 100644 --- a/conn_http.go +++ b/conn_http.go @@ -262,7 +262,7 @@ func (h *httpConnect) isBad() bool { } func (h *httpConnect) readTimeZone(ctx context.Context) (*time.Location, error) { - rows, err := h.query(ctx, func(*Connect, error) {}, "SELECT timezone()") + rows, err := h.query(ctx, func(*connect, error) {}, "SELECT timezone()") if err != nil { return nil, err } @@ -280,7 +280,7 @@ func (h *httpConnect) readTimeZone(ctx context.Context) (*time.Location, error) } func (h *httpConnect) readVersion(ctx context.Context) (proto.Version, error) { - rows, err := h.query(ctx, func(*Connect, error) {}, "SELECT version()") + rows, err := h.query(ctx, func(*connect, error) {}, "SELECT version()") if err != nil { return proto.Version{}, err } diff --git a/conn_http_batch.go b/conn_http_batch.go index aa89f72118..d9692a1c84 100644 --- a/conn_http_batch.go +++ b/conn_http_batch.go @@ -34,7 +34,7 @@ import ( var httpInsertRe = regexp.MustCompile(`(?i)^INSERT INTO\s+\x60?([\w.^\(]+)\x60?\s*(\([^\)]*\))?`) // release is ignored, because http used by std with empty release function -func (h *httpConnect) prepareBatch(ctx context.Context, query string, release func(*Connect, error)) (driver.Batch, error) { +func (h *httpConnect) prepareBatch(ctx context.Context, query string, release func(*connect, error)) (driver.Batch, error) { matches := httpInsertRe.FindStringSubmatch(query) if len(matches) < 3 { return nil, errors.New("cannot get table name from query") diff --git a/conn_http_query.go b/conn_http_query.go index bcb6828b8a..9d7d746aee 100644 --- a/conn_http_query.go +++ b/conn_http_query.go @@ -29,7 +29,7 @@ import ( ) // release is ignored, because http used by std with empty release function -func (h *httpConnect) query(ctx context.Context, release func(*Connect, error), query string, args ...interface{}) (*rows, error) { +func (h *httpConnect) query(ctx context.Context, release func(*connect, error), query string, args ...interface{}) (*rows, error) { options := queryOptions(ctx) query, err := bindQueryOrAppendParameters(true, &options, query, h.location, args...) if err != nil { diff --git a/conn_logs.go b/conn_logs.go index 4375b145da..4f234ccfb3 100644 --- a/conn_logs.go +++ b/conn_logs.go @@ -34,7 +34,7 @@ type Log struct { Text string } -func (c *Connect) logs() ([]Log, error) { +func (c *connect) logs() ([]Log, error) { block, err := c.readData(proto.ServerLog, false) if err != nil { return nil, err diff --git a/conn_ping.go b/conn_ping.go index ea5ad1dda7..cf8103086e 100644 --- a/conn_ping.go +++ b/conn_ping.go @@ -27,7 +27,7 @@ import ( // Connection::ping // https://github.com/ClickHouse/ClickHouse/blob/master/src/Client/Connection.cpp -func (c *Connect) ping(ctx context.Context) (err error) { +func (c *connect) ping(ctx context.Context) (err error) { // set a read deadline - alternative to context.Read operation will fail if no data is received after deadline. c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) defer c.conn.SetReadDeadline(time.Time{}) diff --git a/conn_process.go b/conn_process.go index 1f5ba307f7..b673c78240 100644 --- a/conn_process.go +++ b/conn_process.go @@ -32,7 +32,7 @@ type onProcess struct { profileEvents func([]ProfileEvent) } -func (c *Connect) firstBlock(ctx context.Context, on *onProcess) (*proto.Block, error) { +func (c *connect) firstBlock(ctx context.Context, on *onProcess) (*proto.Block, error) { for { select { case <-ctx.Done(): @@ -58,7 +58,7 @@ func (c *Connect) firstBlock(ctx context.Context, on *onProcess) (*proto.Block, } } -func (c *Connect) process(ctx context.Context, on *onProcess) error { +func (c *connect) process(ctx context.Context, on *onProcess) error { for { select { case <-ctx.Done(): @@ -81,7 +81,7 @@ func (c *Connect) process(ctx context.Context, on *onProcess) error { } } -func (c *Connect) handle(packet byte, on *onProcess) error { +func (c *connect) handle(packet byte, on *onProcess) error { switch packet { case proto.ServerData, proto.ServerTotals, proto.ServerExtremes: block, err := c.readData(packet, true) @@ -134,7 +134,7 @@ func (c *Connect) handle(packet byte, on *onProcess) error { return nil } -func (c *Connect) cancel() error { +func (c *connect) cancel() error { c.debugf("[cancel]") c.buffer.PutUVarInt(proto.ClientCancel) wErr := c.flush() diff --git a/conn_profile_events.go b/conn_profile_events.go index c8f28ace7a..3fa7db74d6 100644 --- a/conn_profile_events.go +++ b/conn_profile_events.go @@ -33,7 +33,7 @@ type ProfileEvent struct { Value int64 } -func (c *Connect) profileEvents() ([]ProfileEvent, error) { +func (c *connect) profileEvents() ([]ProfileEvent, error) { block, err := c.readData(proto.ServerProfileEvents, false) if err != nil { return nil, err diff --git a/conn_query.go b/conn_query.go index 57b2d6d146..b673ca4a0d 100644 --- a/conn_query.go +++ b/conn_query.go @@ -24,7 +24,7 @@ import ( "github.com/ClickHouse/clickhouse-go/v2/lib/proto" ) -func (c *Connect) query(ctx context.Context, release func(*Connect, error), query string, args ...interface{}) (*rows, error) { +func (c *connect) query(ctx context.Context, release func(*connect, error), query string, args ...interface{}) (*rows, error) { var ( options = queryOptions(ctx) onProcess = options.onProcess() @@ -89,7 +89,7 @@ func (c *Connect) query(ctx context.Context, release func(*Connect, error), quer }, nil } -func (c *Connect) queryRow(ctx context.Context, release func(*Connect, error), query string, args ...interface{}) *row { +func (c *connect) queryRow(ctx context.Context, release func(*connect, error), query string, args ...interface{}) *row { rows, err := c.query(ctx, release, query, args...) if err != nil { return &row{ diff --git a/conn_send_query.go b/conn_send_query.go index 7e8ae54951..df21b7ac87 100644 --- a/conn_send_query.go +++ b/conn_send_query.go @@ -23,7 +23,7 @@ import ( // Connection::sendQuery // https://github.com/ClickHouse/ClickHouse/blob/master/src/Client/Connection.cpp -func (c *Connect) sendQuery(body string, o *QueryOptions) error { +func (c *connect) sendQuery(body string, o *QueryOptions) error { c.debugf("[send query] compression=%t %s", c.compression, body) c.buffer.PutByte(proto.ClientQuery) q := proto.Query{ From c9e5a2c15952c18d475697c1d2ee4ad3f2587218 Mon Sep 17 00:00:00 2001 From: Kuba Kaflik Date: Fri, 23 Dec 2022 18:48:59 +0100 Subject: [PATCH 6/8] fix tests --- clickhouse.go | 22 +++++++++++----------- clickhouse_options.go | 3 +-- tests/conn_test.go | 2 +- 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/clickhouse.go b/clickhouse.go index 79e5537d32..5523146245 100644 --- a/clickhouse.go +++ b/clickhouse.go @@ -196,13 +196,10 @@ func (ch *clickhouse) Stats() driver.Stats { func (ch *clickhouse) dial(ctx context.Context) (conn *connect, err error) { connID := int(atomic.AddInt64(&ch.connID, 1)) - dialFunc := func(ctx context.Context, addr string, opt *Options) DialResult { + dialFunc := func(ctx context.Context, addr string, opt *Options) (DialResult, error) { conn, err := dial(ctx, addr, connID, opt) - return DialResult{ - conn: conn, - err: err, - } + return DialResult{conn}, err } dialStrategy := DefaultDialStrategy @@ -214,13 +211,10 @@ func (ch *clickhouse) dial(ctx context.Context) (conn *connect, err error) { if err != nil { return nil, err } - if result.err != nil { - return nil, err - } return result.conn, nil } -func DefaultDialStrategy(ctx context.Context, connID int, opt *Options, dial Dial) (DialResult, error) { +func DefaultDialStrategy(ctx context.Context, connID int, opt *Options, dial Dial) (r DialResult, err error) { for i := range opt.Addr { var num int switch opt.ConnOpenStrategy { @@ -230,10 +224,16 @@ func DefaultDialStrategy(ctx context.Context, connID int, opt *Options, dial Dia num = (int(connID) + i) % len(opt.Addr) } - return dial(ctx, opt.Addr[num], opt), nil + if r, err = dial(ctx, opt.Addr[num], opt); err == nil { + return r, nil + } + } + + if err == nil { + err = ErrAcquireConnNoAddress } - return DialResult{}, ErrAcquireConnNoAddress + return r, err } func (ch *clickhouse) acquire(ctx context.Context) (conn *connect, err error) { diff --git a/clickhouse_options.go b/clickhouse_options.go index f33f8a6c1e..e34e933427 100644 --- a/clickhouse_options.go +++ b/clickhouse_options.go @@ -115,10 +115,9 @@ func ParseDSN(dsn string) (*Options, error) { return opt, nil } -type Dial func(ctx context.Context, addr string, opt *Options) DialResult +type Dial func(ctx context.Context, addr string, opt *Options) (DialResult, error) type DialResult struct { conn *connect - err error } type Options struct { diff --git a/tests/conn_test.go b/tests/conn_test.go index 3ef6f80a4b..0b23a3d0b1 100644 --- a/tests/conn_test.go +++ b/tests/conn_test.go @@ -251,7 +251,7 @@ func TestConnCustomDialStrategy(t *testing.T) { env.Host = "non-existent.host" opts := clientOptionsFromEnv(env, clickhouse.Settings{}) opts.DialStrategy = func(ctx context.Context, connID int, opts *clickhouse.Options, dial clickhouse.Dial) (clickhouse.DialResult, error) { - return dial(ctx, actualAddr, opts), nil + return dial(ctx, actualAddr, opts) } conn, err := clickhouse.Open(&opts) From fe28f72577a3bf32669bf7c5593fd2a85b2d6a79 Mon Sep 17 00:00:00 2001 From: Kuba Kaflik Date: Fri, 23 Dec 2022 18:50:11 +0100 Subject: [PATCH 7/8] remove dial strategy from docs --- README.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/README.md b/README.md index d6f8bca75b..c80b71a142 100644 --- a/README.md +++ b/README.md @@ -72,9 +72,6 @@ Support for the ClickHouse protocol advanced features using `Context`: var d net.Dialer return d.DialContext(ctx, "tcp", addr) }, - opts.DialStrategy = func(ctx context.Context, opts *clickhouse.Options) (*clickhouse.Connect, error) { - return clickhouse.Dial(ctx, "127.0.0.1:5678", 1, opts) - } Debug: true, Debugf: func(format string, v ...interface{}) { fmt.Printf(format, v) From 65896c2ed79087294be13e17d2edf841e6c0aa70 Mon Sep 17 00:00:00 2001 From: Kuba Kaflik Date: Tue, 10 Jan 2023 19:21:50 +0100 Subject: [PATCH 8/8] fix test --- tests/conn_test.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/conn_test.go b/tests/conn_test.go index 0b23a3d0b1..a5600955a7 100644 --- a/tests/conn_test.go +++ b/tests/conn_test.go @@ -247,11 +247,12 @@ func TestConnCustomDialStrategy(t *testing.T) { env, err := GetTestEnvironment(testSet) require.NoError(t, err) - actualAddr := fmt.Sprintf("%s:%d", env.Host, env.Port) - env.Host = "non-existent.host" opts := clientOptionsFromEnv(env, clickhouse.Settings{}) + validAddr := opts.Addr[0] + opts.Addr = []string{"invalid.host:9001"} + opts.DialStrategy = func(ctx context.Context, connID int, opts *clickhouse.Options, dial clickhouse.Dial) (clickhouse.DialResult, error) { - return dial(ctx, actualAddr, opts) + return dial(ctx, validAddr, opts) } conn, err := clickhouse.Open(&opts)