diff --git a/packetbeat/protos/tcp/tcp_test.go b/packetbeat/protos/tcp/tcp_test.go index ae264060ba69..37d8aee523a2 100644 --- a/packetbeat/protos/tcp/tcp_test.go +++ b/packetbeat/protos/tcp/tcp_test.go @@ -1,6 +1,7 @@ package tcp import ( + "fmt" "math/rand" "net" "testing" @@ -23,12 +24,28 @@ const ( type TestProtocol struct { Ports []int + + init func(testMode bool, results publisher.Client) error + parse func(*protos.Packet, *common.TcpTuple, uint8, protos.ProtocolData) protos.ProtocolData + onFin func(*common.TcpTuple, uint8, protos.ProtocolData) protos.ProtocolData + gap func(*common.TcpTuple, uint8, int, protos.ProtocolData) (protos.ProtocolData, bool) } -var _ protos.ProtocolPlugin = &TestProtocol{} +var _ protos.ProtocolPlugin = &TestProtocol{ + init: func(m bool, r publisher.Client) error { return nil }, + parse: func(p *protos.Packet, t *common.TcpTuple, d uint8, priv protos.ProtocolData) protos.ProtocolData { + return priv + }, + onFin: func(t *common.TcpTuple, d uint8, p protos.ProtocolData) protos.ProtocolData { + return p + }, + gap: func(t *common.TcpTuple, d uint8, b int, p protos.ProtocolData) (protos.ProtocolData, bool) { + return p, true + }, +} func (proto *TestProtocol) Init(test_mode bool, results publisher.Client) error { - return nil + return proto.init(test_mode, results) } func (proto TestProtocol) GetPorts() []int { @@ -37,17 +54,17 @@ func (proto TestProtocol) GetPorts() []int { func (proto TestProtocol) Parse(pkt *protos.Packet, tcptuple *common.TcpTuple, dir uint8, private protos.ProtocolData) protos.ProtocolData { - return private + return proto.parse(pkt, tcptuple, dir, private) } func (proto TestProtocol) ReceivedFin(tcptuple *common.TcpTuple, dir uint8, private protos.ProtocolData) protos.ProtocolData { - return private + return proto.onFin(tcptuple, dir, private) } func (proto TestProtocol) GapInStream(tcptuple *common.TcpTuple, dir uint8, nbytes int, private protos.ProtocolData) (priv protos.ProtocolData, drop bool) { - return private, true + return proto.gap(tcptuple, dir, nbytes, private) } func (proto TestProtocol) ConnectionTimeout() time.Duration { @@ -149,6 +166,47 @@ func (p protocols) GetAllTcp() map[protos.Protocol]protos.TcpProtocolPlugin func (p protocols) GetAllUdp() map[protos.Protocol]protos.UdpProtocolPlugin { return nil } func (p protocols) Register(proto protos.Protocol, plugin protos.ProtocolPlugin) { return } +func TestGapInStreamShouldDropState(t *testing.T) { + gap := 0 + var state []byte + + data1 := []byte{1, 2, 3, 4} + data2 := []byte{5, 6, 7, 8} + + tp := &TestProtocol{Ports: []int{ServerPort}} + tp.gap = func(t *common.TcpTuple, d uint8, n int, p protos.ProtocolData) (protos.ProtocolData, bool) { + fmt.Println("lost: %v\n", n) + gap += n + return p, true // drop state + } + tp.parse = func(p *protos.Packet, t *common.TcpTuple, d uint8, priv protos.ProtocolData) protos.ProtocolData { + if priv == nil { + state = nil + } + state = append(state, p.Payload...) + return state + } + + p := protocols{} + p.tcp = map[protos.Protocol]protos.TcpProtocolPlugin{ + protos.HttpProtocol: tp, + } + tcp, _ := NewTcp(p) + + addr := common.NewIpPortTuple(4, + net.ParseIP(ServerIp), ServerPort, + net.ParseIP(ClientIp), uint16(rand.Intn(65535))) + + hdr := &layers.TCP{} + tcp.Process(hdr, &protos.Packet{Ts: time.Now(), Tuple: addr, Payload: data1}) + hdr.Seq += uint32(len(data1) + 10) + tcp.Process(hdr, &protos.Packet{Ts: time.Now(), Tuple: addr, Payload: data2}) + + // validate + assert.Equal(t, 10, gap) + assert.Equal(t, data2, state) +} + // Benchmark that runs with parallelism to help find concurrency related // issues. To run with parallelism, the 'go test' cpu flag must be set // greater than 1, otherwise it just runs concurrently but not in parallel.