diff --git a/net.go b/net.go index b2293c7c..552f76ad 100644 --- a/net.go +++ b/net.go @@ -15,6 +15,7 @@ package openssl import ( + "context" "errors" "net" "time" @@ -89,8 +90,55 @@ func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) { // parameters. func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx, flags DialFlags) (*Conn, error) { - d := net.Dialer{Timeout: timeout} - return dialSession(d, network, addr, ctx, flags, nil) + host, err := parseHost(addr) + if err != nil { + return nil, err + } + + conn, err := net.DialTimeout(network, addr, timeout) + if err != nil { + return nil, err + } + ctx, err = prepareCtx(ctx) + if err != nil { + conn.Close() + return nil, err + } + client, err := createSession(conn, flags, host, ctx, nil) + if err != nil { + conn.Close() + } + return client, err +} + +// DialContext acts like Dial but takes a context for network dial. +// +// The context includes only network dial. It does not include OpenSSL calls. +// +// See func Dial for a description of the network, addr, ctx and flags +// parameters. +func DialContext(context context.Context, network, addr string, + ctx *Ctx, flags DialFlags) (*Conn, error) { + host, err := parseHost(addr) + if err != nil { + return nil, err + } + + dialer := net.Dialer{} + conn, err := dialer.DialContext(context, network, addr) + if err != nil { + return nil, err + } + ctx, err = prepareCtx(ctx) + if err != nil { + conn.Close() + return nil, err + } + client, err := createSession(conn, flags, host, ctx, nil) + if err != nil { + conn.Close() + } + return client, err } // DialSession will connect to network/address and then wrap the corresponding @@ -108,59 +156,76 @@ func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx, // can be retrieved from the GetSession method on the Conn. func DialSession(network, addr string, ctx *Ctx, flags DialFlags, session []byte) (*Conn, error) { - var d net.Dialer - return dialSession(d, network, addr, ctx, flags, session) -} - -func dialSession(d net.Dialer, network, addr string, ctx *Ctx, flags DialFlags, - session []byte) (*Conn, error) { - host, _, err := net.SplitHostPort(addr) + host, err := parseHost(addr) if err != nil { return nil, err } - if ctx == nil { - var err error - ctx, err = NewCtx() - if err != nil { - return nil, err - } - // TODO: use operating system default certificate chain? - } - c, err := d.Dial(network, addr) + conn, err := net.Dial(network, addr) if err != nil { return nil, err } - conn, err := Client(c, ctx) + ctx, err = prepareCtx(ctx) if err != nil { - c.Close() + conn.Close() return nil, err } - if session != nil { - err := conn.setSession(session) - if err != nil { - c.Close() - return nil, err - } + client, err := createSession(conn, flags, host, ctx, session) + if err != nil { + conn.Close() + } + return client, err +} + +func prepareCtx(ctx *Ctx) (*Ctx, error) { + if ctx == nil { + return NewCtx() } + return ctx, nil +} + +func parseHost(addr string) (string, error) { + host, _, err := net.SplitHostPort(addr) + return host, err +} + +func handshake(conn *Conn, host string, flags DialFlags) error { + var err error if flags&DisableSNI == 0 { err = conn.SetTlsExtHostName(host) if err != nil { - conn.Close() - return nil, err + return err } } err = conn.Handshake() if err != nil { - conn.Close() - return nil, err + return err } if flags&InsecureSkipHostVerification == 0 { err = conn.VerifyHostname(host) + if err != nil { + return err + } + } + return nil +} + +func createSession(c net.Conn, flags DialFlags, host string, ctx *Ctx, + session []byte) (*Conn, error) { + conn, err := Client(c, ctx) + if err != nil { + return nil, err + } + if session != nil { + err := conn.setSession(session) if err != nil { conn.Close() return nil, err } } + if err := handshake(conn, host, flags); err != nil { + conn.Close() + return nil, err + } return conn, nil } diff --git a/net_test.go b/net_test.go new file mode 100644 index 00000000..15bca3c1 --- /dev/null +++ b/net_test.go @@ -0,0 +1,98 @@ +package openssl_test + +import ( + "context" + "crypto/rand" + "io" + "net" + "sync" + "testing" + "time" + + "github.com/tarantool/go-openssl" +) + +func sslConnect(t *testing.T, ssl_listener net.Listener) { + for { + var err error + conn, err := ssl_listener.Accept() + if err != nil { + t.Errorf("failed accept: %s", err) + continue + } + io.Copy(conn, io.LimitReader(rand.Reader, 1024)) + break + } +} + +func TestDial(t *testing.T) { + ctx := openssl.GetCtx(t) + if err := ctx.SetCipherList("AES128-SHA"); err != nil { + t.Fatal(err) + } + ssl_listener, err := openssl.Listen("tcp", "localhost:0", ctx) + if err != nil { + t.Fatal(err) + } + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + sslConnect(t, ssl_listener) + wg.Done() + }() + + client, err := openssl.Dial(ssl_listener.Addr().Network(), + ssl_listener.Addr().String(), ctx, openssl.InsecureSkipHostVerification) + + wg.Wait() + + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + n, err := io.Copy(io.Discard, io.LimitReader(client, 1024)) + if n != 1024 { + if n == 0 { + t.Fatal("client is closed after creation") + } + t.Fatalf("client lost some bytes, expected %d, got %d", 1024, n) + } +} + +func TestDialTimeout(t *testing.T) { + ctx := openssl.GetCtx(t) + if err := ctx.SetCipherList("AES128-SHA"); err != nil { + t.Fatal(err) + } + ssl_listener, err := openssl.Listen("tcp", "localhost:0", ctx) + if err != nil { + t.Fatal(err) + } + + client, err := openssl.DialTimeout(ssl_listener.Addr().Network(), + ssl_listener.Addr().String(), time.Nanosecond, ctx, 0) + + if client != nil || err == nil { + t.Fatalf("expected error") + } +} + +func TestDialContext(t *testing.T) { + ctx := openssl.GetCtx(t) + if err := ctx.SetCipherList("AES128-SHA"); err != nil { + t.Fatal(err) + } + ssl_listener, err := openssl.Listen("tcp", "localhost:0", ctx) + if err != nil { + t.Fatal(err) + } + + cancelCtx, cancel := context.WithCancel(context.Background()) + cancel() + client, err := openssl.DialContext(cancelCtx, ssl_listener.Addr().Network(), + ssl_listener.Addr().String(), ctx, 0) + + if client != nil || err == nil { + t.Fatalf("expected error") + } +} diff --git a/ssl_test.go b/ssl_test.go index b99e57ec..5eb0c514 100644 --- a/ssl_test.go +++ b/ssl_test.go @@ -738,7 +738,7 @@ func TestStdlibLotsOfConns(t *testing.T) { }) } -func getCtx(t *testing.T) *Ctx { +func GetCtx(t *testing.T) *Ctx { ctx, err := NewCtx() if err != nil { t.Fatal(err) @@ -761,7 +761,7 @@ func getCtx(t *testing.T) *Ctx { } func TestOpenSSLLotsOfConns(t *testing.T) { - ctx := getCtx(t) + ctx := GetCtx(t) if err := ctx.SetCipherList("AES128-SHA"); err != nil { t.Fatal(err) } @@ -928,7 +928,7 @@ func TestOpenSSLLotsOfConnsWithFail(t *testing.T) { t.Run(name, func(t *testing.T) { LotsOfConns(t, 1024*64, 10, 100, 0*time.Second, func(l net.Listener) net.Listener { - return NewListener(l, getCtx(t)) + return NewListener(l, GetCtx(t)) }, func(c net.Conn) (net.Conn, error) { return Client(c, getClientCtx(t)) })