Skip to content

Commit

Permalink
api: add context to connection create
Browse files Browse the repository at this point in the history
`connection.Connect` and `pool.Connect` no longer return non-working
connection objects. Those functions now accept context as their first
arguments, which user may cancel in process.

`connection.Connect` will block until either the working connection
created (and returned), `opts.MaxReconnects` creation attempts
were made (returns error) or the context is canceled by user
(returns error too).

Closes #136
  • Loading branch information
DerekBum committed Oct 6, 2023
1 parent d8df65d commit 48559f3
Show file tree
Hide file tree
Showing 29 changed files with 776 additions and 237 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ Versioning](http://semver.org/spec/v2.0.0.html) except to the first release.
decoded to a varbinary object (#313).
- Use objects of the Decimal type instead of pointers (#238)
- Use objects of the Datetime type instead of pointers (#238)
- `connection.Connect` and `pool.Connect` no longer return non-working
connection objects (#136). Those functions now accept context as their first
arguments, which user may cancel in process.

### Deprecated

Expand Down
21 changes: 18 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,19 @@ about what it does.
package tarantool

import (
"context"
"fmt"
"time"

"github.com/tarantool/go-tarantool/v2"
)

func main() {
opts := tarantool.Opts{User: "guest"}
conn, err := tarantool.Connect("127.0.0.1:3301", opts)
ctx, cancel := context.WithTimeout(context.Background(),
500 * time.Millisecond)
defer cancel()
conn, err := tarantool.Connect(ctx, "127.0.0.1:3301", opts)
if err != nil {
fmt.Println("Connection refused:", err)
}
Expand Down Expand Up @@ -139,6 +145,10 @@ starting a session. There are two parameters:
* a string with `host:port` format, and
* the option structure that was set up earlier.

There will be only one attempt to connect. If multiple attempts needed,
"`tarantool.Connect`" could be placed inside the loop with some timeout
between each try. Example could be found in [example_test](./example_test.go).

**Observation 4:** The `err` structure will be `nil` if there is no error,
otherwise it will have a description which can be retrieved with `err.Error()`.

Expand Down Expand Up @@ -167,10 +177,15 @@ The subpackage has been deleted. You could use `pool` instead.

#### pool package

The logic has not changed, but there are a few renames:

* The `connection_pool` subpackage has been renamed to `pool`.
* The type `PoolOpts` has been renamed to `Opts`.
* `pool.Connect` no longer return non-working connection objects. This function
now accept context as first argument, which user may cancel in process.

#### connection package

`connection.Connect` no longer return non-working connection objects. This
function now accept context as first argument, which user may cancel in process.

#### msgpack.v5

Expand Down
122 changes: 63 additions & 59 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -378,13 +378,9 @@ func (opts Opts) Clone() Opts {
//
// Notes:
//
// - If opts.Reconnect is zero (default), then connection either already connected
// or error is returned.
//
// - If opts.Reconnect is non-zero, then error will be returned only if authorization
// fails. But if Tarantool is not reachable, then it will make an attempt to reconnect later
// and will not finish to make attempts on authorization failures.
func Connect(addr string, opts Opts) (conn *Connection, err error) {
// - There will be only one attempt to connect. If multiple attempts needed,
// Connect could be placed inside the loop with some timeout between each try.
func Connect(ctx context.Context, addr string, opts Opts) (conn *Connection, err error) {
conn = &Connection{
addr: addr,
requestId: 0,
Expand Down Expand Up @@ -432,25 +428,8 @@ func Connect(addr string, opts Opts) (conn *Connection, err error) {

conn.cond = sync.NewCond(&conn.mutex)

if err = conn.createConnection(false); err != nil {
ter, ok := err.(Error)
if conn.opts.Reconnect <= 0 {
return nil, err
} else if ok && (ter.Code == iproto.ER_NO_SUCH_USER ||
ter.Code == iproto.ER_CREDS_MISMATCH) {
// Reported auth errors immediately.
return nil, err
} else {
// Without SkipSchema it is useless.
go func(conn *Connection) {
conn.mutex.Lock()
defer conn.mutex.Unlock()
if err := conn.createConnection(true); err != nil {
conn.closeConnection(err, true)
}
}(conn)
err = nil
}
if err = conn.createConnection(ctx); err != nil {
return nil, err
}

go conn.pinger()
Expand Down Expand Up @@ -534,18 +513,11 @@ func (conn *Connection) cancelFuture(fut *Future, err error) {
}
}

func (conn *Connection) dial() (err error) {
func (conn *Connection) dial(ctx context.Context) (err error) {
opts := conn.opts
dialTimeout := opts.Reconnect / 2
if dialTimeout == 0 {
dialTimeout = 500 * time.Millisecond
} else if dialTimeout > 5*time.Second {
dialTimeout = 5 * time.Second
}

var c Conn
c, err = conn.opts.Dialer.Dial(conn.addr, DialOpts{
DialTimeout: dialTimeout,
c, err = conn.opts.Dialer.Dial(ctx, conn.addr, DialOpts{
IoTimeout: opts.Timeout,
Transport: opts.Transport,
Ssl: opts.Ssl,
Expand Down Expand Up @@ -658,34 +630,19 @@ func pack(h *smallWBuf, enc *msgpack.Encoder, reqid uint32,
return
}

func (conn *Connection) createConnection(reconnect bool) (err error) {
var reconnects uint
for conn.c == nil && conn.state == connDisconnected {
now := time.Now()
err = conn.dial()
if err == nil || !reconnect {
if err == nil {
conn.notify(Connected)
}
return
}
if conn.opts.MaxReconnects > 0 && reconnects > conn.opts.MaxReconnects {
conn.opts.Logger.Report(LogLastReconnectFailed, conn, err)
err = ClientError{ErrConnectionClosed, "last reconnect failed"}
// mark connection as closed to avoid reopening by another goroutine
return
func (conn *Connection) createConnection(ctx context.Context) error {
var err error
if conn.c == nil && conn.state == connDisconnected {
err = conn.dial(ctx)
if err == nil {
conn.notify(Connected)
return nil
}
conn.opts.Logger.Report(LogReconnectFailed, conn, reconnects, err)
conn.notify(ReconnectFailed)
reconnects++
conn.mutex.Unlock()
time.Sleep(time.Until(now.Add(conn.opts.Reconnect)))
conn.mutex.Lock()
}
if conn.state == connClosed {
err = ClientError{ErrConnectionClosed, "using closed connection"}
}
return
return err
}

func (conn *Connection) closeConnection(neterr error, forever bool) (err error) {
Expand Down Expand Up @@ -727,11 +684,58 @@ func (conn *Connection) closeConnection(neterr error, forever bool) (err error)
return
}

func (conn *Connection) getDialTimeout() time.Duration {
dialTimeout := conn.opts.Reconnect / 2
if dialTimeout == 0 {
dialTimeout = 500 * time.Millisecond
} else if dialTimeout > 5*time.Second {
dialTimeout = 5 * time.Second
}
return dialTimeout
}

func (conn *Connection) runReconnects() error {
dialTimeout := conn.getDialTimeout()
var reconnects uint
var err error

for reconnects <= conn.opts.MaxReconnects {
now := time.Now()

ctx, cancel := context.WithTimeout(context.Background(), dialTimeout)
err = conn.createConnection(ctx)
cancel()

if err == nil {
return nil
}
if clientErr, ok := err.(ClientError); ok &&
clientErr.Code == ErrConnectionClosed {
return err
}

conn.opts.Logger.Report(LogReconnectFailed, conn, reconnects, err)
conn.notify(ReconnectFailed)
if conn.opts.MaxReconnects != 0 {
reconnects++
}
conn.mutex.Unlock()

time.Sleep(time.Until(now.Add(conn.opts.Reconnect)))

conn.mutex.Lock()
}

conn.opts.Logger.Report(LogLastReconnectFailed, conn, err)
// mark connection as closed to avoid reopening by another goroutine
return ClientError{ErrConnectionClosed, "last reconnect failed"}
}

func (conn *Connection) reconnectImpl(neterr error, c Conn) {
if conn.opts.Reconnect > 0 {
if c == conn.c {
conn.closeConnection(neterr, false)
if err := conn.createConnection(true); err != nil {
if err := conn.runReconnects(); err != nil {
conn.closeConnection(err, true)
}
}
Expand Down
5 changes: 4 additions & 1 deletion crud/example_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package crud_test

import (
"context"
"fmt"
"reflect"
"time"
Expand All @@ -21,7 +22,9 @@ var exampleOpts = tarantool.Opts{
}

func exampleConnect() *tarantool.Connection {
conn, err := tarantool.Connect(exampleServer, exampleOpts)
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
conn, err := tarantool.Connect(ctx, exampleServer, exampleOpts)
if err != nil {
panic("Connection is not established: " + err.Error())
}
Expand Down
6 changes: 5 additions & 1 deletion crud/tarantool_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package crud_test

import (
"context"
"fmt"
"log"
"os"
Expand Down Expand Up @@ -108,7 +109,10 @@ var object = crud.MapObject{

func connect(t testing.TB) *tarantool.Connection {
for i := 0; i < 10; i++ {
conn, err := tarantool.Connect(server, opts)
ctx, cancel := context.WithTimeout(context.Background(),
500*time.Millisecond)
conn, err := tarantool.Connect(ctx, server, opts)
cancel()
if err != nil {
t.Fatalf("Failed to connect: %s", err)
}
Expand Down
5 changes: 4 additions & 1 deletion datetime/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package datetime_test

import (
"context"
"fmt"
"time"

Expand All @@ -23,7 +24,9 @@ func Example() {
User: "test",
Pass: "test",
}
conn, err := tarantool.Connect("127.0.0.1:3013", opts)
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
conn, err := tarantool.Connect(ctx, "127.0.0.1:3013", opts)
if err != nil {
fmt.Printf("Error in connect is %v", err)
return
Expand Down
5 changes: 4 additions & 1 deletion decimal/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package decimal_test

import (
"context"
"log"
"time"

Expand All @@ -28,7 +29,9 @@ func Example() {
User: "test",
Pass: "test",
}
client, err := tarantool.Connect(server, opts)
ctx, cancel := context.WithTimeout(context.Background(), opts.Reconnect/2)
client, err := tarantool.Connect(ctx, server, opts)
cancel()
if err != nil {
log.Fatalf("Failed to connect: %s", err.Error())
}
Expand Down
16 changes: 8 additions & 8 deletions dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package tarantool
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -56,8 +57,6 @@ type Conn interface {

// DialOpts is a way to configure a Dial method to create a new Conn.
type DialOpts struct {
// DialTimeout is a timeout for an initial network dial.
DialTimeout time.Duration
// IoTimeout is a timeout per a network read/write.
IoTimeout time.Duration
// Transport is a connect transport type.
Expand Down Expand Up @@ -86,7 +85,7 @@ type DialOpts struct {
type Dialer interface {
// Dial connects to a Tarantool instance to the address with specified
// options.
Dial(address string, opts DialOpts) (Conn, error)
Dial(ctx context.Context, address string, opts DialOpts) (Conn, error)
}

type tntConn struct {
Expand All @@ -104,11 +103,11 @@ type TtDialer struct {

// Dial connects to a Tarantool instance to the address with specified
// options.
func (t TtDialer) Dial(address string, opts DialOpts) (Conn, error) {
func (t TtDialer) Dial(ctx context.Context, address string, opts DialOpts) (Conn, error) {
var err error
conn := new(tntConn)

if conn.net, err = dial(address, opts); err != nil {
if conn.net, err = dial(ctx, address, opts); err != nil {
return nil, fmt.Errorf("failed to dial: %w", err)
}

Expand Down Expand Up @@ -199,13 +198,14 @@ func (c *tntConn) ProtocolInfo() ProtocolInfo {
}

// dial connects to a Tarantool instance.
func dial(address string, opts DialOpts) (net.Conn, error) {
func dial(ctx context.Context, address string, opts DialOpts) (net.Conn, error) {
network, address := parseAddress(address)
switch opts.Transport {
case dialTransportNone:
return net.DialTimeout(network, address, opts.DialTimeout)
dialer := net.Dialer{}
return dialer.DialContext(ctx, network, address)
case dialTransportSsl:
return sslDialTimeout(network, address, opts.DialTimeout, opts.Ssl)
return sslDialContext(ctx, network, address, opts.Ssl)
default:
return nil, fmt.Errorf("unsupported transport type: %s", opts.Transport)
}
Expand Down
Loading

0 comments on commit 48559f3

Please sign in to comment.