From 94b6652fbab537272587b4614983c3b8c16bb9d4 Mon Sep 17 00:00:00 2001 From: Randy Reddig Date: Tue, 20 Jun 2023 17:34:55 -0700 Subject: [PATCH] ssh: add (*Client).DialContext method Fixes golang/go#20288. --- ssh/tcpip.go | 38 +++++++++++++++++++++++++++++++++++--- ssh/tcpip_test.go | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 3 deletions(-) diff --git a/ssh/tcpip.go b/ssh/tcpip.go index 80d35f5ec1..86746534d1 100644 --- a/ssh/tcpip.go +++ b/ssh/tcpip.go @@ -5,6 +5,7 @@ package ssh import ( + "context" "errors" "fmt" "io" @@ -332,6 +333,37 @@ func (l *tcpListener) Addr() net.Addr { return l.laddr } +// DialContext initiates a connection to the addr from the remote host. +// If the supplied context is cancelled before the connection can be opened, +// ctx.Err() will be returned. +// The resulting connection has a zero LocalAddr() and RemoteAddr(). +func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + type connErr struct { + conn net.Conn + err error + } + ch := make(chan connErr) + go func() { + conn, err := c.Dial(n, addr) + select { + case ch <- connErr{conn, err}: + case <-ctx.Done(): + if conn != nil { + conn.Close() + } + } + }() + select { + case res := <-ch: + return res.conn, res.err + case <-ctx.Done(): + return nil, ctx.Err() + } +} + // Dial initiates a connection to the addr from the remote host. // The resulting connection has a zero LocalAddr() and RemoteAddr(). func (c *Client) Dial(n, addr string) (net.Conn, error) { @@ -347,7 +379,7 @@ func (c *Client) Dial(n, addr string) (net.Conn, error) { if err != nil { return nil, err } - ch, err = c.dial(net.IPv4zero.String(), 0, host, int(port)) + ch, err = c.dialTCP(net.IPv4zero.String(), 0, host, int(port)) if err != nil { return nil, err } @@ -393,7 +425,7 @@ func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) Port: 0, } } - ch, err := c.dial(laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port) + ch, err := c.dialTCP(laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port) if err != nil { return nil, err } @@ -412,7 +444,7 @@ type channelOpenDirectMsg struct { lport uint32 } -func (c *Client) dial(laddr string, lport int, raddr string, rport int) (Channel, error) { +func (c *Client) dialTCP(laddr string, lport int, raddr string, rport int) (Channel, error) { msg := channelOpenDirectMsg{ raddr: raddr, rport: uint32(rport), diff --git a/ssh/tcpip_test.go b/ssh/tcpip_test.go index f1265cb496..8fafeb5e26 100644 --- a/ssh/tcpip_test.go +++ b/ssh/tcpip_test.go @@ -5,7 +5,10 @@ package ssh import ( + "context" + "net" "testing" + "time" ) func TestAutoPortListenBroken(t *testing.T) { @@ -18,3 +21,40 @@ func TestAutoPortListenBroken(t *testing.T) { t.Errorf("version %q marked as broken", works) } } + +func TestClientImplementsDialContext(t *testing.T) { + type ContextDialer interface { + DialContext(context.Context, string, string) (net.Conn, error) + } + var _ ContextDialer = &Client{} +} + +func TestClientDialContextWithCancel(t *testing.T) { + c := &Client{} + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := c.DialContext(ctx, "tcp", "localhost:1000") + if err != context.Canceled { + t.Errorf("DialContext: got nil error, expected %v", context.Canceled) + } +} + +func TestClientDialContextWithDeadline(t *testing.T) { + c := &Client{} + ctx, cancel := context.WithDeadline(context.Background(), time.Now()) + defer cancel() + _, err := c.DialContext(ctx, "tcp", "localhost:1000") + if err != context.DeadlineExceeded { + t.Errorf("DialContext: got nil error, expected %v", context.DeadlineExceeded) + } +} + +func TestClientDialContextWithTimeout(t *testing.T) { + c := &Client{} + ctx, cancel := context.WithTimeout(context.Background(), 0) + defer cancel() + _, err := c.DialContext(ctx, "tcp", "localhost:1000") + if err != context.DeadlineExceeded { + t.Errorf("DialContext: got nil error, expected %v", context.DeadlineExceeded) + } +}