Skip to content

Commit

Permalink
dial: add DialContext function
Browse files Browse the repository at this point in the history
In order to replace timeouts with contexts in `Connect` instance
creation (go-tarantool), I need a `DialContext` function.
It accepts context, and cancels, if context is canceled by user.

Part of tarantool/go-tarantool#136
  • Loading branch information
DerekBum committed Oct 2, 2023
1 parent b452431 commit b3ae863
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 33 deletions.
125 changes: 95 additions & 30 deletions net.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package openssl

import (
"context"
"errors"
"net"
"time"
Expand Down Expand Up @@ -89,8 +90,55 @@ func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) {
// parameters.
func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx,
flags DialFlags) (*Conn, error) {
d := net.Dialer{Timeout: timeout}
return dialSession(d, network, addr, ctx, flags, nil)
host, err := parseHost(addr)
if err != nil {
return nil, err
}

conn, err := net.DialTimeout(network, addr, timeout)
if err != nil {
return nil, err
}
ctx, err = prepareCtx(ctx)
if err != nil {
conn.Close()
return nil, err
}
client, err := createSession(conn, flags, host, ctx, nil)
if err != nil {
conn.Close()
}
return client, err
}

// DialContext acts like Dial but takes a context for network dial.
//
// The context includes only network dial. It does not include OpenSSL calls.
//
// See func Dial for a description of the network, addr, ctx and flags
// parameters.
func DialContext(context context.Context, network, addr string,
ctx *Ctx, flags DialFlags) (*Conn, error) {
host, err := parseHost(addr)
if err != nil {
return nil, err
}

dialer := net.Dialer{}
conn, err := dialer.DialContext(context, network, addr)
if err != nil {
return nil, err
}
ctx, err = prepareCtx(ctx)
if err != nil {
conn.Close()
return nil, err
}
client, err := createSession(conn, flags, host, ctx, nil)
if err != nil {
conn.Close()
}
return client, err
}

// DialSession will connect to network/address and then wrap the corresponding
Expand All @@ -108,59 +156,76 @@ func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx,
// can be retrieved from the GetSession method on the Conn.
func DialSession(network, addr string, ctx *Ctx, flags DialFlags,
session []byte) (*Conn, error) {
var d net.Dialer
return dialSession(d, network, addr, ctx, flags, session)
}

func dialSession(d net.Dialer, network, addr string, ctx *Ctx, flags DialFlags,
session []byte) (*Conn, error) {
host, _, err := net.SplitHostPort(addr)
host, err := parseHost(addr)
if err != nil {
return nil, err
}
if ctx == nil {
var err error
ctx, err = NewCtx()
if err != nil {
return nil, err
}
// TODO: use operating system default certificate chain?
}

c, err := d.Dial(network, addr)
conn, err := net.Dial(network, addr)
if err != nil {
return nil, err
}
conn, err := Client(c, ctx)
ctx, err = prepareCtx(ctx)
if err != nil {
c.Close()
conn.Close()
return nil, err
}
if session != nil {
err := conn.setSession(session)
if err != nil {
c.Close()
return nil, err
}
client, err := createSession(conn, flags, host, ctx, session)
if err != nil {
conn.Close()
}
return client, err
}

func prepareCtx(ctx *Ctx) (*Ctx, error) {
if ctx == nil {
return NewCtx()
}
return ctx, nil
}

func parseHost(addr string) (string, error) {
host, _, err := net.SplitHostPort(addr)
return host, err
}

func handshake(conn *Conn, host string, flags DialFlags) error {
var err error
if flags&DisableSNI == 0 {
err = conn.SetTlsExtHostName(host)
if err != nil {
conn.Close()
return nil, err
return err
}
}
err = conn.Handshake()
if err != nil {
conn.Close()
return nil, err
return err
}
if flags&InsecureSkipHostVerification == 0 {
err = conn.VerifyHostname(host)
if err != nil {
return err
}
}
return nil
}

func createSession(c net.Conn, flags DialFlags, host string, ctx *Ctx,
session []byte) (*Conn, error) {
conn, err := Client(c, ctx)
if err != nil {
return nil, err
}
if session != nil {
err := conn.setSession(session)
if err != nil {
conn.Close()
return nil, err
}
}
if err := handshake(conn, host, flags); err != nil {
conn.Close()
return nil, err
}
return conn, nil
}
101 changes: 101 additions & 0 deletions net_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package openssl_test

import (
"context"
"crypto/rand"
"io"
"net"
"sync"
"testing"
"time"

"github.com/tarantool/go-openssl"
)

func sslConnect(t *testing.T, ssl_listener net.Listener) {
for {
var err error
conn, err := ssl_listener.Accept()
if err != nil {
t.Errorf("failed accept: %s", err)
continue
}
io.Copy(conn, io.LimitReader(rand.Reader, 1024))
break
}
}

func TestDial(t *testing.T) {
ctx := openssl.GetCtx(t)
if err := ctx.SetCipherList("AES128-SHA"); err != nil {
t.Fatal(err)
}
ssl_listener, err := openssl.Listen("tcp", "localhost:0", ctx)
if err != nil {
t.Fatal(err)
}

wg := sync.WaitGroup{}
wg.Add(1)
go func() {
sslConnect(t, ssl_listener)
wg.Done()
}()

client, err := openssl.Dial(ssl_listener.Addr().Network(),
ssl_listener.Addr().String(), ctx, openssl.InsecureSkipHostVerification)

wg.Wait()

if err != nil {
t.Fatalf("unexpected err: %v", err)
}
n, err := io.Copy(io.Discard, io.LimitReader(client, 1024))
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
if n != 1024 {
if n == 0 {
t.Fatal("client is closed after creation")
}
t.Fatalf("client lost some bytes, expected %d, got %d", 1024, n)
}
}

func TestDialTimeout(t *testing.T) {
ctx := openssl.GetCtx(t)
if err := ctx.SetCipherList("AES128-SHA"); err != nil {
t.Fatal(err)
}
ssl_listener, err := openssl.Listen("tcp", "localhost:0", ctx)
if err != nil {
t.Fatal(err)
}

client, err := openssl.DialTimeout(ssl_listener.Addr().Network(),
ssl_listener.Addr().String(), time.Nanosecond, ctx, 0)

if client != nil || err == nil {
t.Fatalf("expected error")
}
}

func TestDialContext(t *testing.T) {
ctx := openssl.GetCtx(t)
if err := ctx.SetCipherList("AES128-SHA"); err != nil {
t.Fatal(err)
}
ssl_listener, err := openssl.Listen("tcp", "localhost:0", ctx)
if err != nil {
t.Fatal(err)
}

cancelCtx, cancel := context.WithCancel(context.Background())
cancel()
client, err := openssl.DialContext(cancelCtx, ssl_listener.Addr().Network(),
ssl_listener.Addr().String(), ctx, 0)

if client != nil || err == nil {
t.Fatalf("expected error")
}
}
6 changes: 3 additions & 3 deletions ssl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,7 @@ func TestStdlibLotsOfConns(t *testing.T) {
})
}

func getCtx(t *testing.T) *Ctx {
func GetCtx(t *testing.T) *Ctx {
ctx, err := NewCtx()
if err != nil {
t.Fatal(err)
Expand All @@ -761,7 +761,7 @@ func getCtx(t *testing.T) *Ctx {
}

func TestOpenSSLLotsOfConns(t *testing.T) {
ctx := getCtx(t)
ctx := GetCtx(t)
if err := ctx.SetCipherList("AES128-SHA"); err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -928,7 +928,7 @@ func TestOpenSSLLotsOfConnsWithFail(t *testing.T) {
t.Run(name, func(t *testing.T) {
LotsOfConns(t, 1024*64, 10, 100, 0*time.Second,
func(l net.Listener) net.Listener {
return NewListener(l, getCtx(t))
return NewListener(l, GetCtx(t))
}, func(c net.Conn) (net.Conn, error) {
return Client(c, getClientCtx(t))
})
Expand Down

0 comments on commit b3ae863

Please sign in to comment.