Skip to content

Commit

Permalink
Merge pull request #494 from urso/fix/342-pgsql-panic
Browse files Browse the repository at this point in the history
TCP layer drop connection state on gap
  • Loading branch information
andrewkroh committed Dec 10, 2015
2 parents e8ac7bd + 96f3c7b commit ba6a0f8
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 73 deletions.
1 change: 1 addition & 0 deletions packetbeat/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ All notable changes to this project will be documented in this file based on the
- Fix errors in redis parser when length prefixed strings contain sequences of CRLF. #402
- Fix errors in redis parser when dealing with nested arrays. #402
- Improve MongoDB message correlation. #377
- Fix TCP connection state being reused after dropping due to gap in stream. #342

### Added
- Added redis pipelining support. #402
Expand Down
158 changes: 90 additions & 68 deletions packetbeat/protos/tcp/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,15 @@ func (tcp *Tcp) decideProtocol(tuple *common.IpPortTuple) protos.Protocol {
return protos.UnknownProtocol
}

func (tcp *Tcp) getStream(k common.HashableIpPortTuple) *TcpStream {
func (tcp *Tcp) findStream(k common.HashableIpPortTuple) *TcpConnection {
v := tcp.streams.Get(k)
if v != nil {
return v.(*TcpStream)
return v.(*TcpConnection)
}
return nil
}

type TcpStream struct {
type TcpConnection struct {
id uint32
tuple *common.IpPortTuple
protocol protos.Protocol
Expand All @@ -75,117 +75,139 @@ type TcpStream struct {
data protos.ProtocolData
}

func (stream *TcpStream) String() string {
type TcpStream struct {
conn *TcpConnection
dir uint8
}

func (conn *TcpConnection) String() string {
return fmt.Sprintf("TcpStream id[%d] tuple[%s] protocol[%s] lastSeq[%d %d]",
stream.id, stream.tuple, stream.protocol, stream.lastSeq[0], stream.lastSeq[1])
conn.id, conn.tuple, conn.protocol, conn.lastSeq[0], conn.lastSeq[1])
}

func (stream *TcpStream) addPacket(pkt *protos.Packet, tcphdr *layers.TCP, original_dir uint8) {
mod := stream.tcp.protocols.GetTcp(stream.protocol)
func (stream *TcpStream) addPacket(pkt *protos.Packet, tcphdr *layers.TCP) {
conn := stream.conn
mod := conn.tcp.protocols.GetTcp(conn.protocol)
if mod == nil {
if isDebug {
protocol := conn.protocol
debugf("Ignoring protocol for which we have no module loaded: %s",
stream.protocol)
protocol)
}
return
}

if len(pkt.Payload) > 0 {
stream.data = mod.Parse(pkt, &stream.tcptuple, original_dir, stream.data)
conn.data = mod.Parse(pkt, &conn.tcptuple, stream.dir, conn.data)
}

if tcphdr.FIN {
stream.data = mod.ReceivedFin(&stream.tcptuple, original_dir, stream.data)
conn.data = mod.ReceivedFin(&conn.tcptuple, stream.dir, conn.data)
}
}

func (stream *TcpStream) gapInStream(original_dir uint8, nbytes int) (drop bool) {
mod := stream.tcp.protocols.GetTcp(stream.protocol)
stream.data, drop = mod.GapInStream(&stream.tcptuple, original_dir, nbytes, stream.data)
func (stream *TcpStream) gapInStream(nbytes int) (drop bool) {
conn := stream.conn
mod := conn.tcp.protocols.GetTcp(conn.protocol)
conn.data, drop = mod.GapInStream(&conn.tcptuple, stream.dir, nbytes, conn.data)
return drop
}

func tcpSeqBefore(seq1 uint32, seq2 uint32) bool {
return int32(seq1-seq2) < 0
}

func tcpSeqBeforeEq(seq1 uint32, seq2 uint32) bool {
return int32(seq1-seq2) <= 0
}

func (tcp *Tcp) Process(tcphdr *layers.TCP, pkt *protos.Packet) {

// This Recover should catch all exceptions in
// protocol modules.
defer logp.Recover("Process tcp exception")

stream := tcp.getStream(pkt.Tuple.Hashable())
var original_dir uint8 = TcpDirectionOriginal
created := false
if stream == nil {
stream = tcp.getStream(pkt.Tuple.RevHashable())
if stream == nil {
protocol := tcp.decideProtocol(&pkt.Tuple)
if protocol == protos.UnknownProtocol {
// don't follow
return
}

timeout := time.Duration(0)
mod := tcp.protocols.GetTcp(protocol)
if mod != nil {
timeout = mod.ConnectionTimeout()
}

if isDebug {
debugf("Stream doesn't exist, creating new")
}

// create
stream = &TcpStream{id: tcp.getId(), tuple: &pkt.Tuple, protocol: protocol, tcp: tcp}
stream.tcptuple = common.TcpTupleFromIpPort(stream.tuple, stream.id)
tcp.streams.PutWithTimeout(pkt.Tuple.Hashable(), stream, timeout)
created = true
} else {
original_dir = TcpDirectionReverse
}
stream, created := tcp.getStream(pkt)
if stream.conn == nil {
return
}
conn := stream.conn

tcp_start_seq := tcphdr.Seq
tcp_seq := tcp_start_seq + uint32(len(pkt.Payload))

lastSeq := conn.lastSeq[stream.dir]
if isDebug {
debugf("pkt.start_seq=%v pkt.last_seq=%v stream.last_seq=%v (len=%d)",
tcp_start_seq, tcp_seq, stream.lastSeq[original_dir], len(pkt.Payload))
tcp_start_seq, tcp_seq, lastSeq, len(pkt.Payload))
}

if len(pkt.Payload) > 0 &&
stream.lastSeq[original_dir] != 0 {

if tcpSeqBeforeEq(tcp_seq, stream.lastSeq[original_dir]) {
if len(pkt.Payload) > 0 && lastSeq != 0 {
if tcpSeqBeforeEq(tcp_seq, lastSeq) {
if isDebug {
debugf("Ignoring what looks like a retransmitted segment. pkt.seq=%v len=%v stream.seq=%v",
tcphdr.Seq, len(pkt.Payload), stream.lastSeq[original_dir])
debugf("Ignoring retransmitted segment. pkt.seq=%v len=%v stream.seq=%v",
tcphdr.Seq, len(pkt.Payload), lastSeq)
}
return
}

if tcpSeqBefore(stream.lastSeq[original_dir], tcp_start_seq) {
if tcpSeqBefore(lastSeq, tcp_start_seq) {
if !created {
logp.Warn("Gap in tcp stream. last_seq: %d, seq: %d", stream.lastSeq[original_dir], tcp_start_seq)
drop := stream.gapInStream(original_dir,
int(tcp_start_seq-stream.lastSeq[original_dir]))
gap := int(tcp_start_seq - lastSeq)
logp.Warn("Gap in tcp stream. last_seq: %d, seq: %d, gap: %d", lastSeq, tcp_start_seq, gap)
drop := stream.gapInStream(gap)
if drop {
if isDebug {
debugf("Dropping stream because of gap")
debugf("Dropping connection state because of gap")
}
tcp.streams.Delete(stream.tuple.Hashable())

// drop application layer connection state and
// update stream_id for app layer analysers using stream_id for lookups
conn.id = tcp.getId()
conn.data = nil
}
}
}
}
stream.lastSeq[original_dir] = tcp_seq

stream.addPacket(pkt, tcphdr, original_dir)
conn.lastSeq[stream.dir] = tcp_seq
stream.addPacket(pkt, tcphdr)
}

func (tcp *Tcp) getStream(pkt *protos.Packet) (stream TcpStream, created bool) {
if conn := tcp.findStream(pkt.Tuple.Hashable()); conn != nil {
return TcpStream{conn: conn, dir: TcpDirectionOriginal}, false
}

if conn := tcp.findStream(pkt.Tuple.RevHashable()); conn != nil {
return TcpStream{conn: conn, dir: TcpDirectionReverse}, false
}

protocol := tcp.decideProtocol(&pkt.Tuple)
if protocol == protos.UnknownProtocol {
// don't follow
return TcpStream{}, false
}

var timeout time.Duration
mod := tcp.protocols.GetTcp(protocol)
if mod != nil {
timeout = mod.ConnectionTimeout()
}

if isDebug {
t := pkt.Tuple
debugf("Connection src[%s:%d] dst[%s:%d] doesn't exist, creating new",
t.Src_ip.String(), t.Src_port,
t.Dst_ip.String(), t.Dst_port)
}

conn := &TcpConnection{
id: tcp.getId(),
tuple: &pkt.Tuple,
protocol: protocol,
tcp: tcp}
conn.tcptuple = common.TcpTupleFromIpPort(conn.tuple, conn.id)
tcp.streams.PutWithTimeout(pkt.Tuple.Hashable(), conn, timeout)
return TcpStream{conn: conn, dir: TcpDirectionOriginal}, true
}

func tcpSeqBefore(seq1 uint32, seq2 uint32) bool {
return int32(seq1-seq2) < 0
}

func tcpSeqBeforeEq(seq1 uint32, seq2 uint32) bool {
return int32(seq1-seq2) <= 0
}

func buildPortsMap(plugins map[protos.Protocol]protos.TcpProtocolPlugin) (map[uint16]protos.Protocol, error) {
Expand Down
68 changes: 63 additions & 5 deletions packetbeat/protos/tcp/tcp_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package tcp

import (
"fmt"
"math/rand"
"net"
"testing"
Expand All @@ -23,12 +24,28 @@ const (

type TestProtocol struct {
Ports []int

init func(testMode bool, results publisher.Client) error
parse func(*protos.Packet, *common.TcpTuple, uint8, protos.ProtocolData) protos.ProtocolData
onFin func(*common.TcpTuple, uint8, protos.ProtocolData) protos.ProtocolData
gap func(*common.TcpTuple, uint8, int, protos.ProtocolData) (protos.ProtocolData, bool)
}

var _ protos.ProtocolPlugin = &TestProtocol{}
var _ protos.ProtocolPlugin = &TestProtocol{
init: func(m bool, r publisher.Client) error { return nil },
parse: func(p *protos.Packet, t *common.TcpTuple, d uint8, priv protos.ProtocolData) protos.ProtocolData {
return priv
},
onFin: func(t *common.TcpTuple, d uint8, p protos.ProtocolData) protos.ProtocolData {
return p
},
gap: func(t *common.TcpTuple, d uint8, b int, p protos.ProtocolData) (protos.ProtocolData, bool) {
return p, true
},
}

func (proto *TestProtocol) Init(test_mode bool, results publisher.Client) error {
return nil
return proto.init(test_mode, results)
}

func (proto TestProtocol) GetPorts() []int {
Expand All @@ -37,17 +54,17 @@ func (proto TestProtocol) GetPorts() []int {

func (proto TestProtocol) Parse(pkt *protos.Packet, tcptuple *common.TcpTuple,
dir uint8, private protos.ProtocolData) protos.ProtocolData {
return private
return proto.parse(pkt, tcptuple, dir, private)
}

func (proto TestProtocol) ReceivedFin(tcptuple *common.TcpTuple, dir uint8,
private protos.ProtocolData) protos.ProtocolData {
return private
return proto.onFin(tcptuple, dir, private)
}

func (proto TestProtocol) GapInStream(tcptuple *common.TcpTuple, dir uint8,
nbytes int, private protos.ProtocolData) (priv protos.ProtocolData, drop bool) {
return private, true
return proto.gap(tcptuple, dir, nbytes, private)
}

func (proto TestProtocol) ConnectionTimeout() time.Duration {
Expand Down Expand Up @@ -149,6 +166,47 @@ func (p protocols) GetAllTcp() map[protos.Protocol]protos.TcpProtocolPlugin
func (p protocols) GetAllUdp() map[protos.Protocol]protos.UdpProtocolPlugin { return nil }
func (p protocols) Register(proto protos.Protocol, plugin protos.ProtocolPlugin) { return }

func TestGapInStreamShouldDropState(t *testing.T) {
gap := 0
var state []byte

data1 := []byte{1, 2, 3, 4}
data2 := []byte{5, 6, 7, 8}

tp := &TestProtocol{Ports: []int{ServerPort}}
tp.gap = func(t *common.TcpTuple, d uint8, n int, p protos.ProtocolData) (protos.ProtocolData, bool) {
fmt.Println("lost: %v\n", n)
gap += n
return p, true // drop state
}
tp.parse = func(p *protos.Packet, t *common.TcpTuple, d uint8, priv protos.ProtocolData) protos.ProtocolData {
if priv == nil {
state = nil
}
state = append(state, p.Payload...)
return state
}

p := protocols{}
p.tcp = map[protos.Protocol]protos.TcpProtocolPlugin{
protos.HttpProtocol: tp,
}
tcp, _ := NewTcp(p)

addr := common.NewIpPortTuple(4,
net.ParseIP(ServerIp), ServerPort,
net.ParseIP(ClientIp), uint16(rand.Intn(65535)))

hdr := &layers.TCP{}
tcp.Process(hdr, &protos.Packet{Ts: time.Now(), Tuple: addr, Payload: data1})
hdr.Seq += uint32(len(data1) + 10)
tcp.Process(hdr, &protos.Packet{Ts: time.Now(), Tuple: addr, Payload: data2})

// validate
assert.Equal(t, 10, gap)
assert.Equal(t, data2, state)
}

// Benchmark that runs with parallelism to help find concurrency related
// issues. To run with parallelism, the 'go test' cpu flag must be set
// greater than 1, otherwise it just runs concurrently but not in parallel.
Expand Down

0 comments on commit ba6a0f8

Please sign in to comment.