diff --git a/packetbeat/CHANGELOG.md b/packetbeat/CHANGELOG.md index e1cc0add001f..fd7a2b2014d5 100644 --- a/packetbeat/CHANGELOG.md +++ b/packetbeat/CHANGELOG.md @@ -12,6 +12,7 @@ All notable changes to this project will be documented in this file based on the - Fix errors in redis parser when length prefixed strings contain sequences of CRLF. #402 - Fix errors in redis parser when dealing with nested arrays. #402 - Improve MongoDB message correlation. #377 +- Fix TCP connection state being reused after dropping due to gap in stream. #342 ### Added - Added redis pipelining support. #402 diff --git a/packetbeat/protos/tcp/tcp.go b/packetbeat/protos/tcp/tcp.go index f223ff08a731..2a34989b99c1 100644 --- a/packetbeat/protos/tcp/tcp.go +++ b/packetbeat/protos/tcp/tcp.go @@ -54,15 +54,15 @@ func (tcp *Tcp) decideProtocol(tuple *common.IpPortTuple) protos.Protocol { return protos.UnknownProtocol } -func (tcp *Tcp) getStream(k common.HashableIpPortTuple) *TcpStream { +func (tcp *Tcp) findStream(k common.HashableIpPortTuple) *TcpConnection { v := tcp.streams.Get(k) if v != nil { - return v.(*TcpStream) + return v.(*TcpConnection) } return nil } -type TcpStream struct { +type TcpConnection struct { id uint32 tuple *common.IpPortTuple protocol protos.Protocol @@ -75,117 +75,139 @@ type TcpStream struct { data protos.ProtocolData } -func (stream *TcpStream) String() string { +type TcpStream struct { + conn *TcpConnection + dir uint8 +} + +func (conn *TcpConnection) String() string { return fmt.Sprintf("TcpStream id[%d] tuple[%s] protocol[%s] lastSeq[%d %d]", - stream.id, stream.tuple, stream.protocol, stream.lastSeq[0], stream.lastSeq[1]) + conn.id, conn.tuple, conn.protocol, conn.lastSeq[0], conn.lastSeq[1]) } -func (stream *TcpStream) addPacket(pkt *protos.Packet, tcphdr *layers.TCP, original_dir uint8) { - mod := stream.tcp.protocols.GetTcp(stream.protocol) +func (stream *TcpStream) addPacket(pkt *protos.Packet, tcphdr *layers.TCP) { + conn := stream.conn + mod := conn.tcp.protocols.GetTcp(conn.protocol) if mod == nil { if isDebug { + protocol := conn.protocol debugf("Ignoring protocol for which we have no module loaded: %s", - stream.protocol) + protocol) } return } if len(pkt.Payload) > 0 { - stream.data = mod.Parse(pkt, &stream.tcptuple, original_dir, stream.data) + conn.data = mod.Parse(pkt, &conn.tcptuple, stream.dir, conn.data) } if tcphdr.FIN { - stream.data = mod.ReceivedFin(&stream.tcptuple, original_dir, stream.data) + conn.data = mod.ReceivedFin(&conn.tcptuple, stream.dir, conn.data) } } -func (stream *TcpStream) gapInStream(original_dir uint8, nbytes int) (drop bool) { - mod := stream.tcp.protocols.GetTcp(stream.protocol) - stream.data, drop = mod.GapInStream(&stream.tcptuple, original_dir, nbytes, stream.data) +func (stream *TcpStream) gapInStream(nbytes int) (drop bool) { + conn := stream.conn + mod := conn.tcp.protocols.GetTcp(conn.protocol) + conn.data, drop = mod.GapInStream(&conn.tcptuple, stream.dir, nbytes, conn.data) return drop } -func tcpSeqBefore(seq1 uint32, seq2 uint32) bool { - return int32(seq1-seq2) < 0 -} - -func tcpSeqBeforeEq(seq1 uint32, seq2 uint32) bool { - return int32(seq1-seq2) <= 0 -} - func (tcp *Tcp) Process(tcphdr *layers.TCP, pkt *protos.Packet) { - // This Recover should catch all exceptions in // protocol modules. defer logp.Recover("Process tcp exception") - stream := tcp.getStream(pkt.Tuple.Hashable()) - var original_dir uint8 = TcpDirectionOriginal - created := false - if stream == nil { - stream = tcp.getStream(pkt.Tuple.RevHashable()) - if stream == nil { - protocol := tcp.decideProtocol(&pkt.Tuple) - if protocol == protos.UnknownProtocol { - // don't follow - return - } - - timeout := time.Duration(0) - mod := tcp.protocols.GetTcp(protocol) - if mod != nil { - timeout = mod.ConnectionTimeout() - } - - if isDebug { - debugf("Stream doesn't exist, creating new") - } - - // create - stream = &TcpStream{id: tcp.getId(), tuple: &pkt.Tuple, protocol: protocol, tcp: tcp} - stream.tcptuple = common.TcpTupleFromIpPort(stream.tuple, stream.id) - tcp.streams.PutWithTimeout(pkt.Tuple.Hashable(), stream, timeout) - created = true - } else { - original_dir = TcpDirectionReverse - } + stream, created := tcp.getStream(pkt) + if stream.conn == nil { + return } + conn := stream.conn + tcp_start_seq := tcphdr.Seq tcp_seq := tcp_start_seq + uint32(len(pkt.Payload)) - + lastSeq := conn.lastSeq[stream.dir] if isDebug { debugf("pkt.start_seq=%v pkt.last_seq=%v stream.last_seq=%v (len=%d)", - tcp_start_seq, tcp_seq, stream.lastSeq[original_dir], len(pkt.Payload)) + tcp_start_seq, tcp_seq, lastSeq, len(pkt.Payload)) } - if len(pkt.Payload) > 0 && - stream.lastSeq[original_dir] != 0 { - - if tcpSeqBeforeEq(tcp_seq, stream.lastSeq[original_dir]) { + if len(pkt.Payload) > 0 && lastSeq != 0 { + if tcpSeqBeforeEq(tcp_seq, lastSeq) { if isDebug { - debugf("Ignoring what looks like a retransmitted segment. pkt.seq=%v len=%v stream.seq=%v", - tcphdr.Seq, len(pkt.Payload), stream.lastSeq[original_dir]) + debugf("Ignoring retransmitted segment. pkt.seq=%v len=%v stream.seq=%v", + tcphdr.Seq, len(pkt.Payload), lastSeq) } return } - if tcpSeqBefore(stream.lastSeq[original_dir], tcp_start_seq) { + if tcpSeqBefore(lastSeq, tcp_start_seq) { if !created { - logp.Warn("Gap in tcp stream. last_seq: %d, seq: %d", stream.lastSeq[original_dir], tcp_start_seq) - drop := stream.gapInStream(original_dir, - int(tcp_start_seq-stream.lastSeq[original_dir])) + gap := int(tcp_start_seq - lastSeq) + logp.Warn("Gap in tcp stream. last_seq: %d, seq: %d, gap: %d", lastSeq, tcp_start_seq, gap) + drop := stream.gapInStream(gap) if drop { if isDebug { - debugf("Dropping stream because of gap") + debugf("Dropping connection state because of gap") } - tcp.streams.Delete(stream.tuple.Hashable()) + + // drop application layer connection state and + // update stream_id for app layer analysers using stream_id for lookups + conn.id = tcp.getId() + conn.data = nil } } } } - stream.lastSeq[original_dir] = tcp_seq - stream.addPacket(pkt, tcphdr, original_dir) + conn.lastSeq[stream.dir] = tcp_seq + stream.addPacket(pkt, tcphdr) +} + +func (tcp *Tcp) getStream(pkt *protos.Packet) (stream TcpStream, created bool) { + if conn := tcp.findStream(pkt.Tuple.Hashable()); conn != nil { + return TcpStream{conn: conn, dir: TcpDirectionOriginal}, false + } + + if conn := tcp.findStream(pkt.Tuple.RevHashable()); conn != nil { + return TcpStream{conn: conn, dir: TcpDirectionReverse}, false + } + + protocol := tcp.decideProtocol(&pkt.Tuple) + if protocol == protos.UnknownProtocol { + // don't follow + return TcpStream{}, false + } + + var timeout time.Duration + mod := tcp.protocols.GetTcp(protocol) + if mod != nil { + timeout = mod.ConnectionTimeout() + } + + if isDebug { + t := pkt.Tuple + debugf("Connection src[%s:%d] dst[%s:%d] doesn't exist, creating new", + t.Src_ip.String(), t.Src_port, + t.Dst_ip.String(), t.Dst_port) + } + + conn := &TcpConnection{ + id: tcp.getId(), + tuple: &pkt.Tuple, + protocol: protocol, + tcp: tcp} + conn.tcptuple = common.TcpTupleFromIpPort(conn.tuple, conn.id) + tcp.streams.PutWithTimeout(pkt.Tuple.Hashable(), conn, timeout) + return TcpStream{conn: conn, dir: TcpDirectionOriginal}, true +} + +func tcpSeqBefore(seq1 uint32, seq2 uint32) bool { + return int32(seq1-seq2) < 0 +} + +func tcpSeqBeforeEq(seq1 uint32, seq2 uint32) bool { + return int32(seq1-seq2) <= 0 } func buildPortsMap(plugins map[protos.Protocol]protos.TcpProtocolPlugin) (map[uint16]protos.Protocol, error) { diff --git a/packetbeat/protos/tcp/tcp_test.go b/packetbeat/protos/tcp/tcp_test.go index ae264060ba69..37d8aee523a2 100644 --- a/packetbeat/protos/tcp/tcp_test.go +++ b/packetbeat/protos/tcp/tcp_test.go @@ -1,6 +1,7 @@ package tcp import ( + "fmt" "math/rand" "net" "testing" @@ -23,12 +24,28 @@ const ( type TestProtocol struct { Ports []int + + init func(testMode bool, results publisher.Client) error + parse func(*protos.Packet, *common.TcpTuple, uint8, protos.ProtocolData) protos.ProtocolData + onFin func(*common.TcpTuple, uint8, protos.ProtocolData) protos.ProtocolData + gap func(*common.TcpTuple, uint8, int, protos.ProtocolData) (protos.ProtocolData, bool) } -var _ protos.ProtocolPlugin = &TestProtocol{} +var _ protos.ProtocolPlugin = &TestProtocol{ + init: func(m bool, r publisher.Client) error { return nil }, + parse: func(p *protos.Packet, t *common.TcpTuple, d uint8, priv protos.ProtocolData) protos.ProtocolData { + return priv + }, + onFin: func(t *common.TcpTuple, d uint8, p protos.ProtocolData) protos.ProtocolData { + return p + }, + gap: func(t *common.TcpTuple, d uint8, b int, p protos.ProtocolData) (protos.ProtocolData, bool) { + return p, true + }, +} func (proto *TestProtocol) Init(test_mode bool, results publisher.Client) error { - return nil + return proto.init(test_mode, results) } func (proto TestProtocol) GetPorts() []int { @@ -37,17 +54,17 @@ func (proto TestProtocol) GetPorts() []int { func (proto TestProtocol) Parse(pkt *protos.Packet, tcptuple *common.TcpTuple, dir uint8, private protos.ProtocolData) protos.ProtocolData { - return private + return proto.parse(pkt, tcptuple, dir, private) } func (proto TestProtocol) ReceivedFin(tcptuple *common.TcpTuple, dir uint8, private protos.ProtocolData) protos.ProtocolData { - return private + return proto.onFin(tcptuple, dir, private) } func (proto TestProtocol) GapInStream(tcptuple *common.TcpTuple, dir uint8, nbytes int, private protos.ProtocolData) (priv protos.ProtocolData, drop bool) { - return private, true + return proto.gap(tcptuple, dir, nbytes, private) } func (proto TestProtocol) ConnectionTimeout() time.Duration { @@ -149,6 +166,47 @@ func (p protocols) GetAllTcp() map[protos.Protocol]protos.TcpProtocolPlugin func (p protocols) GetAllUdp() map[protos.Protocol]protos.UdpProtocolPlugin { return nil } func (p protocols) Register(proto protos.Protocol, plugin protos.ProtocolPlugin) { return } +func TestGapInStreamShouldDropState(t *testing.T) { + gap := 0 + var state []byte + + data1 := []byte{1, 2, 3, 4} + data2 := []byte{5, 6, 7, 8} + + tp := &TestProtocol{Ports: []int{ServerPort}} + tp.gap = func(t *common.TcpTuple, d uint8, n int, p protos.ProtocolData) (protos.ProtocolData, bool) { + fmt.Println("lost: %v\n", n) + gap += n + return p, true // drop state + } + tp.parse = func(p *protos.Packet, t *common.TcpTuple, d uint8, priv protos.ProtocolData) protos.ProtocolData { + if priv == nil { + state = nil + } + state = append(state, p.Payload...) + return state + } + + p := protocols{} + p.tcp = map[protos.Protocol]protos.TcpProtocolPlugin{ + protos.HttpProtocol: tp, + } + tcp, _ := NewTcp(p) + + addr := common.NewIpPortTuple(4, + net.ParseIP(ServerIp), ServerPort, + net.ParseIP(ClientIp), uint16(rand.Intn(65535))) + + hdr := &layers.TCP{} + tcp.Process(hdr, &protos.Packet{Ts: time.Now(), Tuple: addr, Payload: data1}) + hdr.Seq += uint32(len(data1) + 10) + tcp.Process(hdr, &protos.Packet{Ts: time.Now(), Tuple: addr, Payload: data2}) + + // validate + assert.Equal(t, 10, gap) + assert.Equal(t, data2, state) +} + // Benchmark that runs with parallelism to help find concurrency related // issues. To run with parallelism, the 'go test' cpu flag must be set // greater than 1, otherwise it just runs concurrently but not in parallel.