diff --git a/protos/mysql/mysql.go b/protos/mysql/mysql.go index d0d979b56220..312725c59e27 100644 --- a/protos/mysql/mysql.go +++ b/protos/mysql/mysql.go @@ -80,7 +80,7 @@ type MysqlStream struct { data []byte parseOffset int - parseState int + parseState parseState isClient bool message *MysqlMessage @@ -91,13 +91,28 @@ const ( TransactionTimeout = 10 * 1e9 ) +type parseState int + const ( - MysqlStateStart = iota - MysqlStateEatMessage - MysqlStateEatFields - MysqlStateEatRows + mysqlStateStart parseState = iota + mysqlStateEatMessage + mysqlStateEatFields + mysqlStateEatRows + + MysqlStateMax ) +var stateStrings []string = []string{ + "Start", + "EatMessage", + "EatFields", + "EatRows", +} + +func (state parseState) String() string { + return stateStrings[state] +} + type Mysql struct { // config @@ -165,7 +180,7 @@ func (mysql *Mysql) Init(test_mode bool, results chan common.MapStr) error { func (stream *MysqlStream) PrepareForNewMessage() { stream.data = stream.data[stream.message.end:] - stream.parseState = MysqlStateStart + stream.parseState = mysqlStateStart stream.parseOffset = 0 stream.isClient = false stream.message = nil @@ -173,12 +188,12 @@ func (stream *MysqlStream) PrepareForNewMessage() { func mysqlMessageParser(s *MysqlStream) (bool, bool) { - logp.Debug("mysqldetailed", "MySQL parser called. parseState = %d", s.parseState) + logp.Debug("mysqldetailed", "MySQL parser called. parseState = %s", s.parseState) m := s.message for s.parseOffset < len(s.data) { switch s.parseState { - case MysqlStateStart: + case mysqlStateStart: m.start = s.parseOffset if len(s.data[s.parseOffset:]) < 5 { logp.Warn("MySQL Message too short. Ignore it.") @@ -198,12 +213,12 @@ func mysqlMessageParser(s *MysqlStream) (bool, bool) { // parse request m.IsRequest = true m.start = s.parseOffset - s.parseState = MysqlStateEatMessage + s.parseState = mysqlStateEatMessage } else { // ignore command m.IgnoreMessage = true - s.parseState = MysqlStateEatMessage + s.parseState = mysqlStateEatMessage } if !s.isClient { @@ -217,23 +232,23 @@ func mysqlMessageParser(s *MysqlStream) (bool, bool) { if uint8(hdr[4]) == 0x00 || uint8(hdr[4]) == 0xfe { logp.Debug("mysqldetailed", "Received OK response") m.start = s.parseOffset - s.parseState = MysqlStateEatMessage + s.parseState = mysqlStateEatMessage m.IsOK = true } else if uint8(hdr[4]) == 0xff { logp.Debug("mysqldetailed", "Received ERR response") m.start = s.parseOffset - s.parseState = MysqlStateEatMessage + s.parseState = mysqlStateEatMessage m.IsError = true } else if m.PacketLength == 1 { logp.Debug("mysqldetailed", "Query response. Number of fields %d", uint8(hdr[4])) m.NumberOfFields = int(hdr[4]) m.start = s.parseOffset s.parseOffset += 5 - s.parseState = MysqlStateEatFields + s.parseState = mysqlStateEatFields } else { // something else. ignore m.IgnoreMessage = true - s.parseState = MysqlStateEatMessage + s.parseState = mysqlStateEatMessage } } else { @@ -243,7 +258,7 @@ func mysqlMessageParser(s *MysqlStream) (bool, bool) { } break - case MysqlStateEatMessage: + case mysqlStateEatMessage: if len(s.data[s.parseOffset:]) >= int(m.PacketLength)+4 { s.parseOffset += 4 //header s.parseOffset += int(m.PacketLength) @@ -289,7 +304,7 @@ func mysqlMessageParser(s *MysqlStream) (bool, bool) { return true, false } - case MysqlStateEatFields: + case mysqlStateEatFields: if len(s.data[s.parseOffset:]) < 4 { // wait for more return true, false @@ -307,7 +322,7 @@ func mysqlMessageParser(s *MysqlStream) (bool, bool) { // EOF marker s.parseOffset += int(m.PacketLength) - s.parseState = MysqlStateEatRows + s.parseState = mysqlStateEatRows } else { _ /* catalog */, off, complete, err := read_lstring(s.data, s.parseOffset) if !complete { @@ -351,7 +366,7 @@ func mysqlMessageParser(s *MysqlStream) (bool, bool) { } break - case MysqlStateEatRows: + case mysqlStateEatRows: if len(s.data[s.parseOffset:]) < 4 { // wait for more return true, false @@ -627,6 +642,8 @@ func (mysql *Mysql) parseMysqlResponse(data []byte) ([]string, [][]string) { } else { offset := 5 + logp.Debug("mysql", "Data len: %d", len(data)) + // Read fields for { length = read_length(data, offset) diff --git a/protos/mysql/mysql_test.go b/protos/mysql/mysql_test.go index 9b951bd71eb3..42bd17c3399b 100644 --- a/protos/mysql/mysql_test.go +++ b/protos/mysql/mysql_test.go @@ -6,6 +6,7 @@ import ( "github.com/elastic/libbeat/common" "github.com/elastic/libbeat/logp" + "github.com/stretchr/testify/assert" "github.com/elastic/packetbeat/protos" @@ -18,6 +19,15 @@ func MysqlModForTests() *Mysql { return &mysql } +func Test_parseStateNames(t *testing.T) { + assert.Equal(t, "Start", mysqlStateStart.String()) + assert.Equal(t, "EatMessage", mysqlStateEatMessage.String()) + assert.Equal(t, "EatFields", mysqlStateEatFields.String()) + assert.Equal(t, "EatRows", mysqlStateEatRows.String()) + + assert.NotNil(t, (MysqlStateMax - 1).String()) +} + func TestMySQLParser_simpleRequest(t *testing.T) { data := []byte(