diff --git a/CHANGELOG.md b/CHANGELOG.md index 211ea25a2..58e6878c1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,9 @@ Versioning](http://semver.org/spec/v2.0.0.html) except to the first release. decoded to a varbinary object (#313). - Use objects of the Decimal type instead of pointers (#238) - Use objects of the Datetime type instead of pointers (#238) +- `connection.Connect` and `pool.Connect` no longer return non-working + connection objects (#136). Those functions now accept context as their first + arguments, which user may cancel in process. ### Deprecated diff --git a/README.md b/README.md index aa4c6deac..b2f007cff 100644 --- a/README.md +++ b/README.md @@ -105,13 +105,19 @@ about what it does. package tarantool import ( + "context" "fmt" + "time" + "github.com/tarantool/go-tarantool/v2" ) func main() { opts := tarantool.Opts{User: "guest"} - conn, err := tarantool.Connect("127.0.0.1:3301", opts) + ctx, cancel := context.WithTimeout(context.Background(), + 500 * time.Millisecond) + defer cancel() + conn, err := tarantool.Connect(ctx, "127.0.0.1:3301", opts) if err != nil { fmt.Println("Connection refused:", err) } @@ -139,6 +145,10 @@ starting a session. There are two parameters: * a string with `host:port` format, and * the option structure that was set up earlier. +There will be only one attempt to connect. If multiple attempts needed, +"`tarantool.Connect`" could be placed inside the loop with some timeout +between each try. Example could be found in [example_test](./example_test.go). + **Observation 4:** The `err` structure will be `nil` if there is no error, otherwise it will have a description which can be retrieved with `err.Error()`. @@ -167,10 +177,15 @@ The subpackage has been deleted. You could use `pool` instead. #### pool package -The logic has not changed, but there are a few renames: - * The `connection_pool` subpackage has been renamed to `pool`. * The type `PoolOpts` has been renamed to `Opts`. +* `pool.Connect` no longer return non-working connection objects. This function + now accept context as first argument, which user may cancel in process. + +#### connection package + +`connection.Connect` no longer return non-working connection objects. This +function now accept context as first argument, which user may cancel in process. #### msgpack.v5 diff --git a/connection.go b/connection.go index 9bb42626a..229650d26 100644 --- a/connection.go +++ b/connection.go @@ -378,13 +378,9 @@ func (opts Opts) Clone() Opts { // // Notes: // -// - If opts.Reconnect is zero (default), then connection either already connected -// or error is returned. -// -// - If opts.Reconnect is non-zero, then error will be returned only if authorization -// fails. But if Tarantool is not reachable, then it will make an attempt to reconnect later -// and will not finish to make attempts on authorization failures. -func Connect(addr string, opts Opts) (conn *Connection, err error) { +// - There will be only one attempt to connect. If multiple attempts needed, +// Connect could be placed inside the loop with some timeout between each try. +func Connect(ctx context.Context, addr string, opts Opts) (conn *Connection, err error) { conn = &Connection{ addr: addr, requestId: 0, @@ -432,25 +428,8 @@ func Connect(addr string, opts Opts) (conn *Connection, err error) { conn.cond = sync.NewCond(&conn.mutex) - if err = conn.createConnection(false); err != nil { - ter, ok := err.(Error) - if conn.opts.Reconnect <= 0 { - return nil, err - } else if ok && (ter.Code == iproto.ER_NO_SUCH_USER || - ter.Code == iproto.ER_CREDS_MISMATCH) { - // Reported auth errors immediately. - return nil, err - } else { - // Without SkipSchema it is useless. - go func(conn *Connection) { - conn.mutex.Lock() - defer conn.mutex.Unlock() - if err := conn.createConnection(true); err != nil { - conn.closeConnection(err, true) - } - }(conn) - err = nil - } + if err = conn.createConnection(ctx); err != nil { + return nil, err } go conn.pinger() @@ -534,18 +513,11 @@ func (conn *Connection) cancelFuture(fut *Future, err error) { } } -func (conn *Connection) dial() (err error) { +func (conn *Connection) dial(ctx context.Context) (err error) { opts := conn.opts - dialTimeout := opts.Reconnect / 2 - if dialTimeout == 0 { - dialTimeout = 500 * time.Millisecond - } else if dialTimeout > 5*time.Second { - dialTimeout = 5 * time.Second - } var c Conn - c, err = conn.opts.Dialer.Dial(conn.addr, DialOpts{ - DialTimeout: dialTimeout, + c, err = conn.opts.Dialer.Dial(ctx, conn.addr, DialOpts{ IoTimeout: opts.Timeout, Transport: opts.Transport, Ssl: opts.Ssl, @@ -658,34 +630,19 @@ func pack(h *smallWBuf, enc *msgpack.Encoder, reqid uint32, return } -func (conn *Connection) createConnection(reconnect bool) (err error) { - var reconnects uint - for conn.c == nil && conn.state == connDisconnected { - now := time.Now() - err = conn.dial() - if err == nil || !reconnect { - if err == nil { - conn.notify(Connected) - } - return - } - if conn.opts.MaxReconnects > 0 && reconnects > conn.opts.MaxReconnects { - conn.opts.Logger.Report(LogLastReconnectFailed, conn, err) - err = ClientError{ErrConnectionClosed, "last reconnect failed"} - // mark connection as closed to avoid reopening by another goroutine - return +func (conn *Connection) createConnection(ctx context.Context) error { + var err error + if conn.c == nil && conn.state == connDisconnected { + err = conn.dial(ctx) + if err == nil { + conn.notify(Connected) + return nil } - conn.opts.Logger.Report(LogReconnectFailed, conn, reconnects, err) - conn.notify(ReconnectFailed) - reconnects++ - conn.mutex.Unlock() - time.Sleep(time.Until(now.Add(conn.opts.Reconnect))) - conn.mutex.Lock() } if conn.state == connClosed { err = ClientError{ErrConnectionClosed, "using closed connection"} } - return + return err } func (conn *Connection) closeConnection(neterr error, forever bool) (err error) { @@ -727,11 +684,58 @@ func (conn *Connection) closeConnection(neterr error, forever bool) (err error) return } +func (conn *Connection) getDialTimeout() time.Duration { + dialTimeout := conn.opts.Reconnect / 2 + if dialTimeout == 0 { + dialTimeout = 500 * time.Millisecond + } else if dialTimeout > 5*time.Second { + dialTimeout = 5 * time.Second + } + return dialTimeout +} + +func (conn *Connection) runReconnects() error { + dialTimeout := conn.getDialTimeout() + var reconnects uint + var err error + + for reconnects <= conn.opts.MaxReconnects { + now := time.Now() + + ctx, cancel := context.WithTimeout(context.Background(), dialTimeout) + err = conn.createConnection(ctx) + cancel() + + if err == nil { + return nil + } + if clientErr, ok := err.(ClientError); ok && + clientErr.Code == ErrConnectionClosed { + return err + } + + conn.opts.Logger.Report(LogReconnectFailed, conn, reconnects, err) + conn.notify(ReconnectFailed) + if conn.opts.MaxReconnects != 0 { + reconnects++ + } + conn.mutex.Unlock() + + time.Sleep(time.Until(now.Add(conn.opts.Reconnect))) + + conn.mutex.Lock() + } + + conn.opts.Logger.Report(LogLastReconnectFailed, conn, err) + // mark connection as closed to avoid reopening by another goroutine + return ClientError{ErrConnectionClosed, "last reconnect failed"} +} + func (conn *Connection) reconnectImpl(neterr error, c Conn) { if conn.opts.Reconnect > 0 { if c == conn.c { conn.closeConnection(neterr, false) - if err := conn.createConnection(true); err != nil { + if err := conn.runReconnects(); err != nil { conn.closeConnection(err, true) } } diff --git a/crud/example_test.go b/crud/example_test.go index 363d0570d..79e9f42a1 100644 --- a/crud/example_test.go +++ b/crud/example_test.go @@ -1,6 +1,7 @@ package crud_test import ( + "context" "fmt" "reflect" "time" @@ -21,7 +22,9 @@ var exampleOpts = tarantool.Opts{ } func exampleConnect() *tarantool.Connection { - conn, err := tarantool.Connect(exampleServer, exampleOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err := tarantool.Connect(ctx, exampleServer, exampleOpts) if err != nil { panic("Connection is not established: " + err.Error()) } diff --git a/crud/tarantool_test.go b/crud/tarantool_test.go index 5cf29f66a..facae3980 100644 --- a/crud/tarantool_test.go +++ b/crud/tarantool_test.go @@ -1,6 +1,7 @@ package crud_test import ( + "context" "fmt" "log" "os" @@ -108,7 +109,10 @@ var object = crud.MapObject{ func connect(t testing.TB) *tarantool.Connection { for i := 0; i < 10; i++ { - conn, err := tarantool.Connect(server, opts) + ctx, cancel := context.WithTimeout(context.Background(), + 500*time.Millisecond) + conn, err := tarantool.Connect(ctx, server, opts) + cancel() if err != nil { t.Fatalf("Failed to connect: %s", err) } diff --git a/datetime/example_test.go b/datetime/example_test.go index 346551629..954f43548 100644 --- a/datetime/example_test.go +++ b/datetime/example_test.go @@ -9,6 +9,7 @@ package datetime_test import ( + "context" "fmt" "time" @@ -23,7 +24,9 @@ func Example() { User: "test", Pass: "test", } - conn, err := tarantool.Connect("127.0.0.1:3013", opts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err := tarantool.Connect(ctx, "127.0.0.1:3013", opts) if err != nil { fmt.Printf("Error in connect is %v", err) return diff --git a/decimal/example_test.go b/decimal/example_test.go index a355767f1..20b5486ca 100644 --- a/decimal/example_test.go +++ b/decimal/example_test.go @@ -9,6 +9,7 @@ package decimal_test import ( + "context" "log" "time" @@ -28,7 +29,9 @@ func Example() { User: "test", Pass: "test", } - client, err := tarantool.Connect(server, opts) + ctx, cancel := context.WithTimeout(context.Background(), opts.Reconnect/2) + client, err := tarantool.Connect(ctx, server, opts) + cancel() if err != nil { log.Fatalf("Failed to connect: %s", err.Error()) } diff --git a/dial.go b/dial.go index 3ba493ac7..5b17c0534 100644 --- a/dial.go +++ b/dial.go @@ -3,6 +3,7 @@ package tarantool import ( "bufio" "bytes" + "context" "errors" "fmt" "io" @@ -56,8 +57,6 @@ type Conn interface { // DialOpts is a way to configure a Dial method to create a new Conn. type DialOpts struct { - // DialTimeout is a timeout for an initial network dial. - DialTimeout time.Duration // IoTimeout is a timeout per a network read/write. IoTimeout time.Duration // Transport is a connect transport type. @@ -86,7 +85,7 @@ type DialOpts struct { type Dialer interface { // Dial connects to a Tarantool instance to the address with specified // options. - Dial(address string, opts DialOpts) (Conn, error) + Dial(ctx context.Context, address string, opts DialOpts) (Conn, error) } type tntConn struct { @@ -104,11 +103,11 @@ type TtDialer struct { // Dial connects to a Tarantool instance to the address with specified // options. -func (t TtDialer) Dial(address string, opts DialOpts) (Conn, error) { +func (t TtDialer) Dial(ctx context.Context, address string, opts DialOpts) (Conn, error) { var err error conn := new(tntConn) - if conn.net, err = dial(address, opts); err != nil { + if conn.net, err = dial(ctx, address, opts); err != nil { return nil, fmt.Errorf("failed to dial: %w", err) } @@ -199,13 +198,14 @@ func (c *tntConn) ProtocolInfo() ProtocolInfo { } // dial connects to a Tarantool instance. -func dial(address string, opts DialOpts) (net.Conn, error) { +func dial(ctx context.Context, address string, opts DialOpts) (net.Conn, error) { network, address := parseAddress(address) switch opts.Transport { case dialTransportNone: - return net.DialTimeout(network, address, opts.DialTimeout) + dialer := net.Dialer{} + return dialer.DialContext(ctx, network, address) case dialTransportSsl: - return sslDialTimeout(network, address, opts.DialTimeout, opts.Ssl) + return sslDialContext(ctx, network, address, opts.Ssl) default: return nil, fmt.Errorf("unsupported transport type: %s", opts.Transport) } diff --git a/dial_test.go b/dial_test.go index ff8a50aab..0bf7558e5 100644 --- a/dial_test.go +++ b/dial_test.go @@ -2,8 +2,11 @@ package tarantool_test import ( "bytes" + "context" "errors" + "fmt" "net" + "strings" "sync" "testing" "time" @@ -18,7 +21,7 @@ type mockErrorDialer struct { err error } -func (m mockErrorDialer) Dial(address string, +func (m mockErrorDialer) Dial(ctx context.Context, address string, opts tarantool.DialOpts) (tarantool.Conn, error) { return nil, m.err } @@ -29,31 +32,37 @@ func TestDialer_Dial_error(t *testing.T) { err: errors.New(errMsg), } - conn, err := tarantool.Connect("any", tarantool.Opts{ - Dialer: dialer, - }) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err := tarantool.Connect(ctx, "any", + tarantool.Opts{ + Dialer: dialer, + }) assert.Nil(t, conn) assert.ErrorContains(t, err, errMsg) } type mockPassedDialer struct { + ctx context.Context address string opts tarantool.DialOpts } -func (m *mockPassedDialer) Dial(address string, +func (m *mockPassedDialer) Dial(ctx context.Context, address string, opts tarantool.DialOpts) (tarantool.Conn, error) { m.address = address m.opts = opts + if ctx.Value("foo") != m.ctx.Value("foo") { + return nil, errors.New("wrong context") + } return nil, errors.New("does not matter") } func TestDialer_Dial_passedOpts(t *testing.T) { const addr = "127.0.0.1:8080" opts := tarantool.DialOpts{ - DialTimeout: 500 * time.Millisecond, - IoTimeout: 2, - Transport: "any", + IoTimeout: 2, + Transport: "any", Ssl: tarantool.SslOpts{ KeyFile: "a", CertFile: "b", @@ -73,7 +82,16 @@ func TestDialer_Dial_passedOpts(t *testing.T) { } dialer := &mockPassedDialer{} - conn, err := tarantool.Connect(addr, tarantool.Opts{ + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + type key int + const ctxKey key = iota + + ctx = context.WithValue(ctx, ctxKey, "bar") + dialer.ctx = ctx + + conn, err := tarantool.Connect(ctx, addr, tarantool.Opts{ Dialer: dialer, Timeout: opts.IoTimeout, Transport: opts.Transport, @@ -86,6 +104,7 @@ func TestDialer_Dial_passedOpts(t *testing.T) { assert.Nil(t, conn) assert.NotNil(t, err) + assert.NotEqual(t, err.Error(), "wrong context") assert.Equal(t, addr, dialer.address) assert.Equal(t, opts, dialer.opts) } @@ -187,7 +206,7 @@ func newMockIoConn() *mockIoConn { return conn } -func (m *mockIoDialer) Dial(address string, +func (m *mockIoDialer) Dial(ctx context.Context, address string, opts tarantool.DialOpts) (tarantool.Conn, error) { m.conn = newMockIoConn() if m.init != nil { @@ -203,11 +222,14 @@ func dialIo(t *testing.T, dialer := mockIoDialer{ init: init, } - conn, err := tarantool.Connect("any", tarantool.Opts{ - Dialer: &dialer, - Timeout: 1000 * time.Second, // Avoid pings. - SkipSchema: true, - }) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err := tarantool.Connect(ctx, "any", + tarantool.Opts{ + Dialer: &dialer, + Timeout: 1000 * time.Second, // Avoid pings. + SkipSchema: true, + }) require.Nil(t, err) require.NotNil(t, conn) @@ -338,3 +360,19 @@ func TestConn_ReadWrite(t *testing.T) { assert.Nil(t, err) assert.NotNil(t, resp) } + +func TestConn_ContextCancel(t *testing.T) { + const addr = "127.0.0.1:8080" + + dialer := tarantool.TtDialer{} + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + conn, err := dialer.Dial(ctx, addr, tarantool.DialOpts{}) + + assert.Nil(t, conn) + assert.NotNil(t, err) + assert.Truef(t, strings.Contains(err.Error(), "operation was canceled"), + fmt.Sprintf("unexpected error, expected to contain %s, got %v", + "operation was canceled", err)) +} diff --git a/example_custom_unpacking_test.go b/example_custom_unpacking_test.go index 1189e16a3..8ab8f5370 100644 --- a/example_custom_unpacking_test.go +++ b/example_custom_unpacking_test.go @@ -1,6 +1,7 @@ package tarantool_test import ( + "context" "fmt" "log" "time" @@ -84,7 +85,9 @@ func Example_customUnpacking() { User: "test", Pass: "test", } - conn, err := tarantool.Connect(server, opts) + ctx, cancel := context.WithTimeout(context.Background(), opts.Reconnect/2) + conn, err := tarantool.Connect(ctx, server, opts) + cancel() if err != nil { log.Fatalf("Failed to connect: %s", err.Error()) } diff --git a/example_test.go b/example_test.go index d871d578a..a774b530d 100644 --- a/example_test.go +++ b/example_test.go @@ -19,7 +19,9 @@ type Tuple struct { } func exampleConnect(opts tarantool.Opts) *tarantool.Connection { - conn, err := tarantool.Connect(server, opts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err := tarantool.Connect(ctx, server, opts) if err != nil { panic("Connection is not established: " + err.Error()) } @@ -38,7 +40,9 @@ func ExampleSslOpts() { CaFile: "testdata/ca.crt", }, } - _, err := tarantool.Connect("127.0.0.1:3013", opts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + _, err := tarantool.Connect(ctx, "127.0.0.1:3013", opts) if err != nil { panic("Connection is not established: " + err.Error()) } @@ -913,12 +917,51 @@ func ExampleFuture_GetIterator() { } func ExampleConnect() { - conn, err := tarantool.Connect("127.0.0.1:3013", tarantool.Opts{ - Timeout: 5 * time.Second, - User: "test", - Pass: "test", - Concurrency: 32, - }) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + conn, err := tarantool.Connect(ctx, "127.0.0.1:3013", + tarantool.Opts{ + Timeout: 5 * time.Second, + User: "test", + Pass: "test", + Concurrency: 32, + }) + if err != nil { + fmt.Println("No connection available") + return + } + defer conn.Close() + if conn != nil { + fmt.Println("Connection is ready") + } + // Output: + // Connection is ready +} + +func ExampleConnect_reconnects() { + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + opts := tarantool.Opts{ + Timeout: 5 * time.Second, + User: "test", + Pass: "test", + Concurrency: 32, + Reconnect: time.Second, + MaxReconnects: 10, + } + + var conn *tarantool.Connection + var err error + + for i := uint(0); i < opts.MaxReconnects; i++ { + conn, err = tarantool.Connect(ctx, "127.0.0.1:3013", opts) + if err == nil { + break + } + time.Sleep(opts.Reconnect) + } if err != nil { fmt.Println("No connection available") return @@ -1081,7 +1124,9 @@ func ExampleConnection_NewPrepared() { User: "test", Pass: "test", } - conn, err := tarantool.Connect(server, opts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err := tarantool.Connect(ctx, server, opts) if err != nil { fmt.Printf("Failed to connect: %s", err.Error()) } @@ -1127,7 +1172,9 @@ func ExampleConnection_NewWatcher() { Features: []tarantool.ProtocolFeature{tarantool.WatchersFeature}, }, } - conn, err := tarantool.Connect(server, opts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err := tarantool.Connect(ctx, server, opts) if err != nil { fmt.Printf("Failed to connect: %s\n", err) return diff --git a/export_test.go b/export_test.go index 10d194840..adb67976c 100644 --- a/export_test.go +++ b/export_test.go @@ -1,6 +1,7 @@ package tarantool import ( + "context" "net" "time" @@ -12,6 +13,11 @@ func SslDialTimeout(network, address string, timeout time.Duration, return sslDialTimeout(network, address, timeout, opts) } +func SslDialContext(ctx context.Context, network, address string, + opts SslOpts) (connection net.Conn, err error) { + return sslDialContext(ctx, network, address, opts) +} + func SslCreateContext(opts SslOpts) (ctx interface{}, err error) { return sslCreateContext(opts) } diff --git a/go.mod b/go.mod index bd848308c..68f46c634 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/shopspring/decimal v1.3.1 github.com/stretchr/testify v1.7.1 github.com/tarantool/go-iproto v0.1.0 - github.com/tarantool/go-openssl v0.0.8-0.20230801114713-b452431f934a + github.com/tarantool/go-openssl v0.0.8-0.20231002130016-e88579e113cf github.com/vmihailenco/msgpack/v5 v5.3.5 golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4 // indirect gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect diff --git a/go.sum b/go.sum index 1810c2b3a..414b13c56 100644 --- a/go.sum +++ b/go.sum @@ -21,8 +21,8 @@ github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMT github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/tarantool/go-iproto v0.1.0 h1:zHN9AA8LDawT+JBD0/Nxgr/bIsWkkpDzpcMuaNPSIAQ= github.com/tarantool/go-iproto v0.1.0/go.mod h1:LNCtdyZxojUed8SbOiYHoc3v9NvaZTB7p96hUySMlIo= -github.com/tarantool/go-openssl v0.0.8-0.20230801114713-b452431f934a h1:eeElglRXJ3xWKkHmDbeXrQWlZyQ4t3Ca1YlZsrfdXFU= -github.com/tarantool/go-openssl v0.0.8-0.20230801114713-b452431f934a/go.mod h1:M7H4xYSbzqpW/ZRBMyH0eyqQBsnhAMfsYk5mv0yid7A= +github.com/tarantool/go-openssl v0.0.8-0.20231002130016-e88579e113cf h1:oCQZliFthJ2j/4TgD3PFwazfZjsn+wCA4xOLi2yO7cI= +github.com/tarantool/go-openssl v0.0.8-0.20231002130016-e88579e113cf/go.mod h1:M7H4xYSbzqpW/ZRBMyH0eyqQBsnhAMfsYk5mv0yid7A= github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU= github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= diff --git a/pool/connection_pool.go b/pool/connection_pool.go index 26e2199e9..042ddc1a5 100644 --- a/pool/connection_pool.go +++ b/pool/connection_pool.go @@ -11,6 +11,7 @@ package pool import ( + "context" "errors" "log" "sync" @@ -33,6 +34,7 @@ var ( ErrClosed = errors.New("pool is closed") ErrUnknownRequest = errors.New("the passed connected request doesn't belong to " + "the current connection pool") + ErrContextCanceled = errors.New("operation was canceled") ) // ConnectionHandler provides callbacks for components interested in handling @@ -103,6 +105,7 @@ type ConnectionPool struct { anyPool *roundRobinStrategy poolsMutex sync.RWMutex watcherContainer watcherContainer + ctxCancels []context.CancelFunc } var _ Pooler = (*ConnectionPool)(nil) @@ -133,7 +136,8 @@ func newEndpoint(addr string) *endpoint { // ConnectWithOpts creates pool for instances with addresses addrs // with options opts. -func ConnectWithOpts(addrs []string, connOpts tarantool.Opts, opts Opts) (*ConnectionPool, error) { +func ConnectWithOpts(ctx context.Context, addrs []string, + connOpts tarantool.Opts, opts Opts) (*ConnectionPool, error) { if len(addrs) == 0 { return nil, ErrEmptyAddrs } @@ -161,16 +165,21 @@ func ConnectWithOpts(addrs []string, connOpts tarantool.Opts, opts Opts) (*Conne connPool.addrs[addr] = nil } - somebodyAlive := connPool.fillPools() + somebodyAlive, ctxCanceled := connPool.fillPools(ctx) if !somebodyAlive { connPool.state.set(closedState) + if ctxCanceled { + return nil, ErrContextCanceled + } return nil, ErrNoConnection } connPool.state.set(connectedState) + controllerCtx, cancel := context.WithCancel(context.Background()) + connPool.ctxCancels = append(connPool.ctxCancels, cancel) for _, s := range connPool.addrs { - go connPool.controller(s) + go connPool.controller(controllerCtx, s) } return connPool, nil @@ -181,11 +190,12 @@ func ConnectWithOpts(addrs []string, connOpts tarantool.Opts, opts Opts) (*Conne // It is useless to set up tarantool.Opts.Reconnect value for a connection. // The connection pool has its own reconnection logic. See // Opts.CheckTimeout description. -func Connect(addrs []string, connOpts tarantool.Opts) (*ConnectionPool, error) { +func Connect(ctx context.Context, addrs []string, + connOpts tarantool.Opts) (*ConnectionPool, error) { opts := Opts{ CheckTimeout: 1 * time.Second, } - return ConnectWithOpts(addrs, connOpts, opts) + return ConnectWithOpts(ctx, addrs, connOpts, opts) } // ConnectedNow gets connected status of pool. @@ -224,7 +234,7 @@ func (p *ConnectionPool) ConfiguredTimeout(mode Mode) (time.Duration, error) { // Add adds a new endpoint with the address into the pool. This function // adds the endpoint only after successful connection. -func (p *ConnectionPool) Add(addr string) error { +func (p *ConnectionPool) Add(ctx context.Context, addr string) error { e := newEndpoint(addr) p.addrsMutex.Lock() @@ -240,7 +250,7 @@ func (p *ConnectionPool) Add(addr string) error { p.addrs[addr] = e p.addrsMutex.Unlock() - if err := p.tryConnect(e); err != nil { + if err := p.tryConnect(ctx, e); err != nil { p.addrsMutex.Lock() delete(p.addrs, addr) p.addrsMutex.Unlock() @@ -248,7 +258,9 @@ func (p *ConnectionPool) Add(addr string) error { return err } - go p.controller(e) + controllerCtx, cancel := context.WithCancel(context.Background()) + p.ctxCancels = append(p.ctxCancels, cancel) + go p.controller(controllerCtx, e) return nil } @@ -306,6 +318,9 @@ func (p *ConnectionPool) Close() []error { } p.addrsMutex.RUnlock() } + for _, cancel := range p.ctxCancels { + cancel() + } return p.waitClose() } @@ -1109,8 +1124,16 @@ func (p *ConnectionPool) handlerDeactivated(conn *tarantool.Connection, } } -func (p *ConnectionPool) fillPools() bool { +func (p *ConnectionPool) deactivateConnection(addr string, + conn *tarantool.Connection, role Role) { + p.deleteConnection(addr) + conn.Close() + p.handlerDeactivated(conn, role) +} + +func (p *ConnectionPool) fillPools(ctx context.Context) (bool, bool) { somebodyAlive := false + ctxCanceled := false // It is called before controller() goroutines so we don't expect // concurrency issues here. @@ -1120,9 +1143,26 @@ func (p *ConnectionPool) fillPools() bool { connOpts := p.connOpts connOpts.Notify = end.notify - conn, err := tarantool.Connect(addr, connOpts) + conn, err := tarantool.Connect(ctx, addr, connOpts) if err != nil { log.Printf("tarantool: connect to %s failed: %s\n", addr, err.Error()) + select { + case <-ctx.Done(): + ctxCanceled = true + + p.addrs[addr] = nil + log.Printf("tarantool: operation was canceled") + + for address, endpoint := range p.addrs { + if endpoint == nil { + continue + } + p.deactivateConnection(address, endpoint.conn, endpoint.role) + } + + return false, ctxCanceled + default: + } } else if conn != nil { role, err := p.getConnectionRole(conn) if err != nil { @@ -1142,9 +1182,7 @@ func (p *ConnectionPool) fillPools() bool { end.role = role somebodyAlive = true } else { - p.deleteConnection(addr) - conn.Close() - p.handlerDeactivated(conn, role) + p.deactivateConnection(addr, conn, role) } } else { conn.Close() @@ -1152,7 +1190,7 @@ func (p *ConnectionPool) fillPools() bool { } } - return somebodyAlive + return somebodyAlive, ctxCanceled } func (p *ConnectionPool) updateConnection(e *endpoint) { @@ -1213,7 +1251,7 @@ func (p *ConnectionPool) updateConnection(e *endpoint) { } } -func (p *ConnectionPool) tryConnect(e *endpoint) error { +func (p *ConnectionPool) tryConnect(ctx context.Context, e *endpoint) error { p.poolsMutex.Lock() if p.state.get() != connectedState { @@ -1226,7 +1264,7 @@ func (p *ConnectionPool) tryConnect(e *endpoint) error { connOpts := p.connOpts connOpts.Notify = e.notify - conn, err := tarantool.Connect(e.addr, connOpts) + conn, err := tarantool.Connect(ctx, e.addr, connOpts) if err == nil { role, err := p.getConnectionRole(conn) p.poolsMutex.Unlock() @@ -1265,7 +1303,7 @@ func (p *ConnectionPool) tryConnect(e *endpoint) error { return err } -func (p *ConnectionPool) reconnect(e *endpoint) { +func (p *ConnectionPool) reconnect(ctx context.Context, e *endpoint) { p.poolsMutex.Lock() if p.state.get() != connectedState { @@ -1280,10 +1318,10 @@ func (p *ConnectionPool) reconnect(e *endpoint) { e.conn = nil e.role = UnknownRole - p.tryConnect(e) + p.tryConnect(ctx, e) } -func (p *ConnectionPool) controller(e *endpoint) { +func (p *ConnectionPool) controller(ctx context.Context, e *endpoint) { timer := time.NewTicker(p.opts.CheckTimeout) defer timer.Stop() @@ -1367,11 +1405,11 @@ func (p *ConnectionPool) controller(e *endpoint) { // Relocate connection between subpools // if ro/rw was updated. if e.conn == nil { - p.tryConnect(e) + p.tryConnect(ctx, e) } else if !e.conn.ClosedNow() { p.updateConnection(e) } else { - p.reconnect(e) + p.reconnect(ctx, e) } } } diff --git a/pool/connection_pool_test.go b/pool/connection_pool_test.go index dd0210d62..8ae07c048 100644 --- a/pool/connection_pool_test.go +++ b/pool/connection_pool_test.go @@ -1,6 +1,7 @@ package pool_test import ( + "context" "fmt" "log" "os" @@ -46,17 +47,23 @@ var defaultTimeoutRetry = 500 * time.Millisecond var instances []test_helpers.TarantoolInstance func TestConnError_IncorrectParams(t *testing.T) { - connPool, err := pool.Connect([]string{}, tarantool.Opts{}) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + connPool, err := pool.Connect(ctx, []string{}, tarantool.Opts{}) + cancel() require.Nilf(t, connPool, "conn is not nil with incorrect param") require.NotNilf(t, err, "err is nil with incorrect params") require.Equal(t, "addrs (first argument) should not be empty", err.Error()) - connPool, err = pool.Connect([]string{"err1", "err2"}, connOpts) + ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond) + connPool, err = pool.Connect(ctx, []string{"err1", "err2"}, connOpts) + cancel() require.Nilf(t, connPool, "conn is not nil with incorrect param") require.NotNilf(t, err, "err is nil with incorrect params") require.Equal(t, "no active connections", err.Error()) - connPool, err = pool.ConnectWithOpts(servers, tarantool.Opts{}, pool.Opts{}) + ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond) + connPool, err = pool.ConnectWithOpts(ctx, servers, tarantool.Opts{}, pool.Opts{}) + cancel() require.Nilf(t, connPool, "conn is not nil with incorrect param") require.NotNilf(t, err, "err is nil with incorrect params") require.Equal(t, "wrong check timeout, must be greater than 0", err.Error()) @@ -64,7 +71,9 @@ func TestConnError_IncorrectParams(t *testing.T) { func TestConnSuccessfully(t *testing.T) { server := servers[0] - connPool, err := pool.Connect([]string{"err", server}, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, []string{"err", server}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -84,9 +93,77 @@ func TestConnSuccessfully(t *testing.T) { require.Nil(t, err) } +func TestConnErrorAfterCtxCancel(t *testing.T) { + var connLongReconnectOpts = tarantool.Opts{ + Timeout: 5 * time.Second, + User: "test", + Pass: "test", + Reconnect: time.Second, + MaxReconnects: 100, + } + + ctx, cancel := context.WithCancel(context.Background()) + + var connPool *pool.ConnectionPool + var err error + + cancel() + connPool, err = pool.Connect(ctx, servers, connLongReconnectOpts) + + if connPool != nil || err == nil { + t.Fatalf("ConnectionPool was created after cancel") + } + if !strings.Contains(err.Error(), "operation was canceled") { + t.Fatalf("Unexpected error, expected to contain %s, got %v", + "operation was canceled", err) + } +} + +type mockClosingDialer struct { + cnt int + ctx context.Context + ctxCancel context.CancelFunc +} + +func (m *mockClosingDialer) Dial(ctx context.Context, address string, + opts tarantool.DialOpts) (tarantool.Conn, error) { + + dialer := tarantool.TtDialer{} + conn, err := dialer.Dial(m.ctx, address, tarantool.DialOpts{ + User: "test", + Password: "test", + }) + + if m.cnt == 0 { + m.ctxCancel() + } + m.cnt++ + + return conn, err +} + +func TestContextCancelInProgress(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + dialer := &mockClosingDialer{0, ctx, cancel} + + connPool, err := pool.Connect(ctx, servers, tarantool.Opts{ + Dialer: dialer, + }) + require.NotNilf(t, err, "expected err after ctx cancel") + assert.Truef(t, strings.Contains(err.Error(), "operation was canceled"), + fmt.Sprintf("unexpected error, expected to contain %s, got %v", + "operation was canceled", err)) + require.Nilf(t, connPool, "conn is not nil after ctx cancel") +} + func TestConnSuccessfullyDuplicates(t *testing.T) { server := servers[0] - connPool, err := pool.Connect([]string{server, server, server, server}, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, []string{server, server, server, server}, + connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -112,7 +189,9 @@ func TestConnSuccessfullyDuplicates(t *testing.T) { func TestReconnect(t *testing.T) { server := servers[0] - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -158,7 +237,9 @@ func TestDisconnect_withReconnect(t *testing.T) { opts := connOpts opts.Reconnect = 10 * time.Second - connPool, err := pool.Connect([]string{servers[serverId]}, opts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, []string{servers[serverId]}, opts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -202,7 +283,9 @@ func TestDisconnectAll(t *testing.T) { server1 := servers[0] server2 := servers[1] - connPool, err := pool.Connect([]string{server1, server2}, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, []string{server1, server2}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -249,14 +332,18 @@ func TestDisconnectAll(t *testing.T) { } func TestAdd(t *testing.T) { - connPool, err := pool.Connect([]string{servers[0]}, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, []string{servers[0]}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") defer connPool.Close() for _, server := range servers[1:] { - err = connPool.Add(server) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + err = connPool.Add(ctx, server) + cancel() require.Nil(t, err) } @@ -280,13 +367,17 @@ func TestAdd(t *testing.T) { } func TestAdd_exist(t *testing.T) { - connPool, err := pool.Connect([]string{servers[0]}, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + connPool, err := pool.Connect(ctx, []string{servers[0]}, connOpts) + cancel() require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") defer connPool.Close() - err = connPool.Add(servers[0]) + ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond) + err = connPool.Add(ctx, servers[0]) + cancel() require.Equal(t, pool.ErrExists, err) args := test_helpers.CheckStatusesArgs{ @@ -305,13 +396,17 @@ func TestAdd_exist(t *testing.T) { } func TestAdd_unreachable(t *testing.T) { - connPool, err := pool.Connect([]string{servers[0]}, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + connPool, err := pool.Connect(ctx, []string{servers[0]}, connOpts) + cancel() require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") defer connPool.Close() - err = connPool.Add("127.0.0.2:6667") + ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond) + err = connPool.Add(ctx, "127.0.0.2:6667") + cancel() // The OS-dependent error so we just check for existence. require.NotNil(t, err) @@ -331,17 +426,23 @@ func TestAdd_unreachable(t *testing.T) { } func TestAdd_afterClose(t *testing.T) { - connPool, err := pool.Connect([]string{servers[0]}, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + connPool, err := pool.Connect(ctx, []string{servers[0]}, connOpts) + cancel() require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") connPool.Close() - err = connPool.Add(servers[0]) + ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond) + err = connPool.Add(ctx, servers[0]) + cancel() assert.Equal(t, err, pool.ErrClosed) } func TestAdd_Close_concurrent(t *testing.T) { - connPool, err := pool.Connect([]string{servers[0]}, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + connPool, err := pool.Connect(ctx, []string{servers[0]}, connOpts) + cancel() require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -350,7 +451,9 @@ func TestAdd_Close_concurrent(t *testing.T) { go func() { defer wg.Done() - err = connPool.Add(servers[1]) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + err = connPool.Add(ctx, servers[1]) + cancel() if err != nil { assert.Equal(t, pool.ErrClosed, err) } @@ -362,7 +465,9 @@ func TestAdd_Close_concurrent(t *testing.T) { } func TestAdd_CloseGraceful_concurrent(t *testing.T) { - connPool, err := pool.Connect([]string{servers[0]}, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + connPool, err := pool.Connect(ctx, []string{servers[0]}, connOpts) + cancel() require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -371,7 +476,9 @@ func TestAdd_CloseGraceful_concurrent(t *testing.T) { go func() { defer wg.Done() - err = connPool.Add(servers[1]) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + err = connPool.Add(ctx, servers[1]) + cancel() if err != nil { assert.Equal(t, pool.ErrClosed, err) } @@ -383,7 +490,9 @@ func TestAdd_CloseGraceful_concurrent(t *testing.T) { } func TestRemove(t *testing.T) { - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -410,7 +519,9 @@ func TestRemove(t *testing.T) { } func TestRemove_double(t *testing.T) { - connPool, err := pool.Connect([]string{servers[0], servers[1]}, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, []string{servers[0], servers[1]}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -437,7 +548,9 @@ func TestRemove_double(t *testing.T) { } func TestRemove_unknown(t *testing.T) { - connPool, err := pool.Connect([]string{servers[0], servers[1]}, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, []string{servers[0], servers[1]}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -463,7 +576,9 @@ func TestRemove_unknown(t *testing.T) { } func TestRemove_concurrent(t *testing.T) { - connPool, err := pool.Connect([]string{servers[0], servers[1]}, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, []string{servers[0], servers[1]}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -510,7 +625,9 @@ func TestRemove_concurrent(t *testing.T) { } func TestRemove_Close_concurrent(t *testing.T) { - connPool, err := pool.Connect([]string{servers[0], servers[1]}, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, []string{servers[0], servers[1]}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -529,7 +646,9 @@ func TestRemove_Close_concurrent(t *testing.T) { } func TestRemove_CloseGraceful_concurrent(t *testing.T) { - connPool, err := pool.Connect([]string{servers[0], servers[1]}, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, []string{servers[0], servers[1]}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -551,7 +670,9 @@ func TestClose(t *testing.T) { server1 := servers[0] server2 := servers[1] - connPool, err := pool.Connect([]string{server1, server2}, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, []string{server1, server2}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -591,7 +712,9 @@ func TestCloseGraceful(t *testing.T) { server1 := servers[0] server2 := servers[1] - connPool, err := pool.Connect([]string{server1, server2}, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, []string{server1, server2}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -730,7 +853,9 @@ func TestConnectionHandlerOpenUpdateClose(t *testing.T) { CheckTimeout: 100 * time.Microsecond, ConnectionHandler: h, } - connPool, err := pool.ConnectWithOpts(poolServers, connOpts, poolOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.ConnectWithOpts(ctx, poolServers, connOpts, poolOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -804,7 +929,9 @@ func TestConnectionHandlerOpenError(t *testing.T) { CheckTimeout: 100 * time.Microsecond, ConnectionHandler: h, } - connPool, err := pool.ConnectWithOpts(poolServers, connOpts, poolOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.ConnectWithOpts(ctx, poolServers, connOpts, poolOpts) if err == nil { defer connPool.Close() } @@ -846,7 +973,9 @@ func TestConnectionHandlerUpdateError(t *testing.T) { CheckTimeout: 100 * time.Microsecond, ConnectionHandler: h, } - connPool, err := pool.ConnectWithOpts(poolServers, connOpts, poolOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.ConnectWithOpts(ctx, poolServers, connOpts, poolOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") defer connPool.Close() @@ -891,7 +1020,9 @@ func TestRequestOnClosed(t *testing.T) { server1 := servers[0] server2 := servers[1] - connPool, err := pool.Connect([]string{server1, server2}, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, []string{server1, server2}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -930,7 +1061,9 @@ func TestGetPoolInfo(t *testing.T) { srvs := []string{server1, server2} expected := []string{server1, server2} - connPool, err := pool.Connect(srvs, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, srvs, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -948,7 +1081,9 @@ func TestCall(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1005,7 +1140,9 @@ func TestCall16(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1062,7 +1199,9 @@ func TestCall17(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1119,7 +1258,9 @@ func TestEval(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1197,7 +1338,9 @@ func TestExecute(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1254,7 +1397,9 @@ func TestRoundRobinStrategy(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1331,7 +1476,9 @@ func TestRoundRobinStrategy_NoReplica(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1402,7 +1549,9 @@ func TestRoundRobinStrategy_NoMaster(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1485,7 +1634,9 @@ func TestUpdateInstancesRoles(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1629,7 +1780,9 @@ func TestInsert(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1728,7 +1881,9 @@ func TestDelete(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1792,7 +1947,9 @@ func TestUpsert(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1864,7 +2021,9 @@ func TestUpdate(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1954,7 +2113,9 @@ func TestReplace(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -2040,7 +2201,9 @@ func TestSelect(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -2160,7 +2323,9 @@ func TestPing(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -2198,7 +2363,9 @@ func TestDo(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -2237,7 +2404,9 @@ func TestDo_concurrent(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -2268,7 +2437,9 @@ func TestNewPrepared(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -2336,7 +2507,9 @@ func TestDoWithStrangerConn(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -2365,7 +2538,9 @@ func TestStream_Commit(t *testing.T) { err = test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") defer connPool.Close() @@ -2464,7 +2639,9 @@ func TestStream_Rollback(t *testing.T) { err = test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") defer connPool.Close() @@ -2563,7 +2740,9 @@ func TestStream_TxnIsolationLevel(t *testing.T) { err = test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") defer connPool.Close() @@ -2657,7 +2836,9 @@ func TestConnectionPool_NewWatcher_noWatchersFeature(t *testing.T) { err := test_helpers.SetClusterRO(servers, opts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, opts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, servers, opts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") defer connPool.Close() @@ -2685,7 +2866,9 @@ func TestConnectionPool_NewWatcher_modes(t *testing.T) { err := test_helpers.SetClusterRO(servers, opts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, opts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, servers, opts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") defer connPool.Close() @@ -2767,7 +2950,9 @@ func TestConnectionPool_NewWatcher_update(t *testing.T) { poolOpts := pool.Opts{ CheckTimeout: 500 * time.Millisecond, } - pool, err := pool.ConnectWithOpts(servers, opts, poolOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + pool, err := pool.ConnectWithOpts(ctx, servers, opts, poolOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, pool, "conn is nil after Connect") @@ -2851,7 +3036,9 @@ func TestWatcher_Unregister(t *testing.T) { err := test_helpers.SetClusterRO(servers, opts, roles) require.Nilf(t, err, "fail to set roles for cluster") - pool, err := pool.Connect(servers, opts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + pool, err := pool.Connect(ctx, servers, opts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, pool, "conn is nil after Connect") defer pool.Close() @@ -2910,7 +3097,9 @@ func TestConnectionPool_NewWatcher_concurrent(t *testing.T) { err := test_helpers.SetClusterRO(servers, opts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, opts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, servers, opts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") defer connPool.Close() @@ -2950,7 +3139,9 @@ func TestWatcher_Unregister_concurrent(t *testing.T) { err := test_helpers.SetClusterRO(servers, opts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, opts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.Connect(ctx, servers, opts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") defer connPool.Close() diff --git a/pool/example_test.go b/pool/example_test.go index 84a41ff7b..75204f22b 100644 --- a/pool/example_test.go +++ b/pool/example_test.go @@ -1,6 +1,7 @@ package pool_test import ( + "context" "fmt" "time" @@ -24,7 +25,8 @@ func examplePool(roles []bool, connOpts tarantool.Opts) (*pool.ConnectionPool, e if err != nil { return nil, fmt.Errorf("ConnectionPool is not established") } - connPool, err := pool.Connect(servers, connOpts) + ctx := context.Background() + connPool, err := pool.Connect(ctx, servers, connOpts) if err != nil || connPool == nil { return nil, fmt.Errorf("ConnectionPool is not established") } diff --git a/queue/example_connection_pool_test.go b/queue/example_connection_pool_test.go index 51fb967a5..a3bfaf0a4 100644 --- a/queue/example_connection_pool_test.go +++ b/queue/example_connection_pool_test.go @@ -1,6 +1,7 @@ package queue_test import ( + "context" "fmt" "sync" "sync/atomic" @@ -164,7 +165,9 @@ func Example_connectionPool() { CheckTimeout: 5 * time.Second, ConnectionHandler: h, } - connPool, err := pool.ConnectWithOpts(servers, connOpts, poolOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.ConnectWithOpts(ctx, servers, connOpts, poolOpts) if err != nil { fmt.Printf("Unable to connect to the pool: %s", err) return diff --git a/queue/example_msgpack_test.go b/queue/example_msgpack_test.go index 6fd101e09..6d3637417 100644 --- a/queue/example_msgpack_test.go +++ b/queue/example_msgpack_test.go @@ -9,6 +9,7 @@ package queue_test import ( + "context" "fmt" "log" "time" @@ -55,7 +56,9 @@ func Example_simpleQueueCustomMsgPack() { User: "test", Pass: "test", } - conn, err := tarantool.Connect("127.0.0.1:3013", opts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + conn, err := tarantool.Connect(ctx, "127.0.0.1:3013", opts) + cancel() if err != nil { log.Fatalf("connection: %s", err) return diff --git a/queue/example_test.go b/queue/example_test.go index 711ee31d4..e81acca40 100644 --- a/queue/example_test.go +++ b/queue/example_test.go @@ -9,6 +9,7 @@ package queue_test import ( + "context" "fmt" "time" @@ -31,7 +32,9 @@ func Example_simpleQueue() { Pass: "test", } - conn, err := tarantool.Connect("127.0.0.1:3013", opts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err := tarantool.Connect(ctx, "127.0.0.1:3013", opts) if err != nil { fmt.Printf("error in prepare is %v", err) return diff --git a/settings/example_test.go b/settings/example_test.go index b1d0e5d4f..29be33bfe 100644 --- a/settings/example_test.go +++ b/settings/example_test.go @@ -1,7 +1,9 @@ package settings_test import ( + "context" "fmt" + "time" "github.com/tarantool/go-tarantool/v2" "github.com/tarantool/go-tarantool/v2/settings" @@ -9,7 +11,9 @@ import ( ) func example_connect(opts tarantool.Opts) *tarantool.Connection { - conn, err := tarantool.Connect(server, opts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err := tarantool.Connect(ctx, server, opts) if err != nil { panic("Connection is not established: " + err.Error()) } diff --git a/shutdown_test.go b/shutdown_test.go index bb4cfa099..4f685cbd8 100644 --- a/shutdown_test.go +++ b/shutdown_test.go @@ -6,6 +6,7 @@ package tarantool_test import ( + "context" "fmt" "sync" "syscall" @@ -460,9 +461,10 @@ func TestGracefulShutdownCloseConcurrent(t *testing.T) { go func(i int) { defer caseWg.Done() - // Do not wait till Tarantool register out watcher, - // test everything is ok even on async. - conn, err := Connect(shtdnServer, shtdnClntOpts) + ctx, cancel := context.WithTimeout(context.Background(), + 500*time.Millisecond) + defer cancel() + conn, err := Connect(ctx, shtdnServer, shtdnClntOpts) if err != nil { t.Errorf("Failed to connect: %s", err) } else { diff --git a/ssl.go b/ssl.go index a23238849..a0743edd1 100644 --- a/ssl.go +++ b/ssl.go @@ -5,6 +5,7 @@ package tarantool import ( "bufio" + "context" "errors" "io/ioutil" "net" @@ -25,6 +26,16 @@ func sslDialTimeout(network, address string, timeout time.Duration, return openssl.DialTimeout(network, address, timeout, ctx.(*openssl.Ctx), 0) } +func sslDialContext(ctx context.Context, network, address string, + opts SslOpts) (connection net.Conn, err error) { + var sslCtx interface{} + if sslCtx, err = sslCreateContext(opts); err != nil { + return + } + + return openssl.DialContext(ctx, network, address, sslCtx.(*openssl.Ctx), 0) +} + // interface{} is a hack. It helps to avoid dependency of go-openssl in build // of tests with the tag 'go_tarantool_ssl_disable'. func sslCreateContext(opts SslOpts) (ctx interface{}, err error) { diff --git a/ssl_test.go b/ssl_test.go index 30078703c..9dd79e0e8 100644 --- a/ssl_test.go +++ b/ssl_test.go @@ -4,6 +4,7 @@ package tarantool_test import ( + "context" "errors" "fmt" "io/ioutil" @@ -60,12 +61,16 @@ func serverSslRecv(msgs <-chan string, errs <-chan error) (string, error) { return <-msgs, <-errs } -func clientSsl(network, address string, opts SslOpts) (net.Conn, error) { - timeout := 5 * time.Second - return SslDialTimeout(network, address, timeout, opts) +func clientSsl(ctx context.Context, network, address string, + opts SslOpts) (net.Conn, error) { + if ctx == nil { + timeout := 5 * time.Second + return SslDialTimeout(network, address, timeout, opts) + } + return SslDialContext(ctx, network, address, opts) } -func createClientServerSsl(t testing.TB, serverOpts, +func createClientServerSsl(ctx context.Context, t testing.TB, serverOpts, clientOpts SslOpts) (net.Listener, net.Conn, <-chan string, <-chan error, error) { t.Helper() @@ -77,16 +82,16 @@ func createClientServerSsl(t testing.TB, serverOpts, msgs, errs := serverSslAccept(l) port := l.Addr().(*net.TCPAddr).Port - c, err := clientSsl("tcp", sslHost+":"+strconv.Itoa(port), clientOpts) + c, err := clientSsl(ctx, "tcp", sslHost+":"+strconv.Itoa(port), clientOpts) return l, c, msgs, errs, err } -func createClientServerSslOk(t testing.TB, serverOpts, +func createClientServerSslOk(ctx context.Context, t testing.TB, serverOpts, clientOpts SslOpts) (net.Listener, net.Conn, <-chan string, <-chan error) { t.Helper() - l, c, msgs, errs, err := createClientServerSsl(t, serverOpts, clientOpts) + l, c, msgs, errs, err := createClientServerSsl(ctx, t, serverOpts, clientOpts) if err != nil { t.Fatalf("Unable to create client, error %q", err.Error()) } @@ -150,7 +155,9 @@ func serverTntStop(inst test_helpers.TarantoolInstance) { } func checkTntConn(clientOpts SslOpts) error { - conn, err := Connect(tntHost, Opts{ + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err := Connect(ctx, tntHost, Opts{ Auth: AutoAuth, Timeout: 500 * time.Millisecond, User: "test", @@ -166,10 +173,11 @@ func checkTntConn(clientOpts SslOpts) error { return nil } -func assertConnectionSslFail(t testing.TB, serverOpts, clientOpts SslOpts) { +func assertConnectionSslFail(ctx context.Context, t testing.TB, serverOpts, + clientOpts SslOpts) { t.Helper() - l, c, _, _, err := createClientServerSsl(t, serverOpts, clientOpts) + l, c, _, _, err := createClientServerSsl(ctx, t, serverOpts, clientOpts) l.Close() if err == nil { c.Close() @@ -177,10 +185,11 @@ func assertConnectionSslFail(t testing.TB, serverOpts, clientOpts SslOpts) { } } -func assertConnectionSslOk(t testing.TB, serverOpts, clientOpts SslOpts) { +func assertConnectionSslOk(ctx context.Context, t testing.TB, serverOpts, + clientOpts SslOpts) { t.Helper() - l, c, msgs, errs := createClientServerSslOk(t, serverOpts, clientOpts) + l, c, msgs, errs := createClientServerSslOk(ctx, t, serverOpts, clientOpts) const message = "any test string" c.Write([]byte(message)) c.Close() @@ -230,6 +239,7 @@ type test struct { ok bool serverOpts SslOpts clientOpts SslOpts + needCtx bool } /* @@ -257,6 +267,7 @@ var tests = []test{ CertFile: "testdata/localhost.crt", }, SslOpts{}, + false, }, { "key_crt_server_and_client", @@ -269,6 +280,7 @@ var tests = []test{ KeyFile: "testdata/localhost.key", CertFile: "testdata/localhost.crt", }, + false, }, { "key_crt_ca_server", @@ -279,6 +291,7 @@ var tests = []test{ CaFile: "testdata/ca.crt", }, SslOpts{}, + false, }, { "key_crt_ca_server_key_crt_client", @@ -292,6 +305,7 @@ var tests = []test{ KeyFile: "testdata/localhost.key", CertFile: "testdata/localhost.crt", }, + false, }, { "key_crt_ca_server_and_client", @@ -306,6 +320,7 @@ var tests = []test{ CertFile: "testdata/localhost.crt", CaFile: "testdata/ca.crt", }, + false, }, { "key_crt_ca_server_and_client_invalid_path_key", @@ -320,6 +335,7 @@ var tests = []test{ CertFile: "testdata/localhost.crt", CaFile: "testdata/ca.crt", }, + false, }, { "key_crt_ca_server_and_client_invalid_path_crt", @@ -334,6 +350,7 @@ var tests = []test{ CertFile: "any_invalid_path", CaFile: "testdata/ca.crt", }, + false, }, { "key_crt_ca_server_and_client_invalid_path_ca", @@ -348,6 +365,7 @@ var tests = []test{ CertFile: "testdata/localhost.crt", CaFile: "any_invalid_path", }, + false, }, { "key_crt_ca_server_and_client_empty_key", @@ -362,6 +380,7 @@ var tests = []test{ CertFile: "testdata/localhost.crt", CaFile: "testdata/ca.crt", }, + false, }, { "key_crt_ca_server_and_client_empty_crt", @@ -376,6 +395,7 @@ var tests = []test{ CertFile: "testdata/empty", CaFile: "testdata/ca.crt", }, + false, }, { "key_crt_ca_server_and_client_empty_ca", @@ -390,6 +410,7 @@ var tests = []test{ CertFile: "testdata/localhost.crt", CaFile: "testdata/empty", }, + false, }, { "key_crt_server_and_key_crt_ca_client", @@ -403,6 +424,7 @@ var tests = []test{ CertFile: "testdata/localhost.crt", CaFile: "testdata/ca.crt", }, + false, }, { "key_crt_ca_ciphers_server_key_crt_ca_client", @@ -418,6 +440,7 @@ var tests = []test{ CertFile: "testdata/localhost.crt", CaFile: "testdata/ca.crt", }, + false, }, { "key_crt_ca_ciphers_server_and_client", @@ -434,6 +457,7 @@ var tests = []test{ CaFile: "testdata/ca.crt", Ciphers: "ECDHE-RSA-AES256-GCM-SHA384", }, + false, }, { "non_equal_ciphers_client", @@ -450,6 +474,7 @@ var tests = []test{ CaFile: "testdata/ca.crt", Ciphers: "TLS_AES_128_GCM_SHA256", }, + false, }, { "pass_key_encrypt_client", @@ -464,6 +489,7 @@ var tests = []test{ CertFile: "testdata/localhost.crt", Password: "mysslpassword", }, + false, }, { "passfile_key_encrypt_client", @@ -478,6 +504,7 @@ var tests = []test{ CertFile: "testdata/localhost.crt", PasswordFile: "testdata/passwords", }, + false, }, { "pass_and_passfile_key_encrypt_client", @@ -493,6 +520,7 @@ var tests = []test{ Password: "mysslpassword", PasswordFile: "testdata/passwords", }, + false, }, { "inv_pass_and_passfile_key_encrypt_client", @@ -508,6 +536,7 @@ var tests = []test{ Password: "invalidpassword", PasswordFile: "testdata/passwords", }, + false, }, { "pass_and_inv_passfile_key_encrypt_client", @@ -523,6 +552,7 @@ var tests = []test{ Password: "mysslpassword", PasswordFile: "testdata/invalidpasswords", }, + false, }, { "pass_and_not_existing_passfile_key_encrypt_client", @@ -538,6 +568,7 @@ var tests = []test{ Password: "mysslpassword", PasswordFile: "testdata/notafile", }, + false, }, { "inv_pass_and_inv_passfile_key_encrypt_client", @@ -553,6 +584,7 @@ var tests = []test{ Password: "invalidpassword", PasswordFile: "testdata/invalidpasswords", }, + false, }, { "not_existing_passfile_key_encrypt_client", @@ -567,6 +599,7 @@ var tests = []test{ CertFile: "testdata/localhost.crt", PasswordFile: "testdata/notafile", }, + false, }, { "no_pass_key_encrypt_client", @@ -580,6 +613,7 @@ var tests = []test{ KeyFile: "testdata/localhost.enc.key", CertFile: "testdata/localhost.crt", }, + false, }, { "pass_key_non_encrypt_client", @@ -594,6 +628,7 @@ var tests = []test{ CertFile: "testdata/localhost.crt", Password: "invalidpassword", }, + false, }, { "passfile_key_non_encrypt_client", @@ -608,6 +643,24 @@ var tests = []test{ CertFile: "testdata/localhost.crt", PasswordFile: "testdata/invalidpasswords", }, + false, + }, + { + "use_dial_with_context", + true, + SslOpts{ + KeyFile: "testdata/localhost.key", + CertFile: "testdata/localhost.crt", + CaFile: "testdata/ca.crt", + Ciphers: "ECDHE-RSA-AES256-GCM-SHA384", + }, + SslOpts{ + KeyFile: "testdata/localhost.key", + CertFile: "testdata/localhost.crt", + CaFile: "testdata/ca.crt", + Ciphers: "ECDHE-RSA-AES256-GCM-SHA384", + }, + true, }, } @@ -621,15 +674,23 @@ func TestSslOpts(t *testing.T) { isTntSsl := isTestTntSsl() for _, test := range tests { + var ctx context.Context + var cancel context.CancelFunc + if test.needCtx { + ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond) + } if test.ok { t.Run("ok_ssl_"+test.name, func(t *testing.T) { - assertConnectionSslOk(t, test.serverOpts, test.clientOpts) + assertConnectionSslOk(ctx, t, test.serverOpts, test.clientOpts) }) } else { t.Run("fail_ssl_"+test.name, func(t *testing.T) { - assertConnectionSslFail(t, test.serverOpts, test.clientOpts) + assertConnectionSslFail(ctx, t, test.serverOpts, test.clientOpts) }) } + if cancel != nil { + cancel() + } if !isTntSsl { continue } @@ -645,6 +706,35 @@ func TestSslOpts(t *testing.T) { } } +func TestSslDialContextCancel(t *testing.T) { + serverOpts := SslOpts{ + KeyFile: "testdata/localhost.key", + CertFile: "testdata/localhost.crt", + CaFile: "testdata/ca.crt", + Ciphers: "ECDHE-RSA-AES256-GCM-SHA384", + } + clientOpts := SslOpts{ + KeyFile: "testdata/localhost.key", + CertFile: "testdata/localhost.crt", + CaFile: "testdata/ca.crt", + Ciphers: "ECDHE-RSA-AES256-GCM-SHA384", + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + l, c, _, _, err := createClientServerSsl(ctx, t, serverOpts, clientOpts) + l.Close() + + if err == nil { + c.Close() + t.Fatalf("Expected error, dial was not canceled") + } + if !strings.Contains(err.Error(), "operation was canceled") { + t.Fatalf("Unexpected error, expected to contain %s, got %v", + "operation was canceled", err) + } +} + func TestOpts_PapSha256Auth(t *testing.T) { isTntSsl := isTestTntSsl() if !isTntSsl { diff --git a/tarantool_test.go b/tarantool_test.go index 6339164f1..61ae11034 100644 --- a/tarantool_test.go +++ b/tarantool_test.go @@ -708,7 +708,9 @@ func TestTtDialer(t *testing.T) { assert := assert.New(t) require := require.New(t) - conn, err := TtDialer{}.Dial(server, DialOpts{}) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err := TtDialer{}.Dial(ctx, server, DialOpts{}) require.Nil(err) require.NotNil(conn) defer conn.Close() @@ -778,7 +780,9 @@ func TestOptsAuth_PapSha256AuthForbit(t *testing.T) { papSha256Opts := opts papSha256Opts.Auth = PapSha256Auth - conn, err := Connect(server, papSha256Opts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err := Connect(ctx, server, papSha256Opts) if err == nil { t.Error("An error expected.") conn.Close() @@ -3408,7 +3412,9 @@ func TestConnectionProtocolVersionRequirementSuccess(t *testing.T) { Version: ProtocolVersion(3), } - conn, err := Connect(server, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err := Connect(ctx, server, connOpts) require.Nilf(t, err, "No errors on connect") require.NotNilf(t, conn, "Connect success") @@ -3424,7 +3430,9 @@ func TestConnectionProtocolVersionRequirementFail(t *testing.T) { Version: ProtocolVersion(3), } - conn, err := Connect(server, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err := Connect(ctx, server, connOpts) require.Nilf(t, conn, "Connect fail") require.NotNilf(t, err, "Got error on connect") @@ -3439,7 +3447,9 @@ func TestConnectionProtocolFeatureRequirementSuccess(t *testing.T) { Features: []ProtocolFeature{TransactionsFeature}, } - conn, err := Connect(server, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err := Connect(ctx, server, connOpts) require.NotNilf(t, conn, "Connect success") require.Nilf(t, err, "No errors on connect") @@ -3455,7 +3465,9 @@ func TestConnectionProtocolFeatureRequirementFail(t *testing.T) { Features: []ProtocolFeature{TransactionsFeature}, } - conn, err := Connect(server, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err := Connect(ctx, server, connOpts) require.Nilf(t, conn, "Connect fail") require.NotNilf(t, err, "Got error on connect") @@ -3471,7 +3483,9 @@ func TestConnectionProtocolFeatureRequirementManyFail(t *testing.T) { Features: []ProtocolFeature{TransactionsFeature, ProtocolFeature(15532)}, } - conn, err := Connect(server, connOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err := Connect(ctx, server, connOpts) require.Nilf(t, conn, "Connect fail") require.NotNilf(t, err, "Got error on connect") @@ -4001,21 +4015,52 @@ func TestConnect_schema_update(t *testing.T) { defer conn.Close() for i := 0; i < 100; i++ { - fut := conn.Do(NewCallRequest("create_spaces")) + func() { + fut := conn.Do(NewCallRequest("create_spaces")) + + ctx, cancel := context.WithTimeout(context.Background(), + 500*time.Millisecond) + defer cancel() + if conn, err := Connect(ctx, server, opts); err != nil { + if err.Error() != "concurrent schema update" { + t.Errorf("unexpected error: %s", err) + } + } else if conn == nil { + t.Errorf("conn is nil") + } else { + conn.Close() + } - if conn, err := Connect(server, opts); err != nil { - if err.Error() != "concurrent schema update" { - t.Errorf("unexpected error: %s", err) + if _, err := fut.Get(); err != nil { + t.Errorf("Failed to call create_spaces: %s", err) } - } else if conn == nil { - t.Errorf("conn is nil") - } else { - conn.Close() - } + }() + } +} - if _, err := fut.Get(); err != nil { - t.Errorf("Failed to call create_spaces: %s", err) - } +func TestConnect_context_cancel(t *testing.T) { + var connLongReconnectOpts = Opts{ + Timeout: 5 * time.Second, + User: "test", + Pass: "test", + Reconnect: time.Second, + MaxReconnects: 100, + } + + ctx, cancel := context.WithCancel(context.Background()) + + var conn *Connection + var err error + + cancel() + conn, err = Connect(ctx, server, connLongReconnectOpts) + + if conn != nil || err == nil { + t.Fatalf("ConnectionPool was created after cancel") + } + if !strings.Contains(err.Error(), "operation was canceled") { + t.Fatalf("Unexpected error, expected to contain %s, got %v", + "operation was canceled", err) } } diff --git a/test_helpers/main.go b/test_helpers/main.go index 894ebb653..f63149779 100644 --- a/test_helpers/main.go +++ b/test_helpers/main.go @@ -11,6 +11,7 @@ package test_helpers import ( + "context" "errors" "fmt" "io" @@ -97,7 +98,9 @@ func isReady(server string, opts *tarantool.Opts) error { var conn *tarantool.Connection var resp *tarantool.Response - conn, err = tarantool.Connect(server, *opts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err = tarantool.Connect(ctx, server, *opts) if err != nil { return err } diff --git a/test_helpers/pool_helper.go b/test_helpers/pool_helper.go index c44df2f6a..398f5c4ba 100644 --- a/test_helpers/pool_helper.go +++ b/test_helpers/pool_helper.go @@ -1,6 +1,7 @@ package test_helpers import ( + "context" "fmt" "reflect" "time" @@ -130,9 +131,9 @@ func Retry(f func(interface{}) error, args interface{}, count int, timeout time. return err } -func InsertOnInstance(server string, connOpts tarantool.Opts, space interface{}, - tuple interface{}) error { - conn, err := tarantool.Connect(server, connOpts) +func InsertOnInstance(ctx context.Context, server string, connOpts tarantool.Opts, + space interface{}, tuple interface{}) error { + conn, err := tarantool.Connect(ctx, server, connOpts) if err != nil { return fmt.Errorf("fail to connect to %s: %s", server, err.Error()) } @@ -182,7 +183,10 @@ func InsertOnInstances(servers []string, connOpts tarantool.Opts, space interfac } for _, server := range servers { - err := InsertOnInstance(server, connOpts, space, tuple) + ctx, cancel := context.WithTimeout(context.Background(), + 500*time.Millisecond) + err := InsertOnInstance(ctx, server, connOpts, space, tuple) + cancel() if err != nil { return err } @@ -191,8 +195,9 @@ func InsertOnInstances(servers []string, connOpts tarantool.Opts, space interfac return nil } -func SetInstanceRO(server string, connOpts tarantool.Opts, isReplica bool) error { - conn, err := tarantool.Connect(server, connOpts) +func SetInstanceRO(ctx context.Context, server string, connOpts tarantool.Opts, + isReplica bool) error { + conn, err := tarantool.Connect(ctx, server, connOpts) if err != nil { return err } @@ -214,7 +219,10 @@ func SetClusterRO(servers []string, connOpts tarantool.Opts, roles []bool) error } for i, server := range servers { - err := SetInstanceRO(server, connOpts, roles[i]) + ctx, cancel := context.WithTimeout(context.Background(), + 500*time.Millisecond) + err := SetInstanceRO(ctx, server, connOpts, roles[i]) + cancel() if err != nil { return err } diff --git a/test_helpers/utils.go b/test_helpers/utils.go index 3771a5f9e..0da1e38a4 100644 --- a/test_helpers/utils.go +++ b/test_helpers/utils.go @@ -1,6 +1,7 @@ package test_helpers import ( + "context" "fmt" "testing" "time" @@ -17,7 +18,9 @@ func ConnectWithValidation(t testing.TB, opts tarantool.Opts) *tarantool.Connection { t.Helper() - conn, err := tarantool.Connect(server, opts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err := tarantool.Connect(ctx, server, opts) if err != nil { t.Fatalf("Failed to connect: %s", err.Error()) } diff --git a/uuid/example_test.go b/uuid/example_test.go index 632f620be..08bd64aae 100644 --- a/uuid/example_test.go +++ b/uuid/example_test.go @@ -9,8 +9,10 @@ package uuid_test import ( + "context" "fmt" "log" + "time" "github.com/google/uuid" "github.com/tarantool/go-tarantool/v2" @@ -25,7 +27,9 @@ func Example() { User: "test", Pass: "test", } - client, err := tarantool.Connect("127.0.0.1:3013", opts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + client, err := tarantool.Connect(ctx, "127.0.0.1:3013", opts) + cancel() if err != nil { log.Fatalf("Failed to connect: %s", err.Error()) }