Skip to content

Commit

Permalink
dial: add DialContext function
Browse files Browse the repository at this point in the history
In order to replace timeouts with contexts in `Connect` instance
creation (go-tarantool), I need a `DialContext` function.
It accepts context, and cancels, if context is canceled by user.

Part of tarantool/go-tarantool#136
  • Loading branch information
DerekBum committed Sep 29, 2023
1 parent b452431 commit db9ed8f
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 25 deletions.
102 changes: 77 additions & 25 deletions net.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package openssl

import (
"context"
"errors"
"net"
"time"
Expand Down Expand Up @@ -90,7 +91,37 @@ func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) {
func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx,
flags DialFlags) (*Conn, error) {
d := net.Dialer{Timeout: timeout}
return dialSession(d, network, addr, ctx, flags, nil)
ctx, err := prepareCtx(ctx)
if err != nil {
return nil, err
}
conn, err := createConnection(context.Background(), d, network, addr)
if err != nil {
return nil, err
}
defer conn.Close()
return createSession(conn, flags, addr, ctx, nil)
}

// DialContext acts like Dial but takes a context for network dial.
//
// The context includes only network dial. It does not include OpenSSL calls.
//
// See func Dial for a description of the network, addr, ctx and flags
// parameters.
func DialContext(context context.Context, network, addr string,
ctx *Ctx, flags DialFlags) (*Conn, error) {
d := net.Dialer{}
ctx, err := prepareCtx(ctx)
if err != nil {
return nil, err
}
conn, err := createConnection(context, d, network, addr)
if err != nil {
return nil, err
}
defer conn.Close()
return createSession(conn, flags, addr, ctx, nil)
}

// DialSession will connect to network/address and then wrap the corresponding
Expand All @@ -109,15 +140,19 @@ func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx,
func DialSession(network, addr string, ctx *Ctx, flags DialFlags,
session []byte) (*Conn, error) {
var d net.Dialer
return dialSession(d, network, addr, ctx, flags, session)
}

func dialSession(d net.Dialer, network, addr string, ctx *Ctx, flags DialFlags,
session []byte) (*Conn, error) {
host, _, err := net.SplitHostPort(addr)
ctx, err := prepareCtx(ctx)
if err != nil {
return nil, err
}
conn, err := createConnection(context.Background(), d, network, addr)
if err != nil {
return nil, err
}
defer conn.Close()
return createSession(conn, flags, addr, ctx, session)
}

func prepareCtx(ctx *Ctx) (*Ctx, error) {
if ctx == nil {
var err error
ctx, err = NewCtx()
Expand All @@ -126,41 +161,58 @@ func dialSession(d net.Dialer, network, addr string, ctx *Ctx, flags DialFlags,
}
// TODO: use operating system default certificate chain?
}
return ctx, nil
}

c, err := d.Dial(network, addr)
func createConnection(context context.Context, d net.Dialer, network,
addr string) (net.Conn, error) {
c, err := d.DialContext(context, network, addr)
if err != nil {
return nil, err
}
conn, err := Client(c, ctx)
if err != nil {
c.Close()
return nil, err
}
if session != nil {
err := conn.setSession(session)
if err != nil {
c.Close()
return nil, err
}
}
return c, nil
}

func applyFlags(conn *Conn, host string, flags DialFlags) error {
var err error
if flags&DisableSNI == 0 {
err = conn.SetTlsExtHostName(host)
if err != nil {
conn.Close()
return nil, err
return err
}
}
err = conn.Handshake()
if err != nil {
conn.Close()
return nil, err
return err
}
if flags&InsecureSkipHostVerification == 0 {
err = conn.VerifyHostname(host)
if err != nil {
conn.Close()
return err
}
}
return nil
}

func createSession(c net.Conn, flags DialFlags, addr string, ctx *Ctx,
session []byte) (*Conn, error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
conn, err := Client(c, ctx)
if err != nil {
return nil, err
}
defer conn.Close()
if session != nil {
err := conn.setSession(session)
if err != nil {
return nil, err
}
}
if err := applyFlags(conn, host, flags); err != nil {
return nil, err
}
return conn, nil
}
36 changes: 36 additions & 0 deletions net_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package openssl

import (
"context"
"net"
"testing"
"time"
)

func TestDialTimeout(t *testing.T) {
tcp_listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatal(err)
}
ctx, _ := NewCtx()
client, err := DialTimeout(tcp_listener.Addr().Network(),
tcp_listener.Addr().String(), time.Nanosecond, ctx, 0)
if client != nil || err == nil {
t.Fatalf("expected error")
}
}

func TestDialContext(t *testing.T) {
tcp_listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatal(err)
}
cancelCtx, cancel := context.WithCancel(context.Background())
ctx, _ := NewCtx()
cancel()
client, err := DialContext(cancelCtx, tcp_listener.Addr().Network(),
tcp_listener.Addr().String(), ctx, 0)
if client != nil || err == nil {
t.Fatalf("expected error")
}
}

0 comments on commit db9ed8f

Please sign in to comment.