Skip to content

Commit

Permalink
ssh: wrap errors from client handshake
Browse files Browse the repository at this point in the history
When an error is returned by a user defined host key callback,
it is now possible to handle it using standard Go mechanisms
such as errors.Is or errors.As.

Fixes golang/go#61309
  • Loading branch information
paxan committed Jul 12, 2023
1 parent 2e82bdd commit d2a34d5
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
2 changes: 1 addition & 1 deletion ssh/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan

if err := conn.clientHandshake(addr, &fullConf); err != nil {
c.Close()
return nil, nil, nil, fmt.Errorf("ssh: handshake failed: %v", err)
return nil, nil, nil, fmt.Errorf("ssh: handshake failed: %w", err)
}
conn.mux = newMux(conn.transport)
return conn, conn.mux.incomingChannels, conn.mux.incomingRequests, nil
Expand Down
25 changes: 23 additions & 2 deletions ssh/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ package ssh
import (
"bytes"
"crypto/rand"
"errors"
"fmt"
"net"
"strings"
"testing"
)
Expand Down Expand Up @@ -207,9 +210,12 @@ func TestBannerCallback(t *testing.T) {
}

func TestNewClientConn(t *testing.T) {
errHostKeyMismatch := errors.New("host key mismatch")

for _, tt := range []struct {
name string
user string
name string
user string
simulateHostKeyMismatch HostKeyCallback
}{
{
name: "good user field for ConnMetadata",
Expand All @@ -219,6 +225,13 @@ func TestNewClientConn(t *testing.T) {
name: "empty user field for ConnMetadata",
user: "",
},
{
name: "host key mismatch",
user: "testuser",
simulateHostKeyMismatch: func(hostname string, remote net.Addr, key PublicKey) error {
return fmt.Errorf("%w: %s", errHostKeyMismatch, bytes.TrimSpace(MarshalAuthorizedKey(key)))
},
},
} {
t.Run(tt.name, func(t *testing.T) {
c1, c2, err := netPipe()
Expand All @@ -243,8 +256,16 @@ func TestNewClientConn(t *testing.T) {
},
HostKeyCallback: InsecureIgnoreHostKey(),
}

if tt.simulateHostKeyMismatch != nil {
clientConf.HostKeyCallback = tt.simulateHostKeyMismatch
}

clientConn, _, _, err := NewClientConn(c2, "", clientConf)
if err != nil {
if tt.simulateHostKeyMismatch != nil && errors.Is(err, errHostKeyMismatch) {
return
}
t.Fatal(err)
}

Expand Down

0 comments on commit d2a34d5

Please sign in to comment.