diff --git a/CHANGELOG.md b/CHANGELOG.md index 211ea25a2..07e976bba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,12 @@ Versioning](http://semver.org/spec/v2.0.0.html) except to the first release. decoded to a varbinary object (#313). - Use objects of the Decimal type instead of pointers (#238) - Use objects of the Datetime type instead of pointers (#238) +- `connection.Connect` no longer return non-working + connection objects (#136). This function now does not attempt to reconnect + and tries to establish a connection only once based on the context object. + Context accepted as first argument, and user may cancel it in process. + `pool.Connect` and `pool.Add` now accept context as first argument, which + user may cancel in process. ### Deprecated diff --git a/README.md b/README.md index aa4c6deac..5d4fcc139 100644 --- a/README.md +++ b/README.md @@ -105,13 +105,19 @@ about what it does. package tarantool import ( + "context" "fmt" + "time" + "github.com/tarantool/go-tarantool/v2" ) func main() { opts := tarantool.Opts{User: "guest"} - conn, err := tarantool.Connect("127.0.0.1:3301", opts) + ctx, cancel := context.WithTimeout(context.Background(), + 500 * time.Millisecond) + defer cancel() + conn, err := tarantool.Connect(ctx, "127.0.0.1:3301", opts) if err != nil { fmt.Println("Connection refused:", err) } @@ -134,11 +140,17 @@ username. The structure may also contain other settings, see more in [documentation][godoc-opts-url] for the "`Opts`" structure. **Observation 3:** The line containing "`tarantool.Connect`" is essential for -starting a session. There are two parameters: +starting a session. There are three parameters: -* a string with `host:port` format, and +* a context, +* a string with `host:port` format, * the option structure that was set up earlier. +There will be only one attempt to connect. If multiple attempts needed, +"`tarantool.Connect`" could be placed inside the loop with some timeout +between each try. Example could be found in the [example_test](./example_test.go), +name - `ExampleConnect_reconnects`. + **Observation 4:** The `err` structure will be `nil` if there is no error, otherwise it will have a description which can be retrieved with `err.Error()`. @@ -167,10 +179,12 @@ The subpackage has been deleted. You could use `pool` instead. #### pool package -The logic has not changed, but there are a few renames: - * The `connection_pool` subpackage has been renamed to `pool`. * The type `PoolOpts` has been renamed to `Opts`. +* `pool.Connect` now accepts context as first argument, which user may cancel + in process. +* `pool.Add` now accepts context as first argument, which user may cancel in + process. #### msgpack.v5 @@ -212,6 +226,13 @@ IPROTO constants have been moved to a separate package [go-iproto](https://githu * The method `Code() uint32` replaced by the `Type() iproto.Type`. +#### Connect function + +`connection.Connect` no longer return non-working connection objects. This function +now does not attempt to reconnect and tries to establish a connection only once +based on the context object. Context accepted as first argument, and user may +cancel it in process. + ## Contributing See [the contributing guide](CONTRIBUTING.md) for detailed instructions on how diff --git a/connection.go b/connection.go index 9bb42626a..f304a12ba 100644 --- a/connection.go +++ b/connection.go @@ -375,16 +375,7 @@ func (opts Opts) Clone() Opts { // - Unix socket, first '/' or '.' indicates Unix socket // (unix:///abs/path/tnt.sock, unix:path/tnt.sock, /abs/path/tnt.sock, // ./rel/path/tnt.sock, unix/:path/tnt.sock) -// -// Notes: -// -// - If opts.Reconnect is zero (default), then connection either already connected -// or error is returned. -// -// - If opts.Reconnect is non-zero, then error will be returned only if authorization -// fails. But if Tarantool is not reachable, then it will make an attempt to reconnect later -// and will not finish to make attempts on authorization failures. -func Connect(addr string, opts Opts) (conn *Connection, err error) { +func Connect(ctx context.Context, addr string, opts Opts) (conn *Connection, err error) { conn = &Connection{ addr: addr, requestId: 0, @@ -432,25 +423,8 @@ func Connect(addr string, opts Opts) (conn *Connection, err error) { conn.cond = sync.NewCond(&conn.mutex) - if err = conn.createConnection(false); err != nil { - ter, ok := err.(Error) - if conn.opts.Reconnect <= 0 { - return nil, err - } else if ok && (ter.Code == iproto.ER_NO_SUCH_USER || - ter.Code == iproto.ER_CREDS_MISMATCH) { - // Reported auth errors immediately. - return nil, err - } else { - // Without SkipSchema it is useless. - go func(conn *Connection) { - conn.mutex.Lock() - defer conn.mutex.Unlock() - if err := conn.createConnection(true); err != nil { - conn.closeConnection(err, true) - } - }(conn) - err = nil - } + if err = conn.createConnection(ctx); err != nil { + return nil, err } go conn.pinger() @@ -534,18 +508,11 @@ func (conn *Connection) cancelFuture(fut *Future, err error) { } } -func (conn *Connection) dial() (err error) { +func (conn *Connection) dial(ctx context.Context) error { opts := conn.opts - dialTimeout := opts.Reconnect / 2 - if dialTimeout == 0 { - dialTimeout = 500 * time.Millisecond - } else if dialTimeout > 5*time.Second { - dialTimeout = 5 * time.Second - } var c Conn - c, err = conn.opts.Dialer.Dial(conn.addr, DialOpts{ - DialTimeout: dialTimeout, + c, err := conn.opts.Dialer.Dial(ctx, conn.addr, DialOpts{ IoTimeout: opts.Timeout, Transport: opts.Transport, Ssl: opts.Ssl, @@ -555,7 +522,7 @@ func (conn *Connection) dial() (err error) { Password: opts.Pass, }) if err != nil { - return + return err } conn.Greeting.Version = c.Greeting().Version @@ -605,7 +572,7 @@ func (conn *Connection) dial() (err error) { conn.shutdownWatcher = watcher } - return + return nil } func pack(h *smallWBuf, enc *msgpack.Encoder, reqid uint32, @@ -658,34 +625,18 @@ func pack(h *smallWBuf, enc *msgpack.Encoder, reqid uint32, return } -func (conn *Connection) createConnection(reconnect bool) (err error) { - var reconnects uint - for conn.c == nil && conn.state == connDisconnected { - now := time.Now() - err = conn.dial() - if err == nil || !reconnect { - if err == nil { - conn.notify(Connected) - } - return - } - if conn.opts.MaxReconnects > 0 && reconnects > conn.opts.MaxReconnects { - conn.opts.Logger.Report(LogLastReconnectFailed, conn, err) - err = ClientError{ErrConnectionClosed, "last reconnect failed"} - // mark connection as closed to avoid reopening by another goroutine - return +func (conn *Connection) createConnection(ctx context.Context) error { + var err error + if conn.c == nil && conn.state == connDisconnected { + if err = conn.dial(ctx); err == nil { + conn.notify(Connected) + return nil } - conn.opts.Logger.Report(LogReconnectFailed, conn, reconnects, err) - conn.notify(ReconnectFailed) - reconnects++ - conn.mutex.Unlock() - time.Sleep(time.Until(now.Add(conn.opts.Reconnect))) - conn.mutex.Lock() } if conn.state == connClosed { err = ClientError{ErrConnectionClosed, "using closed connection"} } - return + return err } func (conn *Connection) closeConnection(neterr error, forever bool) (err error) { @@ -727,11 +678,57 @@ func (conn *Connection) closeConnection(neterr error, forever bool) (err error) return } +func (conn *Connection) getDialTimeout() time.Duration { + dialTimeout := conn.opts.Reconnect / 2 + if dialTimeout == 0 { + dialTimeout = 500 * time.Millisecond + } else if dialTimeout > 5*time.Second { + dialTimeout = 5 * time.Second + } + return dialTimeout +} + +func (conn *Connection) runReconnects() error { + dialTimeout := conn.getDialTimeout() + var reconnects uint + var err error + + for conn.opts.MaxReconnects == 0 || reconnects <= conn.opts.MaxReconnects { + now := time.Now() + + ctx, cancel := context.WithTimeout(context.Background(), dialTimeout) + err = conn.createConnection(ctx) + cancel() + + if err != nil { + if clientErr, ok := err.(ClientError); ok && + clientErr.Code == ErrConnectionClosed { + return err + } + } else { + return nil + } + + conn.opts.Logger.Report(LogReconnectFailed, conn, reconnects, err) + conn.notify(ReconnectFailed) + reconnects++ + conn.mutex.Unlock() + + time.Sleep(time.Until(now.Add(conn.opts.Reconnect))) + + conn.mutex.Lock() + } + + conn.opts.Logger.Report(LogLastReconnectFailed, conn, err) + // mark connection as closed to avoid reopening by another goroutine + return ClientError{ErrConnectionClosed, "last reconnect failed"} +} + func (conn *Connection) reconnectImpl(neterr error, c Conn) { if conn.opts.Reconnect > 0 { if c == conn.c { conn.closeConnection(neterr, false) - if err := conn.createConnection(true); err != nil { + if err := conn.runReconnects(); err != nil { conn.closeConnection(err, true) } } diff --git a/crud/example_test.go b/crud/example_test.go index 363d0570d..79e9f42a1 100644 --- a/crud/example_test.go +++ b/crud/example_test.go @@ -1,6 +1,7 @@ package crud_test import ( + "context" "fmt" "reflect" "time" @@ -21,7 +22,9 @@ var exampleOpts = tarantool.Opts{ } func exampleConnect() *tarantool.Connection { - conn, err := tarantool.Connect(exampleServer, exampleOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err := tarantool.Connect(ctx, exampleServer, exampleOpts) if err != nil { panic("Connection is not established: " + err.Error()) } diff --git a/crud/tarantool_test.go b/crud/tarantool_test.go index 5cf29f66a..17519faad 100644 --- a/crud/tarantool_test.go +++ b/crud/tarantool_test.go @@ -108,7 +108,9 @@ var object = crud.MapObject{ func connect(t testing.TB) *tarantool.Connection { for i := 0; i < 10; i++ { - conn, err := tarantool.Connect(server, opts) + ctx, cancel := test_helpers.GetConnectContext() + conn, err := tarantool.Connect(ctx, server, opts) + cancel() if err != nil { t.Fatalf("Failed to connect: %s", err) } diff --git a/datetime/example_test.go b/datetime/example_test.go index 346551629..954f43548 100644 --- a/datetime/example_test.go +++ b/datetime/example_test.go @@ -9,6 +9,7 @@ package datetime_test import ( + "context" "fmt" "time" @@ -23,7 +24,9 @@ func Example() { User: "test", Pass: "test", } - conn, err := tarantool.Connect("127.0.0.1:3013", opts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err := tarantool.Connect(ctx, "127.0.0.1:3013", opts) if err != nil { fmt.Printf("Error in connect is %v", err) return diff --git a/decimal/example_test.go b/decimal/example_test.go index a355767f1..f1984283c 100644 --- a/decimal/example_test.go +++ b/decimal/example_test.go @@ -9,6 +9,7 @@ package decimal_test import ( + "context" "log" "time" @@ -22,13 +23,13 @@ import ( func Example() { server := "127.0.0.1:3013" opts := tarantool.Opts{ - Timeout: 5 * time.Second, - Reconnect: 1 * time.Second, - MaxReconnects: 3, - User: "test", - Pass: "test", + Timeout: 5 * time.Second, + User: "test", + Pass: "test", } - client, err := tarantool.Connect(server, opts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + client, err := tarantool.Connect(ctx, server, opts) + cancel() if err != nil { log.Fatalf("Failed to connect: %s", err.Error()) } diff --git a/dial.go b/dial.go index 3ba493ac7..5b17c0534 100644 --- a/dial.go +++ b/dial.go @@ -3,6 +3,7 @@ package tarantool import ( "bufio" "bytes" + "context" "errors" "fmt" "io" @@ -56,8 +57,6 @@ type Conn interface { // DialOpts is a way to configure a Dial method to create a new Conn. type DialOpts struct { - // DialTimeout is a timeout for an initial network dial. - DialTimeout time.Duration // IoTimeout is a timeout per a network read/write. IoTimeout time.Duration // Transport is a connect transport type. @@ -86,7 +85,7 @@ type DialOpts struct { type Dialer interface { // Dial connects to a Tarantool instance to the address with specified // options. - Dial(address string, opts DialOpts) (Conn, error) + Dial(ctx context.Context, address string, opts DialOpts) (Conn, error) } type tntConn struct { @@ -104,11 +103,11 @@ type TtDialer struct { // Dial connects to a Tarantool instance to the address with specified // options. -func (t TtDialer) Dial(address string, opts DialOpts) (Conn, error) { +func (t TtDialer) Dial(ctx context.Context, address string, opts DialOpts) (Conn, error) { var err error conn := new(tntConn) - if conn.net, err = dial(address, opts); err != nil { + if conn.net, err = dial(ctx, address, opts); err != nil { return nil, fmt.Errorf("failed to dial: %w", err) } @@ -199,13 +198,14 @@ func (c *tntConn) ProtocolInfo() ProtocolInfo { } // dial connects to a Tarantool instance. -func dial(address string, opts DialOpts) (net.Conn, error) { +func dial(ctx context.Context, address string, opts DialOpts) (net.Conn, error) { network, address := parseAddress(address) switch opts.Transport { case dialTransportNone: - return net.DialTimeout(network, address, opts.DialTimeout) + dialer := net.Dialer{} + return dialer.DialContext(ctx, network, address) case dialTransportSsl: - return sslDialTimeout(network, address, opts.DialTimeout, opts.Ssl) + return sslDialContext(ctx, network, address, opts.Ssl) default: return nil, fmt.Errorf("unsupported transport type: %s", opts.Transport) } diff --git a/dial_test.go b/dial_test.go index ff8a50aab..acd6737c5 100644 --- a/dial_test.go +++ b/dial_test.go @@ -2,8 +2,11 @@ package tarantool_test import ( "bytes" + "context" "errors" + "fmt" "net" + "strings" "sync" "testing" "time" @@ -12,13 +15,14 @@ import ( "github.com/stretchr/testify/require" "github.com/tarantool/go-tarantool/v2" + "github.com/tarantool/go-tarantool/v2/test_helpers" ) type mockErrorDialer struct { err error } -func (m mockErrorDialer) Dial(address string, +func (m mockErrorDialer) Dial(ctx context.Context, address string, opts tarantool.DialOpts) (tarantool.Conn, error) { return nil, m.err } @@ -29,7 +33,9 @@ func TestDialer_Dial_error(t *testing.T) { err: errors.New(errMsg), } - conn, err := tarantool.Connect("any", tarantool.Opts{ + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := tarantool.Connect(ctx, "any", tarantool.Opts{ Dialer: dialer, }) assert.Nil(t, conn) @@ -37,23 +43,26 @@ func TestDialer_Dial_error(t *testing.T) { } type mockPassedDialer struct { + ctx context.Context address string opts tarantool.DialOpts } -func (m *mockPassedDialer) Dial(address string, +func (m *mockPassedDialer) Dial(ctx context.Context, address string, opts tarantool.DialOpts) (tarantool.Conn, error) { m.address = address m.opts = opts + if ctx != m.ctx { + return nil, errors.New("wrong context") + } return nil, errors.New("does not matter") } func TestDialer_Dial_passedOpts(t *testing.T) { const addr = "127.0.0.1:8080" opts := tarantool.DialOpts{ - DialTimeout: 500 * time.Millisecond, - IoTimeout: 2, - Transport: "any", + IoTimeout: 2, + Transport: "any", Ssl: tarantool.SslOpts{ KeyFile: "a", CertFile: "b", @@ -73,7 +82,12 @@ func TestDialer_Dial_passedOpts(t *testing.T) { } dialer := &mockPassedDialer{} - conn, err := tarantool.Connect(addr, tarantool.Opts{ + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + + dialer.ctx = ctx + + conn, err := tarantool.Connect(ctx, addr, tarantool.Opts{ Dialer: dialer, Timeout: opts.IoTimeout, Transport: opts.Transport, @@ -86,6 +100,7 @@ func TestDialer_Dial_passedOpts(t *testing.T) { assert.Nil(t, conn) assert.NotNil(t, err) + assert.NotEqual(t, err.Error(), "wrong context") assert.Equal(t, addr, dialer.address) assert.Equal(t, opts, dialer.opts) } @@ -187,7 +202,7 @@ func newMockIoConn() *mockIoConn { return conn } -func (m *mockIoDialer) Dial(address string, +func (m *mockIoDialer) Dial(ctx context.Context, address string, opts tarantool.DialOpts) (tarantool.Conn, error) { m.conn = newMockIoConn() if m.init != nil { @@ -203,11 +218,14 @@ func dialIo(t *testing.T, dialer := mockIoDialer{ init: init, } - conn, err := tarantool.Connect("any", tarantool.Opts{ - Dialer: &dialer, - Timeout: 1000 * time.Second, // Avoid pings. - SkipSchema: true, - }) + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := tarantool.Connect(ctx, "any", + tarantool.Opts{ + Dialer: &dialer, + Timeout: 1000 * time.Second, // Avoid pings. + SkipSchema: true, + }) require.Nil(t, err) require.NotNil(t, conn) @@ -338,3 +356,19 @@ func TestConn_ReadWrite(t *testing.T) { assert.Nil(t, err) assert.NotNil(t, resp) } + +func TestConn_ContextCancel(t *testing.T) { + const addr = "127.0.0.1:8080" + + dialer := tarantool.TtDialer{} + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + conn, err := dialer.Dial(ctx, addr, tarantool.DialOpts{}) + + assert.Nil(t, conn) + assert.NotNil(t, err) + assert.Truef(t, strings.Contains(err.Error(), "operation was canceled"), + fmt.Sprintf("unexpected error, expected to contain %s, got %v", + "operation was canceled", err)) +} diff --git a/example_custom_unpacking_test.go b/example_custom_unpacking_test.go index 1189e16a3..8fb243db0 100644 --- a/example_custom_unpacking_test.go +++ b/example_custom_unpacking_test.go @@ -1,6 +1,7 @@ package tarantool_test import ( + "context" "fmt" "log" "time" @@ -78,13 +79,13 @@ func Example_customUnpacking() { // Establish a connection. server := "127.0.0.1:3013" opts := tarantool.Opts{ - Timeout: 500 * time.Millisecond, - Reconnect: 1 * time.Second, - MaxReconnects: 3, - User: "test", - Pass: "test", + Timeout: 500 * time.Millisecond, + User: "test", + Pass: "test", } - conn, err := tarantool.Connect(server, opts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + conn, err := tarantool.Connect(ctx, server, opts) + cancel() if err != nil { log.Fatalf("Failed to connect: %s", err.Error()) } diff --git a/example_test.go b/example_test.go index d871d578a..77b2cff24 100644 --- a/example_test.go +++ b/example_test.go @@ -19,7 +19,9 @@ type Tuple struct { } func exampleConnect(opts tarantool.Opts) *tarantool.Connection { - conn, err := tarantool.Connect(server, opts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err := tarantool.Connect(ctx, server, opts) if err != nil { panic("Connection is not established: " + err.Error()) } @@ -38,7 +40,9 @@ func ExampleSslOpts() { CaFile: "testdata/ca.crt", }, } - _, err := tarantool.Connect("127.0.0.1:3013", opts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + _, err := tarantool.Connect(ctx, "127.0.0.1:3013", opts) if err != nil { panic("Connection is not established: " + err.Error()) } @@ -913,7 +917,10 @@ func ExampleFuture_GetIterator() { } func ExampleConnect() { - conn, err := tarantool.Connect("127.0.0.1:3013", tarantool.Opts{ + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + conn, err := tarantool.Connect(ctx, "127.0.0.1:3013", tarantool.Opts{ Timeout: 5 * time.Second, User: "test", Pass: "test", @@ -931,6 +938,40 @@ func ExampleConnect() { // Connection is ready } +func ExampleConnect_reconnects() { + opts := tarantool.Opts{ + Timeout: 5 * time.Second, + User: "test", + Pass: "test", + Concurrency: 32, + Reconnect: time.Second, + MaxReconnects: 10, + } + + var conn *tarantool.Connection + var err error + + for i := uint(0); i < opts.MaxReconnects; i++ { + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + conn, err = tarantool.Connect(ctx, "127.0.0.1:3013", opts) + cancel() + if err == nil { + break + } + time.Sleep(opts.Reconnect) + } + if err != nil { + fmt.Println("No connection available") + return + } + defer conn.Close() + if conn != nil { + fmt.Println("Connection is ready") + } + // Output: + // Connection is ready +} + // Example demonstrates how to retrieve information with space schema. func ExampleSchema() { conn := exampleConnect(opts) @@ -1081,7 +1122,9 @@ func ExampleConnection_NewPrepared() { User: "test", Pass: "test", } - conn, err := tarantool.Connect(server, opts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err := tarantool.Connect(ctx, server, opts) if err != nil { fmt.Printf("Failed to connect: %s", err.Error()) } @@ -1127,7 +1170,9 @@ func ExampleConnection_NewWatcher() { Features: []tarantool.ProtocolFeature{tarantool.WatchersFeature}, }, } - conn, err := tarantool.Connect(server, opts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err := tarantool.Connect(ctx, server, opts) if err != nil { fmt.Printf("Failed to connect: %s\n", err) return diff --git a/export_test.go b/export_test.go index 10d194840..fc5d90c34 100644 --- a/export_test.go +++ b/export_test.go @@ -1,15 +1,16 @@ package tarantool import ( + "context" "net" "time" "github.com/vmihailenco/msgpack/v5" ) -func SslDialTimeout(network, address string, timeout time.Duration, +func SslDialContext(ctx context.Context, network, address string, opts SslOpts) (connection net.Conn, err error) { - return sslDialTimeout(network, address, timeout, opts) + return sslDialContext(ctx, network, address, opts) } func SslCreateContext(opts SslOpts) (ctx interface{}, err error) { diff --git a/go.mod b/go.mod index bd848308c..22cb7aee3 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/shopspring/decimal v1.3.1 github.com/stretchr/testify v1.7.1 github.com/tarantool/go-iproto v0.1.0 - github.com/tarantool/go-openssl v0.0.8-0.20230801114713-b452431f934a + github.com/tarantool/go-openssl v0.0.8-0.20231004103608-336ca939d2ca github.com/vmihailenco/msgpack/v5 v5.3.5 golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4 // indirect gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect diff --git a/go.sum b/go.sum index 1810c2b3a..44d00984f 100644 --- a/go.sum +++ b/go.sum @@ -21,8 +21,8 @@ github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMT github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/tarantool/go-iproto v0.1.0 h1:zHN9AA8LDawT+JBD0/Nxgr/bIsWkkpDzpcMuaNPSIAQ= github.com/tarantool/go-iproto v0.1.0/go.mod h1:LNCtdyZxojUed8SbOiYHoc3v9NvaZTB7p96hUySMlIo= -github.com/tarantool/go-openssl v0.0.8-0.20230801114713-b452431f934a h1:eeElglRXJ3xWKkHmDbeXrQWlZyQ4t3Ca1YlZsrfdXFU= -github.com/tarantool/go-openssl v0.0.8-0.20230801114713-b452431f934a/go.mod h1:M7H4xYSbzqpW/ZRBMyH0eyqQBsnhAMfsYk5mv0yid7A= +github.com/tarantool/go-openssl v0.0.8-0.20231004103608-336ca939d2ca h1:oOrBh73tDDyooIXajfr+0pfnM+89404ClAhJpTTHI7E= +github.com/tarantool/go-openssl v0.0.8-0.20231004103608-336ca939d2ca/go.mod h1:M7H4xYSbzqpW/ZRBMyH0eyqQBsnhAMfsYk5mv0yid7A= github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU= github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= diff --git a/pool/connection_pool.go b/pool/connection_pool.go index 26e2199e9..9e479681f 100644 --- a/pool/connection_pool.go +++ b/pool/connection_pool.go @@ -11,6 +11,7 @@ package pool import ( + "context" "errors" "log" "sync" @@ -33,6 +34,7 @@ var ( ErrClosed = errors.New("pool is closed") ErrUnknownRequest = errors.New("the passed connected request doesn't belong to " + "the current connection pool") + ErrContextCanceled = errors.New("operation was canceled") ) // ConnectionHandler provides callbacks for components interested in handling @@ -116,6 +118,7 @@ type endpoint struct { shutdown chan struct{} close chan struct{} closed chan struct{} + cancel context.CancelFunc closeErr error } @@ -128,12 +131,14 @@ func newEndpoint(addr string) *endpoint { shutdown: make(chan struct{}), close: make(chan struct{}), closed: make(chan struct{}), + cancel: nil, } } // ConnectWithOpts creates pool for instances with addresses addrs // with options opts. -func ConnectWithOpts(addrs []string, connOpts tarantool.Opts, opts Opts) (*ConnectionPool, error) { +func ConnectWithOpts(ctx context.Context, addrs []string, + connOpts tarantool.Opts, opts Opts) (*ConnectionPool, error) { if len(addrs) == 0 { return nil, ErrEmptyAddrs } @@ -161,16 +166,21 @@ func ConnectWithOpts(addrs []string, connOpts tarantool.Opts, opts Opts) (*Conne connPool.addrs[addr] = nil } - somebodyAlive := connPool.fillPools() + somebodyAlive, ctxCanceled := connPool.fillPools(ctx) if !somebodyAlive { connPool.state.set(closedState) + if ctxCanceled { + return nil, ErrContextCanceled + } return nil, ErrNoConnection } connPool.state.set(connectedState) for _, s := range connPool.addrs { - go connPool.controller(s) + endpointCtx, cancel := context.WithCancel(context.Background()) + s.cancel = cancel + go connPool.controller(endpointCtx, s) } return connPool, nil @@ -181,11 +191,12 @@ func ConnectWithOpts(addrs []string, connOpts tarantool.Opts, opts Opts) (*Conne // It is useless to set up tarantool.Opts.Reconnect value for a connection. // The connection pool has its own reconnection logic. See // Opts.CheckTimeout description. -func Connect(addrs []string, connOpts tarantool.Opts) (*ConnectionPool, error) { +func Connect(ctx context.Context, addrs []string, + connOpts tarantool.Opts) (*ConnectionPool, error) { opts := Opts{ CheckTimeout: 1 * time.Second, } - return ConnectWithOpts(addrs, connOpts, opts) + return ConnectWithOpts(ctx, addrs, connOpts, opts) } // ConnectedNow gets connected status of pool. @@ -224,9 +235,12 @@ func (p *ConnectionPool) ConfiguredTimeout(mode Mode) (time.Duration, error) { // Add adds a new endpoint with the address into the pool. This function // adds the endpoint only after successful connection. -func (p *ConnectionPool) Add(addr string) error { +func (p *ConnectionPool) Add(ctx context.Context, addr string) error { e := newEndpoint(addr) + endpointCtx, cancel := context.WithCancel(context.Background()) + e.cancel = cancel + p.addrsMutex.Lock() // Ensure that Close()/CloseGraceful() not in progress/done. if p.state.get() != connectedState { @@ -240,7 +254,7 @@ func (p *ConnectionPool) Add(addr string) error { p.addrs[addr] = e p.addrsMutex.Unlock() - if err := p.tryConnect(e); err != nil { + if err := p.tryConnect(ctx, e); err != nil { p.addrsMutex.Lock() delete(p.addrs, addr) p.addrsMutex.Unlock() @@ -248,7 +262,7 @@ func (p *ConnectionPool) Add(addr string) error { return err } - go p.controller(e) + go p.controller(endpointCtx, e) return nil } @@ -268,6 +282,7 @@ func (p *ConnectionPool) Remove(addr string) error { case <-endpoint.shutdown: // CloseGraceful()/Remove() in progress/done. default: + endpoint.cancel() close(endpoint.shutdown) } @@ -302,6 +317,7 @@ func (p *ConnectionPool) Close() []error { p.state.cas(shutdownState, closedState) { p.addrsMutex.RLock() for _, s := range p.addrs { + s.cancel() close(s.close) } p.addrsMutex.RUnlock() @@ -316,6 +332,7 @@ func (p *ConnectionPool) CloseGraceful() []error { if p.state.cas(connectedState, shutdownState) { p.addrsMutex.RLock() for _, s := range p.addrs { + s.cancel() close(s.shutdown) } p.addrsMutex.RUnlock() @@ -1109,8 +1126,48 @@ func (p *ConnectionPool) handlerDeactivated(conn *tarantool.Connection, } } -func (p *ConnectionPool) fillPools() bool { +func (p *ConnectionPool) deactivateConnection(addr string, + conn *tarantool.Connection, role Role) { + p.deleteConnection(addr) + conn.Close() + p.handlerDeactivated(conn, role) +} + +func (p *ConnectionPool) deactivateConnections() { + for address, endpoint := range p.addrs { + if endpoint != nil && endpoint.conn != nil { + p.deactivateConnection(address, endpoint.conn, endpoint.role) + } + } +} + +func (p *ConnectionPool) processConnection(conn *tarantool.Connection, + addr string, end *endpoint) bool { + role, err := p.getConnectionRole(conn) + if err != nil { + conn.Close() + log.Printf("tarantool: storing connection to %s failed: %s\n", addr, err) + return false + } + + if !p.handlerDiscovered(conn, role) { + conn.Close() + return false + } + if p.addConnection(addr, conn, role) != nil { + conn.Close() + p.handlerDeactivated(conn, role) + return false + } + + end.conn = conn + end.role = role + return true +} + +func (p *ConnectionPool) fillPools(ctx context.Context) (bool, bool) { somebodyAlive := false + ctxCanceled := false // It is called before controller() goroutines so we don't expect // concurrency issues here. @@ -1120,39 +1177,27 @@ func (p *ConnectionPool) fillPools() bool { connOpts := p.connOpts connOpts.Notify = end.notify - conn, err := tarantool.Connect(addr, connOpts) + conn, err := tarantool.Connect(ctx, addr, connOpts) if err != nil { log.Printf("tarantool: connect to %s failed: %s\n", addr, err.Error()) - } else if conn != nil { - role, err := p.getConnectionRole(conn) - if err != nil { - conn.Close() - log.Printf("tarantool: storing connection to %s failed: %s\n", addr, err) - continue - } + select { + case <-ctx.Done(): + ctxCanceled = true - if p.handlerDiscovered(conn, role) { - if p.addConnection(addr, conn, role) != nil { - conn.Close() - p.handlerDeactivated(conn, role) - } + p.addrs[addr] = nil + log.Printf("tarantool: operation was canceled") - if conn.ConnectedNow() { - end.conn = conn - end.role = role - somebodyAlive = true - } else { - p.deleteConnection(addr) - conn.Close() - p.handlerDeactivated(conn, role) - } - } else { - conn.Close() + p.deactivateConnections() + + return false, ctxCanceled + default: } + } else if p.processConnection(conn, addr, end) { + somebodyAlive = true } } - return somebodyAlive + return somebodyAlive, ctxCanceled } func (p *ConnectionPool) updateConnection(e *endpoint) { @@ -1213,7 +1258,7 @@ func (p *ConnectionPool) updateConnection(e *endpoint) { } } -func (p *ConnectionPool) tryConnect(e *endpoint) error { +func (p *ConnectionPool) tryConnect(ctx context.Context, e *endpoint) error { p.poolsMutex.Lock() if p.state.get() != connectedState { @@ -1226,7 +1271,7 @@ func (p *ConnectionPool) tryConnect(e *endpoint) error { connOpts := p.connOpts connOpts.Notify = e.notify - conn, err := tarantool.Connect(e.addr, connOpts) + conn, err := tarantool.Connect(ctx, e.addr, connOpts) if err == nil { role, err := p.getConnectionRole(conn) p.poolsMutex.Unlock() @@ -1265,7 +1310,7 @@ func (p *ConnectionPool) tryConnect(e *endpoint) error { return err } -func (p *ConnectionPool) reconnect(e *endpoint) { +func (p *ConnectionPool) reconnect(ctx context.Context, e *endpoint) { p.poolsMutex.Lock() if p.state.get() != connectedState { @@ -1280,10 +1325,10 @@ func (p *ConnectionPool) reconnect(e *endpoint) { e.conn = nil e.role = UnknownRole - p.tryConnect(e) + p.tryConnect(ctx, e) } -func (p *ConnectionPool) controller(e *endpoint) { +func (p *ConnectionPool) controller(ctx context.Context, e *endpoint) { timer := time.NewTicker(p.opts.CheckTimeout) defer timer.Stop() @@ -1367,11 +1412,11 @@ func (p *ConnectionPool) controller(e *endpoint) { // Relocate connection between subpools // if ro/rw was updated. if e.conn == nil { - p.tryConnect(e) + p.tryConnect(ctx, e) } else if !e.conn.ClosedNow() { p.updateConnection(e) } else { - p.reconnect(e) + p.reconnect(ctx, e) } } } diff --git a/pool/connection_pool_test.go b/pool/connection_pool_test.go index dd0210d62..b944d0c6c 100644 --- a/pool/connection_pool_test.go +++ b/pool/connection_pool_test.go @@ -1,6 +1,7 @@ package pool_test import ( + "context" "fmt" "log" "os" @@ -46,17 +47,23 @@ var defaultTimeoutRetry = 500 * time.Millisecond var instances []test_helpers.TarantoolInstance func TestConnError_IncorrectParams(t *testing.T) { - connPool, err := pool.Connect([]string{}, tarantool.Opts{}) + ctx, cancel := test_helpers.GetPoolConnectContext() + connPool, err := pool.Connect(ctx, []string{}, tarantool.Opts{}) + cancel() require.Nilf(t, connPool, "conn is not nil with incorrect param") require.NotNilf(t, err, "err is nil with incorrect params") require.Equal(t, "addrs (first argument) should not be empty", err.Error()) - connPool, err = pool.Connect([]string{"err1", "err2"}, connOpts) + ctx, cancel = test_helpers.GetPoolConnectContext() + connPool, err = pool.Connect(ctx, []string{"err1", "err2"}, connOpts) + cancel() require.Nilf(t, connPool, "conn is not nil with incorrect param") require.NotNilf(t, err, "err is nil with incorrect params") require.Equal(t, "no active connections", err.Error()) - connPool, err = pool.ConnectWithOpts(servers, tarantool.Opts{}, pool.Opts{}) + ctx, cancel = test_helpers.GetPoolConnectContext() + connPool, err = pool.ConnectWithOpts(ctx, servers, tarantool.Opts{}, pool.Opts{}) + cancel() require.Nilf(t, connPool, "conn is not nil with incorrect param") require.NotNilf(t, err, "err is nil with incorrect params") require.Equal(t, "wrong check timeout, must be greater than 0", err.Error()) @@ -64,7 +71,9 @@ func TestConnError_IncorrectParams(t *testing.T) { func TestConnSuccessfully(t *testing.T) { server := servers[0] - connPool, err := pool.Connect([]string{"err", server}, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, []string{"err", server}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -84,9 +93,77 @@ func TestConnSuccessfully(t *testing.T) { require.Nil(t, err) } +func TestConnErrorAfterCtxCancel(t *testing.T) { + var connLongReconnectOpts = tarantool.Opts{ + Timeout: 5 * time.Second, + User: "test", + Pass: "test", + Reconnect: time.Second, + MaxReconnects: 100, + } + + ctx, cancel := context.WithCancel(context.Background()) + + var connPool *pool.ConnectionPool + var err error + + cancel() + connPool, err = pool.Connect(ctx, servers, connLongReconnectOpts) + + if connPool != nil || err == nil { + t.Fatalf("ConnectionPool was created after cancel") + } + if !strings.Contains(err.Error(), "operation was canceled") { + t.Fatalf("Unexpected error, expected to contain %s, got %v", + "operation was canceled", err) + } +} + +type mockClosingDialer struct { + cnt int + ctx context.Context + ctxCancel context.CancelFunc +} + +func (m *mockClosingDialer) Dial(ctx context.Context, address string, + opts tarantool.DialOpts) (tarantool.Conn, error) { + + dialer := tarantool.TtDialer{} + conn, err := dialer.Dial(m.ctx, address, tarantool.DialOpts{ + User: "test", + Password: "test", + }) + + if m.cnt == 0 { + m.ctxCancel() + } + m.cnt++ + + return conn, err +} + +func TestContextCancelInProgress(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + dialer := &mockClosingDialer{0, ctx, cancel} + + connPool, err := pool.Connect(ctx, servers, tarantool.Opts{ + Dialer: dialer, + }) + require.NotNilf(t, err, "expected err after ctx cancel") + assert.Truef(t, strings.Contains(err.Error(), "operation was canceled"), + fmt.Sprintf("unexpected error, expected to contain %s, got %v", + "operation was canceled", err)) + require.Nilf(t, connPool, "conn is not nil after ctx cancel") +} + func TestConnSuccessfullyDuplicates(t *testing.T) { server := servers[0] - connPool, err := pool.Connect([]string{server, server, server, server}, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, []string{server, server, server, server}, + connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -112,7 +189,9 @@ func TestConnSuccessfullyDuplicates(t *testing.T) { func TestReconnect(t *testing.T) { server := servers[0] - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -158,7 +237,9 @@ func TestDisconnect_withReconnect(t *testing.T) { opts := connOpts opts.Reconnect = 10 * time.Second - connPool, err := pool.Connect([]string{servers[serverId]}, opts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, []string{servers[serverId]}, opts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -202,7 +283,9 @@ func TestDisconnectAll(t *testing.T) { server1 := servers[0] server2 := servers[1] - connPool, err := pool.Connect([]string{server1, server2}, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, []string{server1, server2}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -249,14 +332,18 @@ func TestDisconnectAll(t *testing.T) { } func TestAdd(t *testing.T) { - connPool, err := pool.Connect([]string{servers[0]}, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, []string{servers[0]}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") defer connPool.Close() for _, server := range servers[1:] { - err = connPool.Add(server) + ctx, cancel := test_helpers.GetConnectContext() + err = connPool.Add(ctx, server) + cancel() require.Nil(t, err) } @@ -280,13 +367,17 @@ func TestAdd(t *testing.T) { } func TestAdd_exist(t *testing.T) { - connPool, err := pool.Connect([]string{servers[0]}, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + connPool, err := pool.Connect(ctx, []string{servers[0]}, connOpts) + cancel() require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") defer connPool.Close() - err = connPool.Add(servers[0]) + ctx, cancel = test_helpers.GetConnectContext() + err = connPool.Add(ctx, servers[0]) + cancel() require.Equal(t, pool.ErrExists, err) args := test_helpers.CheckStatusesArgs{ @@ -305,13 +396,17 @@ func TestAdd_exist(t *testing.T) { } func TestAdd_unreachable(t *testing.T) { - connPool, err := pool.Connect([]string{servers[0]}, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + connPool, err := pool.Connect(ctx, []string{servers[0]}, connOpts) + cancel() require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") defer connPool.Close() - err = connPool.Add("127.0.0.2:6667") + ctx, cancel = test_helpers.GetConnectContext() + err = connPool.Add(ctx, "127.0.0.2:6667") + cancel() // The OS-dependent error so we just check for existence. require.NotNil(t, err) @@ -331,17 +426,23 @@ func TestAdd_unreachable(t *testing.T) { } func TestAdd_afterClose(t *testing.T) { - connPool, err := pool.Connect([]string{servers[0]}, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + connPool, err := pool.Connect(ctx, []string{servers[0]}, connOpts) + cancel() require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") connPool.Close() - err = connPool.Add(servers[0]) + ctx, cancel = test_helpers.GetConnectContext() + err = connPool.Add(ctx, servers[0]) + cancel() assert.Equal(t, err, pool.ErrClosed) } func TestAdd_Close_concurrent(t *testing.T) { - connPool, err := pool.Connect([]string{servers[0]}, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + connPool, err := pool.Connect(ctx, []string{servers[0]}, connOpts) + cancel() require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -350,7 +451,9 @@ func TestAdd_Close_concurrent(t *testing.T) { go func() { defer wg.Done() - err = connPool.Add(servers[1]) + ctx, cancel := test_helpers.GetConnectContext() + err = connPool.Add(ctx, servers[1]) + cancel() if err != nil { assert.Equal(t, pool.ErrClosed, err) } @@ -362,7 +465,9 @@ func TestAdd_Close_concurrent(t *testing.T) { } func TestAdd_CloseGraceful_concurrent(t *testing.T) { - connPool, err := pool.Connect([]string{servers[0]}, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + connPool, err := pool.Connect(ctx, []string{servers[0]}, connOpts) + cancel() require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -371,7 +476,9 @@ func TestAdd_CloseGraceful_concurrent(t *testing.T) { go func() { defer wg.Done() - err = connPool.Add(servers[1]) + ctx, cancel := test_helpers.GetConnectContext() + err = connPool.Add(ctx, servers[1]) + cancel() if err != nil { assert.Equal(t, pool.ErrClosed, err) } @@ -383,7 +490,9 @@ func TestAdd_CloseGraceful_concurrent(t *testing.T) { } func TestRemove(t *testing.T) { - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -410,7 +519,9 @@ func TestRemove(t *testing.T) { } func TestRemove_double(t *testing.T) { - connPool, err := pool.Connect([]string{servers[0], servers[1]}, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, []string{servers[0], servers[1]}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -437,7 +548,9 @@ func TestRemove_double(t *testing.T) { } func TestRemove_unknown(t *testing.T) { - connPool, err := pool.Connect([]string{servers[0], servers[1]}, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, []string{servers[0], servers[1]}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -463,7 +576,9 @@ func TestRemove_unknown(t *testing.T) { } func TestRemove_concurrent(t *testing.T) { - connPool, err := pool.Connect([]string{servers[0], servers[1]}, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, []string{servers[0], servers[1]}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -510,7 +625,9 @@ func TestRemove_concurrent(t *testing.T) { } func TestRemove_Close_concurrent(t *testing.T) { - connPool, err := pool.Connect([]string{servers[0], servers[1]}, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, []string{servers[0], servers[1]}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -529,7 +646,9 @@ func TestRemove_Close_concurrent(t *testing.T) { } func TestRemove_CloseGraceful_concurrent(t *testing.T) { - connPool, err := pool.Connect([]string{servers[0], servers[1]}, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, []string{servers[0], servers[1]}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -551,7 +670,9 @@ func TestClose(t *testing.T) { server1 := servers[0] server2 := servers[1] - connPool, err := pool.Connect([]string{server1, server2}, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, []string{server1, server2}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -591,7 +712,9 @@ func TestCloseGraceful(t *testing.T) { server1 := servers[0] server2 := servers[1] - connPool, err := pool.Connect([]string{server1, server2}, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, []string{server1, server2}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -730,7 +853,9 @@ func TestConnectionHandlerOpenUpdateClose(t *testing.T) { CheckTimeout: 100 * time.Microsecond, ConnectionHandler: h, } - connPool, err := pool.ConnectWithOpts(poolServers, connOpts, poolOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.ConnectWithOpts(ctx, poolServers, connOpts, poolOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -804,7 +929,9 @@ func TestConnectionHandlerOpenError(t *testing.T) { CheckTimeout: 100 * time.Microsecond, ConnectionHandler: h, } - connPool, err := pool.ConnectWithOpts(poolServers, connOpts, poolOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.ConnectWithOpts(ctx, poolServers, connOpts, poolOpts) if err == nil { defer connPool.Close() } @@ -846,7 +973,9 @@ func TestConnectionHandlerUpdateError(t *testing.T) { CheckTimeout: 100 * time.Microsecond, ConnectionHandler: h, } - connPool, err := pool.ConnectWithOpts(poolServers, connOpts, poolOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.ConnectWithOpts(ctx, poolServers, connOpts, poolOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") defer connPool.Close() @@ -891,7 +1020,9 @@ func TestRequestOnClosed(t *testing.T) { server1 := servers[0] server2 := servers[1] - connPool, err := pool.Connect([]string{server1, server2}, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, []string{server1, server2}, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -930,7 +1061,9 @@ func TestGetPoolInfo(t *testing.T) { srvs := []string{server1, server2} expected := []string{server1, server2} - connPool, err := pool.Connect(srvs, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, srvs, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -948,7 +1081,9 @@ func TestCall(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1005,7 +1140,9 @@ func TestCall16(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1062,7 +1199,9 @@ func TestCall17(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1119,7 +1258,9 @@ func TestEval(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1197,7 +1338,9 @@ func TestExecute(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1254,7 +1397,9 @@ func TestRoundRobinStrategy(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1331,7 +1476,9 @@ func TestRoundRobinStrategy_NoReplica(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1402,7 +1549,9 @@ func TestRoundRobinStrategy_NoMaster(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1485,7 +1634,9 @@ func TestUpdateInstancesRoles(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1629,7 +1780,9 @@ func TestInsert(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1728,7 +1881,9 @@ func TestDelete(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1792,7 +1947,9 @@ func TestUpsert(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1864,7 +2021,9 @@ func TestUpdate(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -1954,7 +2113,9 @@ func TestReplace(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -2040,7 +2201,9 @@ func TestSelect(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -2160,7 +2323,9 @@ func TestPing(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -2198,7 +2363,9 @@ func TestDo(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -2237,7 +2404,9 @@ func TestDo_concurrent(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -2268,7 +2437,9 @@ func TestNewPrepared(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -2336,7 +2507,9 @@ func TestDoWithStrangerConn(t *testing.T) { err := test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") @@ -2365,7 +2538,9 @@ func TestStream_Commit(t *testing.T) { err = test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") defer connPool.Close() @@ -2464,7 +2639,9 @@ func TestStream_Rollback(t *testing.T) { err = test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") defer connPool.Close() @@ -2563,7 +2740,9 @@ func TestStream_TxnIsolationLevel(t *testing.T) { err = test_helpers.SetClusterRO(servers, connOpts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") defer connPool.Close() @@ -2657,7 +2836,9 @@ func TestConnectionPool_NewWatcher_noWatchersFeature(t *testing.T) { err := test_helpers.SetClusterRO(servers, opts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, opts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, servers, opts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") defer connPool.Close() @@ -2685,7 +2866,9 @@ func TestConnectionPool_NewWatcher_modes(t *testing.T) { err := test_helpers.SetClusterRO(servers, opts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, opts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, servers, opts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") defer connPool.Close() @@ -2767,7 +2950,9 @@ func TestConnectionPool_NewWatcher_update(t *testing.T) { poolOpts := pool.Opts{ CheckTimeout: 500 * time.Millisecond, } - pool, err := pool.ConnectWithOpts(servers, opts, poolOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + pool, err := pool.ConnectWithOpts(ctx, servers, opts, poolOpts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, pool, "conn is nil after Connect") @@ -2851,7 +3036,9 @@ func TestWatcher_Unregister(t *testing.T) { err := test_helpers.SetClusterRO(servers, opts, roles) require.Nilf(t, err, "fail to set roles for cluster") - pool, err := pool.Connect(servers, opts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + pool, err := pool.Connect(ctx, servers, opts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, pool, "conn is nil after Connect") defer pool.Close() @@ -2910,7 +3097,9 @@ func TestConnectionPool_NewWatcher_concurrent(t *testing.T) { err := test_helpers.SetClusterRO(servers, opts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, opts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, servers, opts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") defer connPool.Close() @@ -2950,7 +3139,9 @@ func TestWatcher_Unregister_concurrent(t *testing.T) { err := test_helpers.SetClusterRO(servers, opts, roles) require.Nilf(t, err, "fail to set roles for cluster") - connPool, err := pool.Connect(servers, opts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, servers, opts) require.Nilf(t, err, "failed to connect") require.NotNilf(t, connPool, "conn is nil after Connect") defer connPool.Close() diff --git a/pool/example_test.go b/pool/example_test.go index 84a41ff7b..dae28de90 100644 --- a/pool/example_test.go +++ b/pool/example_test.go @@ -24,7 +24,9 @@ func examplePool(roles []bool, connOpts tarantool.Opts) (*pool.ConnectionPool, e if err != nil { return nil, fmt.Errorf("ConnectionPool is not established") } - connPool, err := pool.Connect(servers, connOpts) + ctx, cancel := test_helpers.GetPoolConnectContext() + defer cancel() + connPool, err := pool.Connect(ctx, servers, connOpts) if err != nil || connPool == nil { return nil, fmt.Errorf("ConnectionPool is not established") } diff --git a/queue/example_connection_pool_test.go b/queue/example_connection_pool_test.go index 51fb967a5..a3bfaf0a4 100644 --- a/queue/example_connection_pool_test.go +++ b/queue/example_connection_pool_test.go @@ -1,6 +1,7 @@ package queue_test import ( + "context" "fmt" "sync" "sync/atomic" @@ -164,7 +165,9 @@ func Example_connectionPool() { CheckTimeout: 5 * time.Second, ConnectionHandler: h, } - connPool, err := pool.ConnectWithOpts(servers, connOpts, poolOpts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + connPool, err := pool.ConnectWithOpts(ctx, servers, connOpts, poolOpts) if err != nil { fmt.Printf("Unable to connect to the pool: %s", err) return diff --git a/queue/example_msgpack_test.go b/queue/example_msgpack_test.go index 6fd101e09..6d3637417 100644 --- a/queue/example_msgpack_test.go +++ b/queue/example_msgpack_test.go @@ -9,6 +9,7 @@ package queue_test import ( + "context" "fmt" "log" "time" @@ -55,7 +56,9 @@ func Example_simpleQueueCustomMsgPack() { User: "test", Pass: "test", } - conn, err := tarantool.Connect("127.0.0.1:3013", opts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + conn, err := tarantool.Connect(ctx, "127.0.0.1:3013", opts) + cancel() if err != nil { log.Fatalf("connection: %s", err) return diff --git a/queue/example_test.go b/queue/example_test.go index 711ee31d4..e81acca40 100644 --- a/queue/example_test.go +++ b/queue/example_test.go @@ -9,6 +9,7 @@ package queue_test import ( + "context" "fmt" "time" @@ -31,7 +32,9 @@ func Example_simpleQueue() { Pass: "test", } - conn, err := tarantool.Connect("127.0.0.1:3013", opts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err := tarantool.Connect(ctx, "127.0.0.1:3013", opts) if err != nil { fmt.Printf("error in prepare is %v", err) return diff --git a/settings/example_test.go b/settings/example_test.go index b1d0e5d4f..29be33bfe 100644 --- a/settings/example_test.go +++ b/settings/example_test.go @@ -1,7 +1,9 @@ package settings_test import ( + "context" "fmt" + "time" "github.com/tarantool/go-tarantool/v2" "github.com/tarantool/go-tarantool/v2/settings" @@ -9,7 +11,9 @@ import ( ) func example_connect(opts tarantool.Opts) *tarantool.Connection { - conn, err := tarantool.Connect(server, opts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + conn, err := tarantool.Connect(ctx, server, opts) if err != nil { panic("Connection is not established: " + err.Error()) } diff --git a/shutdown_test.go b/shutdown_test.go index bb4cfa099..412d27ea4 100644 --- a/shutdown_test.go +++ b/shutdown_test.go @@ -460,9 +460,12 @@ func TestGracefulShutdownCloseConcurrent(t *testing.T) { go func(i int) { defer caseWg.Done() + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + // Do not wait till Tarantool register out watcher, // test everything is ok even on async. - conn, err := Connect(shtdnServer, shtdnClntOpts) + conn, err := Connect(ctx, shtdnServer, shtdnClntOpts) if err != nil { t.Errorf("Failed to connect: %s", err) } else { diff --git a/ssl.go b/ssl.go index a23238849..8ca430559 100644 --- a/ssl.go +++ b/ssl.go @@ -5,24 +5,24 @@ package tarantool import ( "bufio" + "context" "errors" "io/ioutil" "net" "os" "strings" - "time" "github.com/tarantool/go-openssl" ) -func sslDialTimeout(network, address string, timeout time.Duration, +func sslDialContext(ctx context.Context, network, address string, opts SslOpts) (connection net.Conn, err error) { - var ctx interface{} - if ctx, err = sslCreateContext(opts); err != nil { + var sslCtx interface{} + if sslCtx, err = sslCreateContext(opts); err != nil { return } - return openssl.DialTimeout(network, address, timeout, ctx.(*openssl.Ctx), 0) + return openssl.DialContext(ctx, network, address, sslCtx.(*openssl.Ctx), 0) } // interface{} is a hack. It helps to avoid dependency of go-openssl in build diff --git a/ssl_disable.go b/ssl_disable.go index 8d0ab406b..6a2aa2163 100644 --- a/ssl_disable.go +++ b/ssl_disable.go @@ -4,12 +4,12 @@ package tarantool import ( + "context" "errors" "net" - "time" ) -func sslDialTimeout(network, address string, timeout time.Duration, +func sslDialContext(ctx context.Context, network, address string, opts SslOpts) (connection net.Conn, err error) { return nil, errors.New("SSL support is disabled.") } diff --git a/ssl_test.go b/ssl_test.go index 30078703c..65be85504 100644 --- a/ssl_test.go +++ b/ssl_test.go @@ -4,6 +4,7 @@ package tarantool_test import ( + "context" "errors" "fmt" "io/ioutil" @@ -60,12 +61,12 @@ func serverSslRecv(msgs <-chan string, errs <-chan error) (string, error) { return <-msgs, <-errs } -func clientSsl(network, address string, opts SslOpts) (net.Conn, error) { - timeout := 5 * time.Second - return SslDialTimeout(network, address, timeout, opts) +func clientSsl(ctx context.Context, network, address string, + opts SslOpts) (net.Conn, error) { + return SslDialContext(ctx, network, address, opts) } -func createClientServerSsl(t testing.TB, serverOpts, +func createClientServerSsl(ctx context.Context, t testing.TB, serverOpts, clientOpts SslOpts) (net.Listener, net.Conn, <-chan string, <-chan error, error) { t.Helper() @@ -77,16 +78,16 @@ func createClientServerSsl(t testing.TB, serverOpts, msgs, errs := serverSslAccept(l) port := l.Addr().(*net.TCPAddr).Port - c, err := clientSsl("tcp", sslHost+":"+strconv.Itoa(port), clientOpts) + c, err := clientSsl(ctx, "tcp", sslHost+":"+strconv.Itoa(port), clientOpts) return l, c, msgs, errs, err } -func createClientServerSslOk(t testing.TB, serverOpts, +func createClientServerSslOk(ctx context.Context, t testing.TB, serverOpts, clientOpts SslOpts) (net.Listener, net.Conn, <-chan string, <-chan error) { t.Helper() - l, c, msgs, errs, err := createClientServerSsl(t, serverOpts, clientOpts) + l, c, msgs, errs, err := createClientServerSsl(ctx, t, serverOpts, clientOpts) if err != nil { t.Fatalf("Unable to create client, error %q", err.Error()) } @@ -150,7 +151,9 @@ func serverTntStop(inst test_helpers.TarantoolInstance) { } func checkTntConn(clientOpts SslOpts) error { - conn, err := Connect(tntHost, Opts{ + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := Connect(ctx, tntHost, Opts{ Auth: AutoAuth, Timeout: 500 * time.Millisecond, User: "test", @@ -166,10 +169,11 @@ func checkTntConn(clientOpts SslOpts) error { return nil } -func assertConnectionSslFail(t testing.TB, serverOpts, clientOpts SslOpts) { +func assertConnectionSslFail(ctx context.Context, t testing.TB, serverOpts, + clientOpts SslOpts) { t.Helper() - l, c, _, _, err := createClientServerSsl(t, serverOpts, clientOpts) + l, c, _, _, err := createClientServerSsl(ctx, t, serverOpts, clientOpts) l.Close() if err == nil { c.Close() @@ -177,10 +181,11 @@ func assertConnectionSslFail(t testing.TB, serverOpts, clientOpts SslOpts) { } } -func assertConnectionSslOk(t testing.TB, serverOpts, clientOpts SslOpts) { +func assertConnectionSslOk(ctx context.Context, t testing.TB, serverOpts, + clientOpts SslOpts) { t.Helper() - l, c, msgs, errs := createClientServerSslOk(t, serverOpts, clientOpts) + l, c, msgs, errs := createClientServerSslOk(ctx, t, serverOpts, clientOpts) const message = "any test string" c.Write([]byte(message)) c.Close() @@ -621,15 +626,19 @@ func TestSslOpts(t *testing.T) { isTntSsl := isTestTntSsl() for _, test := range tests { + var ctx context.Context + var cancel context.CancelFunc + ctx, cancel = test_helpers.GetConnectContext() if test.ok { t.Run("ok_ssl_"+test.name, func(t *testing.T) { - assertConnectionSslOk(t, test.serverOpts, test.clientOpts) + assertConnectionSslOk(ctx, t, test.serverOpts, test.clientOpts) }) } else { t.Run("fail_ssl_"+test.name, func(t *testing.T) { - assertConnectionSslFail(t, test.serverOpts, test.clientOpts) + assertConnectionSslFail(ctx, t, test.serverOpts, test.clientOpts) }) } + cancel() if !isTntSsl { continue } @@ -645,6 +654,35 @@ func TestSslOpts(t *testing.T) { } } +func TestSslDialContextCancel(t *testing.T) { + serverOpts := SslOpts{ + KeyFile: "testdata/localhost.key", + CertFile: "testdata/localhost.crt", + CaFile: "testdata/ca.crt", + Ciphers: "ECDHE-RSA-AES256-GCM-SHA384", + } + clientOpts := SslOpts{ + KeyFile: "testdata/localhost.key", + CertFile: "testdata/localhost.crt", + CaFile: "testdata/ca.crt", + Ciphers: "ECDHE-RSA-AES256-GCM-SHA384", + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + l, c, _, _, err := createClientServerSsl(ctx, t, serverOpts, clientOpts) + l.Close() + + if err == nil { + c.Close() + t.Fatalf("Expected error, dial was not canceled") + } + if !strings.Contains(err.Error(), "operation was canceled") { + t.Fatalf("Unexpected error, expected to contain %s, got %v", + "operation was canceled", err) + } +} + func TestOpts_PapSha256Auth(t *testing.T) { isTntSsl := isTestTntSsl() if !isTntSsl { diff --git a/tarantool_test.go b/tarantool_test.go index 6339164f1..fe6368a66 100644 --- a/tarantool_test.go +++ b/tarantool_test.go @@ -708,7 +708,9 @@ func TestTtDialer(t *testing.T) { assert := assert.New(t) require := require.New(t) - conn, err := TtDialer{}.Dial(server, DialOpts{}) + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := TtDialer{}.Dial(ctx, server, DialOpts{}) require.Nil(err) require.NotNil(conn) defer conn.Close() @@ -778,7 +780,9 @@ func TestOptsAuth_PapSha256AuthForbit(t *testing.T) { papSha256Opts := opts papSha256Opts.Auth = PapSha256Auth - conn, err := Connect(server, papSha256Opts) + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := Connect(ctx, server, papSha256Opts) if err == nil { t.Error("An error expected.") conn.Close() @@ -3408,7 +3412,9 @@ func TestConnectionProtocolVersionRequirementSuccess(t *testing.T) { Version: ProtocolVersion(3), } - conn, err := Connect(server, connOpts) + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := Connect(ctx, server, connOpts) require.Nilf(t, err, "No errors on connect") require.NotNilf(t, conn, "Connect success") @@ -3424,7 +3430,9 @@ func TestConnectionProtocolVersionRequirementFail(t *testing.T) { Version: ProtocolVersion(3), } - conn, err := Connect(server, connOpts) + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := Connect(ctx, server, connOpts) require.Nilf(t, conn, "Connect fail") require.NotNilf(t, err, "Got error on connect") @@ -3439,7 +3447,9 @@ func TestConnectionProtocolFeatureRequirementSuccess(t *testing.T) { Features: []ProtocolFeature{TransactionsFeature}, } - conn, err := Connect(server, connOpts) + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := Connect(ctx, server, connOpts) require.NotNilf(t, conn, "Connect success") require.Nilf(t, err, "No errors on connect") @@ -3455,7 +3465,9 @@ func TestConnectionProtocolFeatureRequirementFail(t *testing.T) { Features: []ProtocolFeature{TransactionsFeature}, } - conn, err := Connect(server, connOpts) + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := Connect(ctx, server, connOpts) require.Nilf(t, conn, "Connect fail") require.NotNilf(t, err, "Got error on connect") @@ -3471,7 +3483,9 @@ func TestConnectionProtocolFeatureRequirementManyFail(t *testing.T) { Features: []ProtocolFeature{TransactionsFeature, ProtocolFeature(15532)}, } - conn, err := Connect(server, connOpts) + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + conn, err := Connect(ctx, server, connOpts) require.Nilf(t, conn, "Connect fail") require.NotNilf(t, err, "Got error on connect") @@ -4003,7 +4017,9 @@ func TestConnect_schema_update(t *testing.T) { for i := 0; i < 100; i++ { fut := conn.Do(NewCallRequest("create_spaces")) - if conn, err := Connect(server, opts); err != nil { + ctx, cancel := test_helpers.GetConnectContext() + defer cancel() + if conn, err := Connect(ctx, server, opts); err != nil { if err.Error() != "concurrent schema update" { t.Errorf("unexpected error: %s", err) } @@ -4019,6 +4035,32 @@ func TestConnect_schema_update(t *testing.T) { } } +func TestConnect_context_cancel(t *testing.T) { + var connLongReconnectOpts = Opts{ + Timeout: 5 * time.Second, + User: "test", + Pass: "test", + Reconnect: time.Second, + MaxReconnects: 100, + } + + ctx, cancel := context.WithCancel(context.Background()) + + var conn *Connection + var err error + + cancel() + conn, err = Connect(ctx, server, connLongReconnectOpts) + + if conn != nil || err == nil { + t.Fatalf("ConnectionPool was created after cancel") + } + if !strings.Contains(err.Error(), "operation was canceled") { + t.Fatalf("Unexpected error, expected to contain %s, got %v", + "operation was canceled", err) + } +} + // runTestMain is a body of TestMain function // (see https://pkg.go.dev/testing#hdr-Main). // Using defer + os.Exit is not works so TestMain body diff --git a/test_helpers/main.go b/test_helpers/main.go index 894ebb653..cc806a7d2 100644 --- a/test_helpers/main.go +++ b/test_helpers/main.go @@ -11,6 +11,7 @@ package test_helpers import ( + "context" "errors" "fmt" "io" @@ -97,7 +98,9 @@ func isReady(server string, opts *tarantool.Opts) error { var conn *tarantool.Connection var resp *tarantool.Response - conn, err = tarantool.Connect(server, *opts) + ctx, cancel := GetConnectContext() + defer cancel() + conn, err = tarantool.Connect(ctx, server, *opts) if err != nil { return err } @@ -402,3 +405,7 @@ func ConvertUint64(v interface{}) (result uint64, err error) { } return } + +func GetConnectContext() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), 500*time.Millisecond) +} diff --git a/test_helpers/pool_helper.go b/test_helpers/pool_helper.go index c44df2f6a..b2340ccb8 100644 --- a/test_helpers/pool_helper.go +++ b/test_helpers/pool_helper.go @@ -1,6 +1,7 @@ package test_helpers import ( + "context" "fmt" "reflect" "time" @@ -130,9 +131,9 @@ func Retry(f func(interface{}) error, args interface{}, count int, timeout time. return err } -func InsertOnInstance(server string, connOpts tarantool.Opts, space interface{}, - tuple interface{}) error { - conn, err := tarantool.Connect(server, connOpts) +func InsertOnInstance(ctx context.Context, server string, connOpts tarantool.Opts, + space interface{}, tuple interface{}) error { + conn, err := tarantool.Connect(ctx, server, connOpts) if err != nil { return fmt.Errorf("fail to connect to %s: %s", server, err.Error()) } @@ -182,7 +183,9 @@ func InsertOnInstances(servers []string, connOpts tarantool.Opts, space interfac } for _, server := range servers { - err := InsertOnInstance(server, connOpts, space, tuple) + ctx, cancel := GetConnectContext() + err := InsertOnInstance(ctx, server, connOpts, space, tuple) + cancel() if err != nil { return err } @@ -191,8 +194,9 @@ func InsertOnInstances(servers []string, connOpts tarantool.Opts, space interfac return nil } -func SetInstanceRO(server string, connOpts tarantool.Opts, isReplica bool) error { - conn, err := tarantool.Connect(server, connOpts) +func SetInstanceRO(ctx context.Context, server string, connOpts tarantool.Opts, + isReplica bool) error { + conn, err := tarantool.Connect(ctx, server, connOpts) if err != nil { return err } @@ -214,7 +218,9 @@ func SetClusterRO(servers []string, connOpts tarantool.Opts, roles []bool) error } for i, server := range servers { - err := SetInstanceRO(server, connOpts, roles[i]) + ctx, cancel := GetConnectContext() + err := SetInstanceRO(ctx, server, connOpts, roles[i]) + cancel() if err != nil { return err } @@ -257,3 +263,7 @@ func StopTarantoolInstances(instances []TarantoolInstance) { StopTarantoolWithCleanup(instance) } } + +func GetPoolConnectContext() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), 500*time.Millisecond) +} diff --git a/test_helpers/utils.go b/test_helpers/utils.go index 3771a5f9e..898ae84e3 100644 --- a/test_helpers/utils.go +++ b/test_helpers/utils.go @@ -17,7 +17,9 @@ func ConnectWithValidation(t testing.TB, opts tarantool.Opts) *tarantool.Connection { t.Helper() - conn, err := tarantool.Connect(server, opts) + ctx, cancel := GetConnectContext() + defer cancel() + conn, err := tarantool.Connect(ctx, server, opts) if err != nil { t.Fatalf("Failed to connect: %s", err.Error()) } diff --git a/uuid/example_test.go b/uuid/example_test.go index 632f620be..08bd64aae 100644 --- a/uuid/example_test.go +++ b/uuid/example_test.go @@ -9,8 +9,10 @@ package uuid_test import ( + "context" "fmt" "log" + "time" "github.com/google/uuid" "github.com/tarantool/go-tarantool/v2" @@ -25,7 +27,9 @@ func Example() { User: "test", Pass: "test", } - client, err := tarantool.Connect("127.0.0.1:3013", opts) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + client, err := tarantool.Connect(ctx, "127.0.0.1:3013", opts) + cancel() if err != nil { log.Fatalf("Failed to connect: %s", err.Error()) }