Skip to content

Commit

Permalink
x/crypto/ssh: interpret disconnect message as error in the transport …
Browse files Browse the repository at this point in the history
…layer.

This ensures that higher level parts (e.g. the client authentication
loop) never have to deal with disconnect messages.

Fixes coreos/fleet#565.

Change-Id: Ie164b6c4b0982c7ed9af6d3bf91697a78a911a20
Reviewed-on: https://go-review.googlesource.com/20801
Reviewed-by: Anton Khramov <[email protected]>
Reviewed-by: Adam Langley <[email protected]>
  • Loading branch information
desdeel2d0m authored and agl committed Mar 29, 2016
1 parent 91c9575 commit e85dbb4
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 57 deletions.
2 changes: 0 additions & 2 deletions ssh/client_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,6 @@ func handleAuthResponse(c packetConn) (bool, []string, error) {
return false, msg.Methods, nil
case msgUserAuthSuccess:
return true, nil, nil
case msgDisconnect:
return false, nil, io.EOF
default:
return false, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0])
}
Expand Down
43 changes: 43 additions & 0 deletions ssh/handshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"errors"
"fmt"
"net"
"reflect"
"runtime"
"strings"
"sync"
Expand Down Expand Up @@ -413,3 +414,45 @@ func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int) {

wg.Wait()
}

func TestDisconnect(t *testing.T) {
if runtime.GOOS == "plan9" {
t.Skip("see golang.org/issue/7237")
}
checker := &testChecker{}
trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
if err != nil {
t.Fatalf("handshakePair: %v", err)
}

defer trC.Close()
defer trS.Close()

trC.writePacket([]byte{msgRequestSuccess, 0, 0})
errMsg := &disconnectMsg{
Reason: 42,
Message: "such is life",
}
trC.writePacket(Marshal(errMsg))
trC.writePacket([]byte{msgRequestSuccess, 0, 0})

packet, err := trS.readPacket()
if err != nil {
t.Fatalf("readPacket 1: %v", err)
}
if packet[0] != msgRequestSuccess {
t.Errorf("got packet %v, want packet type %d", packet, msgRequestSuccess)
}

_, err = trS.readPacket()
if err == nil {
t.Errorf("readPacket 2 succeeded")
} else if !reflect.DeepEqual(err, errMsg) {
t.Errorf("got error %#v, want %#v", err, errMsg)
}

_, err = trS.readPacket()
if err == nil {
t.Errorf("readPacket 3 succeeded")
}
}
2 changes: 1 addition & 1 deletion ssh/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ type disconnectMsg struct {
}

func (d *disconnectMsg) Error() string {
return fmt.Sprintf("ssh: disconnect reason %d: %s", d.Reason, d.Message)
return fmt.Sprintf("ssh: disconnect, reason %d: %s", d.Reason, d.Message)
}

// See RFC 4253, section 7.1.
Expand Down
26 changes: 0 additions & 26 deletions ssh/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,18 +175,6 @@ func (m *mux) ackRequest(ok bool, data []byte) error {
return m.sendMessage(globalRequestFailureMsg{Data: data})
}

// TODO(hanwen): Disconnect is a transport layer message. We should
// probably send and receive Disconnect somewhere in the transport
// code.

// Disconnect sends a disconnect message.
func (m *mux) Disconnect(reason uint32, message string) error {
return m.sendMessage(disconnectMsg{
Reason: reason,
Message: message,
})
}

func (m *mux) Close() error {
return m.conn.Close()
}
Expand Down Expand Up @@ -239,8 +227,6 @@ func (m *mux) onePacket() error {
case msgNewKeys:
// Ignore notification of key change.
return nil
case msgDisconnect:
return m.handleDisconnect(packet)
case msgChannelOpen:
return m.handleChannelOpen(packet)
case msgGlobalRequest, msgRequestSuccess, msgRequestFailure:
Expand All @@ -260,18 +246,6 @@ func (m *mux) onePacket() error {
return ch.handlePacket(packet)
}

func (m *mux) handleDisconnect(packet []byte) error {
var d disconnectMsg
if err := Unmarshal(packet, &d); err != nil {
return err
}

if debugMux {
log.Printf("caught disconnect: %v", d)
}
return &d
}

func (m *mux) handleGlobalPacket(packet []byte) error {
msg, err := decode(packet)
if err != nil {
Expand Down
23 changes: 0 additions & 23 deletions ssh/mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,6 @@ func TestMuxGlobalRequest(t *testing.T) {
ok, data, err)
}

clientMux.Disconnect(0, "")
if !seen {
t.Errorf("never saw 'peek' request")
}
Expand Down Expand Up @@ -378,28 +377,6 @@ func TestMuxChannelRequestUnblock(t *testing.T) {
}
}

func TestMuxDisconnect(t *testing.T) {
a, b := muxPair()
defer a.Close()
defer b.Close()

go func() {
for r := range b.incomingRequests {
r.Reply(true, nil)
}
}()

a.Disconnect(42, "whatever")
ok, _, err := a.SendRequest("hello", true, nil)
if ok || err == nil {
t.Errorf("got reply after disconnecting")
}
err = b.Wait()
if d, ok := err.(*disconnectMsg); !ok || d.Reason != 42 {
t.Errorf("got %#v, want disconnectMsg{Reason:42}", err)
}
}

func TestMuxCloseChannel(t *testing.T) {
r, w, mux := channelPair(t)
defer mux.Close()
Expand Down
25 changes: 20 additions & 5 deletions ssh/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,27 @@ func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) {
err = errors.New("ssh: zero length packet")
}

if len(packet) > 0 && packet[0] == msgNewKeys {
select {
case cipher := <-s.pendingKeyChange:
if len(packet) > 0 {
switch packet[0] {
case msgNewKeys:
select {
case cipher := <-s.pendingKeyChange:
s.packetCipher = cipher
default:
return nil, errors.New("ssh: got bogus newkeys message.")
default:
return nil, errors.New("ssh: got bogus newkeys message.")
}

case msgDisconnect:
// Transform a disconnect message into an
// error. Since this is lowest level at which
// we interpret message types, doing it here
// ensures that we don't have to handle it
// elsewhere.
var msg disconnectMsg
if err := Unmarshal(packet, &msg); err != nil {
return nil, err
}
return nil, &msg
}
}

Expand Down

0 comments on commit e85dbb4

Please sign in to comment.