Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dial: add DialContext function #10

Merged
merged 2 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions .github/actions/go-test-setup/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,6 @@ runs:
shell: bash
run: |
echo 'CGO_ENABLED=1' >> $GITHUB_ENV
- name: Windows setup
shell: bash
if: ${{ runner.os == 'Windows' }}
run: |
pacman -S --noconfirm mingw-w64-x86_64-toolchain mingw-w64-i686-toolchain
echo '/c/msys64/mingw64/bin' >> $GITHUB_PATH
echo 'PATH_386=/c/msys64/mingw32/bin:${{ env.PATH_386 }}' >> $GITHUB_ENV
- name: Linux setup
shell: bash
if: ${{ runner.os == 'Linux' }}
Expand Down
10 changes: 1 addition & 9 deletions .github/workflows/go-test-ubuntu-22.04.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,6 @@ jobs:
run: |
go version
go env
- name: Use msys2 on windows
if: startsWith(matrix.os, 'windows')
shell: bash
# The executable for msys2 is also called bash.cmd
# https://github.com/actions/virtual-environments/blob/main/images/win/Windows2019-Readme.md#shells
# If we prepend its location to the PATH
# subsequent 'shell: bash' steps will use msys2 instead of gitbash
run: echo "C:/msys64/usr/bin" >> $GITHUB_PATH
- name: Run repo-specific setup
uses: ./.github/actions/go-test-setup
if: hashFiles('./.github/actions/go-test-setup') != ''
Expand All @@ -55,7 +47,7 @@ jobs:
export "PATH=${{ env.PATH_386 }}:$PATH"
go test -v ./...
- name: Run tests with race detector
if: startsWith(matrix.os, 'ubuntu') # speed things up. Windows and OSX VMs are slow
if: startsWith(matrix.os, 'ubuntu') # speed things up. OSX VMs is slow
uses: protocol/[email protected]
with:
run: go test -v -race ./...
Expand Down
12 changes: 2 additions & 10 deletions .github/workflows/go-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [ "ubuntu", "windows", "macos" ]
os: [ "ubuntu", "macos" ]
go: [ "1.18.x", "1.19.x" ]
env:
COVERAGES: ""
Expand All @@ -26,14 +26,6 @@ jobs:
run: |
go version
go env
- name: Use msys2 on windows
if: ${{ matrix.os == 'windows' }}
shell: bash
# The executable for msys2 is also called bash.cmd
# https://github.com/actions/virtual-environments/blob/main/images/win/Windows2019-Readme.md#shells
# If we prepend its location to the PATH
# subsequent 'shell: bash' steps will use msys2 instead of gitbash
run: echo "C:/msys64/usr/bin" >> $GITHUB_PATH
- name: Run repo-specific setup
uses: ./.github/actions/go-test-setup
if: hashFiles('./.github/actions/go-test-setup') != ''
Expand All @@ -54,7 +46,7 @@ jobs:
export "PATH=${{ env.PATH_386 }}:$PATH"
go test -v -shuffle=on ./...
- name: Run tests with race detector
if: ${{ matrix.os == 'ubuntu' }} # speed things up. Windows and OSX VMs are slow
if: ${{ matrix.os == 'ubuntu' }} # speed things up. OSX VMs is slow
uses: protocol/[email protected]
with:
run: go test -v -race ./...
Expand Down
133 changes: 99 additions & 34 deletions net.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package openssl

import (
"context"
"errors"
"net"
"time"
Expand Down Expand Up @@ -77,8 +78,8 @@ const (
// some certs to the certificate store of the client context you're using.
// This library is not nice enough to use the system certificate store by
// default for you yet.
func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) {
return DialSession(network, addr, ctx, flags, nil)
func Dial(network, addr string, sslCtx *Ctx, flags DialFlags) (*Conn, error) {
return DialSession(network, addr, sslCtx, flags, nil)
}

// DialTimeout acts like Dial but takes a timeout for network dial.
Expand All @@ -87,10 +88,57 @@ func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) {
//
// See func Dial for a description of the network, addr, ctx and flags
// parameters.
func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx,
func DialTimeout(network, addr string, timeout time.Duration, sslCtx *Ctx,
flags DialFlags) (*Conn, error) {
d := net.Dialer{Timeout: timeout}
return dialSession(d, network, addr, ctx, flags, nil)
host, err := parseHost(addr)
if err != nil {
return nil, err
}

conn, err := net.DialTimeout(network, addr, timeout)
if err != nil {
return nil, err
}
sslCtx, err = prepareCtx(sslCtx)
if err != nil {
conn.Close()
return nil, err
}
client, err := createSession(conn, flags, host, sslCtx, nil)
if err != nil {
conn.Close()
}
return client, err
}

// DialContext acts like Dial but takes a context for network dial.
//
// The context includes only network dial. It does not include OpenSSL calls.
//
// See func Dial for a description of the network, addr, ctx and flags
// parameters.
func DialContext(ctx context.Context, network, addr string,
sslCtx *Ctx, flags DialFlags) (*Conn, error) {
host, err := parseHost(addr)
if err != nil {
return nil, err
}

dialer := net.Dialer{}
conn, err := dialer.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
sslCtx, err = prepareCtx(sslCtx)
if err != nil {
conn.Close()
return nil, err
}
client, err := createSession(conn, flags, host, sslCtx, nil)
if err != nil {
conn.Close()
}
return client, err
}

// DialSession will connect to network/address and then wrap the corresponding
Expand All @@ -106,61 +154,78 @@ func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx,
//
// If session is not nil it will be used to resume the tls state. The session
// can be retrieved from the GetSession method on the Conn.
func DialSession(network, addr string, ctx *Ctx, flags DialFlags,
func DialSession(network, addr string, sslCtx *Ctx, flags DialFlags,
session []byte) (*Conn, error) {
var d net.Dialer
return dialSession(d, network, addr, ctx, flags, session)
}

func dialSession(d net.Dialer, network, addr string, ctx *Ctx, flags DialFlags,
session []byte) (*Conn, error) {
host, _, err := net.SplitHostPort(addr)
host, err := parseHost(addr)
if err != nil {
return nil, err
}
if ctx == nil {
var err error
ctx, err = NewCtx()
if err != nil {
return nil, err
}
// TODO: use operating system default certificate chain?
}

c, err := d.Dial(network, addr)
conn, err := net.Dial(network, addr)
if err != nil {
return nil, err
}
conn, err := Client(c, ctx)
sslCtx, err = prepareCtx(sslCtx)
if err != nil {
c.Close()
conn.Close()
return nil, err
}
if session != nil {
err := conn.setSession(session)
if err != nil {
c.Close()
return nil, err
}
client, err := createSession(conn, flags, host, sslCtx, session)
if err != nil {
conn.Close()
}
return client, err
}

func prepareCtx(sslCtx *Ctx) (*Ctx, error) {
if sslCtx == nil {
return NewCtx()
}
return sslCtx, nil
}

func parseHost(addr string) (string, error) {
host, _, err := net.SplitHostPort(addr)
return host, err
}

func handshake(conn *Conn, host string, flags DialFlags) error {
var err error
if flags&DisableSNI == 0 {
err = conn.SetTlsExtHostName(host)
if err != nil {
conn.Close()
return nil, err
return err
}
}
err = conn.Handshake()
if err != nil {
conn.Close()
return nil, err
return err
}
if flags&InsecureSkipHostVerification == 0 {
err = conn.VerifyHostname(host)
if err != nil {
return err
}
}
return nil
}

func createSession(c net.Conn, flags DialFlags, host string, sslCtx *Ctx,
session []byte) (*Conn, error) {
conn, err := Client(c, sslCtx)
if err != nil {
return nil, err
}
if session != nil {
err := conn.setSession(session)
if err != nil {
conn.Close()
return nil, err
}
}
if err := handshake(conn, host, flags); err != nil {
conn.Close()
return nil, err
}
return conn, nil
}
101 changes: 101 additions & 0 deletions net_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package openssl_test

import (
"context"
"crypto/rand"
"io"
"net"
"sync"
"testing"
"time"

"github.com/tarantool/go-openssl"
)

func sslConnect(t *testing.T, ssl_listener net.Listener) {
for {
var err error
conn, err := ssl_listener.Accept()
if err != nil {
t.Errorf("failed accept: %s", err)
continue
}
io.Copy(conn, io.LimitReader(rand.Reader, 1024))
break
}
}

func TestDial(t *testing.T) {
ctx := openssl.GetCtx(t)
if err := ctx.SetCipherList("AES128-SHA"); err != nil {
t.Fatal(err)
}
ssl_listener, err := openssl.Listen("tcp", "localhost:0", ctx)
if err != nil {
t.Fatal(err)
}

wg := sync.WaitGroup{}
wg.Add(1)
go func() {
sslConnect(t, ssl_listener)
wg.Done()
}()

client, err := openssl.Dial(ssl_listener.Addr().Network(),
ssl_listener.Addr().String(), ctx, openssl.InsecureSkipHostVerification)

wg.Wait()

if err != nil {
t.Fatalf("unexpected err: %v", err)
}
n, err := io.Copy(io.Discard, io.LimitReader(client, 1024))
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
if n != 1024 {
if n == 0 {
t.Fatal("client is closed after creation")
}
t.Fatalf("client lost some bytes, expected %d, got %d", 1024, n)
}
}

func TestDialTimeout(t *testing.T) {
ctx := openssl.GetCtx(t)
if err := ctx.SetCipherList("AES128-SHA"); err != nil {
t.Fatal(err)
}
ssl_listener, err := openssl.Listen("tcp", "localhost:0", ctx)
if err != nil {
t.Fatal(err)
}

client, err := openssl.DialTimeout(ssl_listener.Addr().Network(),
ssl_listener.Addr().String(), time.Nanosecond, ctx, 0)

if client != nil || err == nil {
t.Fatalf("expected error")
}
}

func TestDialContext(t *testing.T) {
ctx := openssl.GetCtx(t)
if err := ctx.SetCipherList("AES128-SHA"); err != nil {
t.Fatal(err)
}
ssl_listener, err := openssl.Listen("tcp", "localhost:0", ctx)
if err != nil {
t.Fatal(err)
}

cancelCtx, cancel := context.WithCancel(context.Background())
cancel()
client, err := openssl.DialContext(cancelCtx, ssl_listener.Addr().Network(),
ssl_listener.Addr().String(), ctx, 0)

if client != nil || err == nil {
t.Fatalf("expected error")
}
}
6 changes: 3 additions & 3 deletions ssl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,7 @@ func TestStdlibLotsOfConns(t *testing.T) {
})
}

func getCtx(t *testing.T) *Ctx {
func GetCtx(t *testing.T) *Ctx {
ctx, err := NewCtx()
if err != nil {
t.Fatal(err)
Expand All @@ -761,7 +761,7 @@ func getCtx(t *testing.T) *Ctx {
}

func TestOpenSSLLotsOfConns(t *testing.T) {
ctx := getCtx(t)
ctx := GetCtx(t)
if err := ctx.SetCipherList("AES128-SHA"); err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -928,7 +928,7 @@ func TestOpenSSLLotsOfConnsWithFail(t *testing.T) {
t.Run(name, func(t *testing.T) {
LotsOfConns(t, 1024*64, 10, 100, 0*time.Second,
func(l net.Listener) net.Listener {
return NewListener(l, getCtx(t))
return NewListener(l, GetCtx(t))
}, func(c net.Conn) (net.Conn, error) {
return Client(c, getClientCtx(t))
})
Expand Down