diff --git a/server/conn.go b/server/conn.go index 0f541376f8781..d4608f9c48adb 100644 --- a/server/conn.go +++ b/server/conn.go @@ -171,6 +171,41 @@ func (cc *clientConn) String() string { ) } +// authSwitchRequest is used when the client asked to speak something +// other than mysql_native_password. The server is allowed to ask +// the client to switch, so lets ask for mysql_native_password +// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest +func (cc *clientConn) authSwitchRequest(ctx context.Context) ([]byte, error) { + enclen := 1 + len("mysql_native_password") + 1 + len(cc.salt) + 1 + data := cc.alloc.AllocWithLen(4, enclen) + data = append(data, 0xfe) // switch request + data = append(data, []byte("mysql_native_password")...) + data = append(data, byte(0x00)) // requires null + data = append(data, cc.salt...) + data = append(data, 0) + err := cc.writePacket(data) + if err != nil { + logutil.Logger(ctx).Debug("write response to client failed", zap.Error(err)) + return nil, err + } + err = cc.flush(ctx) + if err != nil { + logutil.Logger(ctx).Debug("flush response to client failed", zap.Error(err)) + return nil, err + } + resp, err := cc.readPacket() + if err != nil { + err = errors.SuspendStack(err) + if errors.Cause(err) == io.EOF { + logutil.Logger(ctx).Warn("authSwitchRequest response fail due to connection has be closed by client-side") + } else { + logutil.Logger(ctx).Warn("authSwitchRequest response fail", zap.Error(err)) + } + return nil, err + } + return resp, nil +} + // handshake works like TCP handshake, but in a higher level, it first writes initial packet to client, // during handshake, client and server negotiate compatible features and do authentication. // After handshake, client can send sql query to server. @@ -317,6 +352,7 @@ type handshakeResponse41 struct { User string DBName string Auth []byte + AuthPlugin string Attrs map[string]string } @@ -438,14 +474,18 @@ func parseHandshakeResponseBody(ctx context.Context, packet *handshakeResponse41 if len(data[offset:]) > 0 { idx := bytes.IndexByte(data[offset:], 0) packet.DBName = string(data[offset : offset+idx]) - offset = offset + idx + 1 + offset += idx + 1 } } if packet.Capability&mysql.ClientPluginAuth > 0 { - // TODO: Support mysql.ClientPluginAuth, skip it now idx := bytes.IndexByte(data[offset:], 0) - offset = offset + idx + 1 + s := offset + f := offset + idx + if s < f { // handle unexpected bad packets + packet.AuthPlugin = string(data[s:f]) + } + offset += idx + 1 } if packet.Capability&mysql.ClientConnectAtts > 0 { @@ -564,6 +604,14 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Con return err } + // switching from other methods should work, but not tested + if resp.AuthPlugin == "caching_sha2_password" { + resp.Auth, err = cc.authSwitchRequest(ctx) + if err != nil { + logutil.Logger(ctx).Warn("attempt to send auth switch request packet failed", zap.Error(err)) + return err + } + } cc.capability = resp.Capability & cc.server.capability cc.user = resp.User cc.dbname = resp.DBName diff --git a/server/conn_test.go b/server/conn_test.go index ab0a02905504f..0b65a6be9de28 100644 --- a/server/conn_test.go +++ b/server/conn_test.go @@ -163,6 +163,32 @@ func (ts *ConnTestSuite) TestIssue1768(c *C) { c.Assert(len(p.Auth) > 0, IsTrue) } +func (ts *ConnTestSuite) TestAuthSwitchRequest(c *C) { + c.Parallel() + // this data is from a MySQL 8.0 client + data := []byte{ + 0x85, 0xa6, 0xff, 0x1, 0x0, 0x0, 0x0, 0x1, 0x21, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x72, 0x6f, + 0x6f, 0x74, 0x0, 0x0, 0x63, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x67, 0x5f, 0x73, 0x68, 0x61, + 0x32, 0x5f, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x0, 0x79, 0x4, 0x5f, 0x70, + 0x69, 0x64, 0x5, 0x37, 0x37, 0x30, 0x38, 0x36, 0x9, 0x5f, 0x70, 0x6c, 0x61, 0x74, 0x66, + 0x6f, 0x72, 0x6d, 0x6, 0x78, 0x38, 0x36, 0x5f, 0x36, 0x34, 0x3, 0x5f, 0x6f, 0x73, 0x5, + 0x4c, 0x69, 0x6e, 0x75, 0x78, 0xc, 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x6e, + 0x61, 0x6d, 0x65, 0x8, 0x6c, 0x69, 0x62, 0x6d, 0x79, 0x73, 0x71, 0x6c, 0x7, 0x6f, 0x73, + 0x5f, 0x75, 0x73, 0x65, 0x72, 0xa, 0x6e, 0x75, 0x6c, 0x6c, 0x6e, 0x6f, 0x74, 0x6e, 0x69, + 0x6c, 0xf, 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x76, 0x65, 0x72, 0x73, 0x69, + 0x6f, 0x6e, 0x6, 0x38, 0x2e, 0x30, 0x2e, 0x32, 0x31, 0xc, 0x70, 0x72, 0x6f, 0x67, 0x72, + 0x61, 0x6d, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x5, 0x6d, 0x79, 0x73, 0x71, 0x6c, + } + + var resp handshakeResponse41 + pos, err := parseHandshakeResponseHeader(context.Background(), &resp, data) + c.Assert(err, IsNil) + err = parseHandshakeResponseBody(context.Background(), &resp, data, pos) + c.Assert(err, IsNil) + c.Assert(resp.AuthPlugin == "caching_sha2_password", IsTrue) +} + func (ts *ConnTestSuite) TestInitialHandshake(c *C) { c.Parallel() var outBuffer bytes.Buffer