Skip to content

Commit

Permalink
Stash the connection attributes on the conn struct
Browse files Browse the repository at this point in the history
  • Loading branch information
jscheinblum committed Nov 2, 2023
1 parent 9ed7a3f commit db3262e
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 20 deletions.
2 changes: 2 additions & 0 deletions go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ type Conn struct {
// It is set during the initial handshake.
UserData Getter

ConnectionAttributes map[string]string

bufferedReader *bufio.Reader
flushTimer *time.Timer
header [packetHeaderSize]byte
Expand Down
42 changes: 24 additions & 18 deletions go/mysql/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,11 +354,12 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti
}
return
}
user, clientAuthMethod, clientAuthResponse, err := l.parseClientHandshakePacket(c, true, response)
user, clientAuthMethod, clientAuthResponse, clientAttributes, err := l.parseClientHandshakePacket(c, true, response)
if err != nil {
log.Errorf("Cannot parse client handshake response from %s: %v", c, err)
return
}
c.ConnectionAttributes = clientAttributes

c.recycleReadPacket()

Expand All @@ -371,11 +372,12 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti
}

// Returns copies of the data, so we can recycle the buffer.
user, clientAuthMethod, clientAuthResponse, err = l.parseClientHandshakePacket(c, false, response)
user, clientAuthMethod, clientAuthResponse, clientAttributes, err = l.parseClientHandshakePacket(c, false, response)
if err != nil {
log.Errorf("Cannot parse post-SSL client handshake response from %s: %v", c, err)
return
}
c.ConnectionAttributes = clientAttributes
c.recycleReadPacket()

if con, ok := c.conn.(*tls.Conn); ok {
Expand Down Expand Up @@ -638,16 +640,16 @@ func (c *Conn) writeHandshakeV10(serverVersion string, authServer AuthServer, en
// parseClientHandshakePacket parses the handshake sent by the client.
// Returns the username, auth method, auth data, error.
// The original data is not pointed at, and can be freed.
func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []byte) (string, AuthMethodDescription, []byte, error) {
func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []byte) (string, AuthMethodDescription, []byte, map[string]string, error) {
pos := 0

// Client flags, 4 bytes.
clientFlags, pos, ok := readUint32(data, pos)
if !ok {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read client flags")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read client flags")
}
if clientFlags&CapabilityClientProtocol41 == 0 {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: only support protocol 4.1")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: only support protocol 4.1")
}

// Remember a subset of the capabilities, so we can use them
Expand All @@ -666,13 +668,13 @@ func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []by
// See doc.go for more information.
_, pos, ok = readUint32(data, pos)
if !ok {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read maxPacketSize")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read maxPacketSize")
}

// Character set. Need to handle it.
characterSet, pos, ok := readByte(data, pos)
if !ok {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read characterSet")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read characterSet")
}
c.CharacterSet = collations.ID(characterSet)

Expand All @@ -686,13 +688,13 @@ func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []by
c.conn = conn
c.bufferedReader.Reset(conn)
c.Capabilities |= CapabilityClientSSL
return "", "", nil, nil
return "", "", nil, nil, nil
}

// username
username, pos, ok := readNullString(data, pos)
if !ok {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read username")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read username")
}

// auth-response can have three forms.
Expand All @@ -701,29 +703,29 @@ func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []by
var l uint64
l, pos, ok = readLenEncInt(data, pos)
if !ok {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response variable length")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response variable length")
}
authResponse, pos, ok = readBytesCopy(data, pos, int(l))
if !ok {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response")
}

} else if clientFlags&CapabilityClientSecureConnection != 0 {
var l byte
l, pos, ok = readByte(data, pos)
if !ok {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response length")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response length")
}

authResponse, pos, ok = readBytesCopy(data, pos, int(l))
if !ok {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response")
}
} else {
a := ""
a, pos, ok = readNullString(data, pos)
if !ok {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read auth-response")
}
authResponse = []byte(a)
}
Expand All @@ -733,7 +735,7 @@ func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []by
dbname := ""
dbname, pos, ok = readNullString(data, pos)
if !ok {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read dbname")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read dbname")
}
c.schemaName = dbname
}
Expand All @@ -744,7 +746,7 @@ func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []by
var authMethodStr string
authMethodStr, pos, ok = readNullString(data, pos)
if !ok {
return "", "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read authMethod")
return "", "", nil, nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "parseClientHandshakePacket: can't read authMethod")
}
// The JDBC driver sometimes sends an empty string as the auth method when it wants to use mysql_native_password
if authMethodStr != "" {
Expand All @@ -753,13 +755,17 @@ func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []by
}

// Decode connection attributes send by the client
var client_attibutes map[string]string
if clientFlags&CapabilityClientConnAttr != 0 {
if _, _, err := parseConnAttrs(data, pos); err != nil {
ca, _, err := parseConnAttrs(data, pos)
if err != nil {
log.Warningf("Decode connection attributes send by the client: %v", err)
}

client_attibutes = ca
}

return username, AuthMethodDescription(authMethod), authResponse, nil
return username, AuthMethodDescription(authMethod), authResponse, client_attibutes, nil
}

func parseConnAttrs(data []byte, pos int) (map[string]string, int, error) {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgateproxy/mysql_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ func (ph *proxyHandler) session(c *mysql.Conn) *vtgateconn.VTGateSession {
}

var err error
session, err = ph.proxy.NewSession(options)
session, err = ph.proxy.NewSession(options, c.ConnectionAttributes)
if err != nil {
log.Errorf("error creating new session for %s: %v", c.GetRawConn().RemoteAddr().String(), err)
}
Expand Down
8 changes: 7 additions & 1 deletion go/vt/vtgateproxy/vtgateproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package vtgateproxy
import (
"context"
"flag"
"fmt"
"io"
"time"

Expand Down Expand Up @@ -68,11 +69,16 @@ func (proxy *VTGateProxy) connect(ctx context.Context) error {
return nil
}

func (proxy *VTGateProxy) NewSession(options *querypb.ExecuteOptions) (*vtgateconn.VTGateSession, error) {
func (proxy *VTGateProxy) NewSession(options *querypb.ExecuteOptions, connectionAttributes map[string]string) (*vtgateconn.VTGateSession, error) {
if proxy.conn == nil {
return nil, vterrors.Errorf(vtrpcpb.Code_UNAVAILABLE, "not connnected")
}

target, ok := connectionAttributes["target"]
if ok {
fmt.Printf("Creating new session from upstream provided target string: %v\n", target)
}

// XXX/demmer handle schemaName?
return proxy.conn.Session("", options), nil
}
Expand Down

0 comments on commit db3262e

Please sign in to comment.