diff --git a/Godeps/_workspace/src/github.com/elastic/libbeat/common/cache.go b/Godeps/_workspace/src/github.com/elastic/libbeat/common/cache.go new file mode 100644 index 000000000000..6faa2a4b2f10 --- /dev/null +++ b/Godeps/_workspace/src/github.com/elastic/libbeat/common/cache.go @@ -0,0 +1,229 @@ +package common + +import ( + "sync" + "time" +) + +// Key type used in the cache. +type Key interface{} + +// Value type held in the cache. Cannot be nil. +type Value interface{} + +// RemovalListener is the callback function type that can be registered with +// the cache to receive notification of the removal of expired elements. +type RemovalListener func(k Key, v Value) + +// Clock is the function type used to get the current time. +type clock func() time.Time + +// An element stored in the cache. +type element struct { + expiration time.Time + value Value +} + +// IsExpired returns true if the element is expired (current time is greater +// than the expiration time). +func (e *element) IsExpired(now time.Time) bool { + return now.After(e.expiration) +} + +// UpdateLastAccessTime updates the expiration time of the element. This +// should be called each time the element is accessed. +func (e *element) UpdateLastAccessTime(now time.Time, expiration time.Duration) { + e.expiration = now.Add(expiration) +} + +// Cache is a semi-persistent mapping of keys to values. Elements added to the +// cache are store until they are explicitly deleted or are expired due time- +// based eviction based on last access time. +// +// Expired elements are not visible through classes methods, but they do remain +// stored in the cache until CleanUp() is invoked. Therefore CleanUp() must be +// invoked periodically to prevent the cache from becoming a memory leak. If +// you want to start a goroutine to perform periodic clean-up then see +// StartJanitor(). +// +// Cache does not support storing nil values. Any attempt to put nil into +// the cache will cause a panic. +type Cache struct { + sync.RWMutex + timeout time.Duration // Length of time before cache elements expire. + elements map[Key]*element // Data stored by the cache. + clock clock // Function used to get the current time. + listener RemovalListener // Callback listen to notify of evictions. + janitorQuit chan struct{} // Closing this channel stop the janitor. +} + +// NewCache creates and returns a new Cache. d is the length of time after last +// access that cache elements expire. initialSize is the initial allocation size +// used for the Cache's underlying map. +func NewCache(d time.Duration, initialSize int) *Cache { + return newCache(d, initialSize, nil, time.Now) +} + +// NewCacheWithRemovalListener creates and returns a new Cache and register a +// RemovalListener callback function. d is the length of time after last access +// that cache elements expire. initialSize is the initial allocation size used +// for the Cache's underlying map. l is the callback function that will be +// invoked when cache elements are removed from the map on CleanUp. +func NewCacheWithRemovalListener(d time.Duration, initialSize int, l RemovalListener) *Cache { + return newCache(d, initialSize, l, time.Now) +} + +func newCache(d time.Duration, initialSize int, l RemovalListener, t clock) *Cache { + return &Cache{ + timeout: d, + elements: make(map[Key]*element, initialSize), + listener: l, + clock: t, + } +} + +// PutIfAbsent writes the given key and value to the cache only if the key is +// absent from the cache. Nil is returned if the key-value pair were written, +// otherwise the old value is returned. +func (c *Cache) PutIfAbsent(k Key, v Value) Value { + c.Lock() + defer c.Unlock() + oldValue, exists := c.get(k) + if exists { + return oldValue + } + + c.put(k, v) + return nil +} + +// Put writes the given key and value to the map replacing any existing value +// if it exists. The previous value associated with the key returned or nil +// if the key was not present. +func (c *Cache) Put(k Key, v Value) Value { + c.Lock() + defer c.Unlock() + oldValue, _ := c.get(k) + c.put(k, v) + return oldValue +} + +// Replace overwrites the value for a key only if the key exists. The old +// value is returned if the value is updated, otherwise nil is returned. +func (c *Cache) Replace(k Key, v Value) Value { + c.Lock() + defer c.Unlock() + oldValue, exists := c.get(k) + if !exists { + return nil + } + + c.put(k, v) + return oldValue +} + +// Get the current value associated with a key or nil if the key is not +// present. The last access time of the element is updated. +func (c *Cache) Get(k Key) Value { + c.RLock() + defer c.RUnlock() + v, _ := c.get(k) + return v +} + +// Delete a key from the map and return the value or nil if the key does +// not exist. The RemovalListener is not notified for explicit deletions. +func (c *Cache) Delete(k Key) Value { + c.Lock() + defer c.Unlock() + v, _ := c.get(k) + delete(c.elements, k) + return v +} + +// CleanUp performs maintenance on the cache by removing expired elements from +// the cache. If a RemoveListener is registered it will be invoked for each +// element that is removed during this clean up operation. The RemovalListener +// is invoked on the caller's goroutine. +func (c *Cache) CleanUp() int { + c.Lock() + defer c.Unlock() + count := 0 + for k, v := range c.elements { + if v.IsExpired(c.clock()) { + delete(c.elements, k) + count++ + if c.listener != nil { + c.listener(k, v.value) + } + } + } + return count +} + +// Entries returns a copy of the non-expired elements in the cache. +func (c *Cache) Entries() map[Key]Value { + c.RLock() + defer c.RUnlock() + copy := make(map[Key]Value, len(c.elements)) + for k, v := range c.elements { + if !v.IsExpired(c.clock()) { + copy[k] = v.value + } + } + return copy +} + +// Size returns the number of elements in the cache. The number includes both +// active elements and expired elements that have not been cleaned up. +func (c *Cache) Size() int { + c.RLock() + defer c.RUnlock() + return len(c.elements) +} + +// StartJanitor starts a goroutine that will periodically invoke the cache's +// CleanUp() method. +func (c *Cache) StartJanitor(interval time.Duration) { + ticker := time.NewTicker(interval) + c.janitorQuit = make(chan struct{}) + go func() { + for { + select { + case <-ticker.C: + c.CleanUp() + case <-c.janitorQuit: + ticker.Stop() + return + } + } + }() +} + +// StopJanitor stops the goroutine created by StartJanitor. +func (c *Cache) StopJanitor() { + close(c.janitorQuit) +} + +// get returns the non-expired values from the cache. +func (c *Cache) get(k Key) (Value, bool) { + elem, exists := c.elements[k] + now := c.clock() + if exists && !elem.IsExpired(now) { + elem.UpdateLastAccessTime(now, c.timeout) + return elem.value, true + } + return nil, false +} + +// put writes a key-value to the cache replacing any existing mapping. +func (c *Cache) put(k Key, v Value) { + if v == nil { + panic("Cache does not support storing nil values.") + } + + c.elements[k] = &element{ + expiration: c.clock().Add(c.timeout), + value: v, + } +} diff --git a/Godeps/_workspace/src/github.com/elastic/libbeat/common/cache_test.go b/Godeps/_workspace/src/github.com/elastic/libbeat/common/cache_test.go new file mode 100644 index 000000000000..bc4e77cb87e3 --- /dev/null +++ b/Godeps/_workspace/src/github.com/elastic/libbeat/common/cache_test.go @@ -0,0 +1,167 @@ +package common + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +const ( + Timeout time.Duration = 1 * time.Minute + InitalSize int = 10 +) + +const ( + alphaKey = "alphaKey" + alphaValue = "a" + bravoKey = "bravoKey" + bravoValue = "b" +) + +// Current time as simulated by the fakeClock function. +var ( + currentTime time.Time + fakeClock clock = func() time.Time { + return currentTime + } +) + +// RemovalListener callback. +var ( + callbackKey Key + callbackValue Value + removalListener RemovalListener = func(k Key, v Value) { + callbackKey = k + callbackValue = v + } +) + +// Test that the removal listener is invoked with the expired key/value. +func TestExpireWithRemovalListener(t *testing.T) { + callbackKey = nil + callbackValue = nil + c := newCache(Timeout, InitalSize, removalListener, fakeClock) + c.Put(alphaKey, alphaValue) + currentTime = currentTime.Add(Timeout).Add(time.Nanosecond) + assert.Equal(t, 1, c.CleanUp()) + assert.Equal(t, alphaKey, callbackKey) + assert.Equal(t, alphaValue, callbackValue) +} + +// Test that the number of removed elements is returned by Expire. +func TestExpireWithoutRemovalListener(t *testing.T) { + c := newCache(Timeout, InitalSize, nil, fakeClock) + c.Put(alphaKey, alphaValue) + c.Put(bravoKey, bravoValue) + currentTime = currentTime.Add(Timeout).Add(time.Nanosecond) + assert.Equal(t, 2, c.CleanUp()) +} + +func TestPutIfAbsent(t *testing.T) { + c := newCache(Timeout, InitalSize, nil, fakeClock) + oldValue := c.PutIfAbsent(alphaKey, alphaValue) + assert.Nil(t, oldValue) + oldValue = c.PutIfAbsent(alphaKey, bravoValue) + assert.Equal(t, alphaValue, oldValue) +} + +func TestPut(t *testing.T) { + c := newCache(Timeout, InitalSize, nil, fakeClock) + oldValue := c.Put(alphaKey, alphaValue) + assert.Nil(t, oldValue) + oldValue = c.Put(bravoKey, bravoValue) + assert.Nil(t, oldValue) + + oldValue = c.Put(alphaKey, bravoValue) + assert.Equal(t, alphaValue, oldValue) + oldValue = c.Put(bravoKey, alphaValue) + assert.Equal(t, bravoValue, oldValue) +} + +func TestReplace(t *testing.T) { + c := newCache(Timeout, InitalSize, nil, fakeClock) + + // Nil is returned when the value does not exist and no element is added. + assert.Nil(t, c.Replace(alphaKey, alphaValue)) + assert.Equal(t, 0, c.Size()) + + // alphaKey is replaced with the new value. + assert.Nil(t, c.Put(alphaKey, alphaValue)) + assert.Equal(t, alphaValue, c.Replace(alphaKey, bravoValue)) + assert.Equal(t, 1, c.Size()) +} + +func TestGetUpdatesLastAccessTime(t *testing.T) { + c := newCache(Timeout, InitalSize, nil, fakeClock) + c.Put(alphaKey, alphaValue) + + currentTime = currentTime.Add(Timeout / 2) + assert.Equal(t, alphaValue, c.Get(alphaKey)) + currentTime = currentTime.Add(Timeout / 2) + assert.Equal(t, alphaValue, c.Get(alphaKey)) +} + +func TestDeleteNonExistentKey(t *testing.T) { + c := newCache(Timeout, InitalSize, nil, fakeClock) + assert.Nil(t, c.Delete(alphaKey)) +} + +func TestDeleteExistingKey(t *testing.T) { + c := newCache(Timeout, InitalSize, nil, fakeClock) + c.Put(alphaKey, alphaValue) + assert.Equal(t, alphaValue, c.Delete(alphaKey)) +} + +func TestDeleteExpiredKey(t *testing.T) { + c := newCache(Timeout, InitalSize, nil, fakeClock) + c.Put(alphaKey, alphaValue) + currentTime = currentTime.Add(Timeout).Add(time.Nanosecond) + assert.Nil(t, c.Delete(alphaKey)) +} + +// Test that Entries returns the non-expired map entries. +func TestEntries(t *testing.T) { + c := newCache(Timeout, InitalSize, nil, fakeClock) + c.Put(alphaKey, alphaValue) + currentTime = currentTime.Add(Timeout).Add(time.Nanosecond) + c.Put(bravoKey, bravoValue) + m := c.Entries() + assert.Equal(t, 1, len(m)) + assert.Equal(t, bravoValue, m[bravoKey]) +} + +// Test that Size returns a count of both expired and non-expired elements. +func TestSize(t *testing.T) { + c := newCache(Timeout, InitalSize, nil, fakeClock) + c.Put(alphaKey, alphaValue) + currentTime = currentTime.Add(Timeout).Add(time.Nanosecond) + c.Put(bravoKey, bravoValue) + assert.Equal(t, 2, c.Size()) +} + +func TestGetExpiredValue(t *testing.T) { + c := newCache(Timeout, InitalSize, nil, fakeClock) + c.Put(alphaKey, alphaValue) + v := c.Get(alphaKey) + assert.Equal(t, alphaValue, v) + + currentTime = currentTime.Add(Timeout).Add(time.Nanosecond) + v = c.Get(alphaKey) + assert.Nil(t, v) +} + +// Test that the janitor invokes CleanUp on the cache and that the +// RemovalListener is invoked during clean up. +func TestJanitor(t *testing.T) { + keyChan := make(chan Key) + c := newCache(Timeout, InitalSize, func(k Key, v Value) { + keyChan <- k + }, fakeClock) + c.Put(alphaKey, alphaValue) + currentTime = currentTime.Add(Timeout).Add(time.Nanosecond) + c.StartJanitor(time.Millisecond) + key := <-keyChan + c.StopJanitor() + assert.Equal(t, alphaKey, key) +} diff --git a/Makefile b/Makefile index 3c6a95c62b41..2c153aaf5af7 100644 --- a/Makefile +++ b/Makefile @@ -107,7 +107,7 @@ coverage: .PHONY: benchmark benchmark: - $(GODEP) go test -short -bench=. ./... + $(GODEP) go test -short -bench=. ./... -cpu=2 .PHONY: env env: env/bin/activate diff --git a/packetbeat.go b/packetbeat.go index bf0262e3370b..a35e958d9a83 100644 --- a/packetbeat.go +++ b/packetbeat.go @@ -55,7 +55,6 @@ type Packetbeat struct { CmdLineArgs CmdLineArgs Sniff *sniffer.SnifferSetup over chan bool - tcpProc *tcp.Tcp } type CmdLineArgs struct { @@ -154,7 +153,7 @@ func (pb *Packetbeat) Setup(b *beat.Beat) error { var err error - pb.tcpProc, err = tcp.NewTcp(&protos.Protos) + tcpProc, err := tcp.NewTcp(&protos.Protos) if err != nil { logp.Critical(err.Error()) os.Exit(1) @@ -180,7 +179,7 @@ func (pb *Packetbeat) Setup(b *beat.Beat) error { } logp.Debug("main", "Initializing sniffer") - err = pb.Sniff.Init(false, afterInputsQueue, pb.tcpProc, udpProc) + err = pb.Sniff.Init(false, afterInputsQueue, tcpProc, udpProc) if err != nil { logp.Critical("Initializing sniffer failed: %v", err) os.Exit(1) @@ -231,9 +230,9 @@ func (pb *Packetbeat) Run(b *beat.Beat) error { func (pb *Packetbeat) Cleanup(b *beat.Beat) error { if service.WithMemProfile() { - // wait for all TCP streams to expire - time.Sleep(tcp.TCP_STREAM_EXPIRY * 1.2) - pb.tcpProc.PrintTcpMap() + logp.Debug("main", "Waiting for streams and transactions to expire...") + time.Sleep(time.Duration(float64(protos.DefaultTransactionExpiration) * 1.2)) + logp.Debug("main", "Streams and transactions should all be expired now.") } return nil } diff --git a/pre-commit b/pre-commit index 23a93d4f0d94..111f5d10e51d 100755 --- a/pre-commit +++ b/pre-commit @@ -2,5 +2,5 @@ echo "Running pre-commit hook..." go fmt ./... go vet ./... -go test -short ./... +godep go test -short ./... flake8 tests/*.py tests/pbtests/*.py diff --git a/protos/dns/dns.go b/protos/dns/dns.go index 4dcc951fea6a..44666853ddd9 100644 --- a/protos/dns/dns.go +++ b/protos/dns/dns.go @@ -21,7 +21,6 @@ import ( "fmt" "net" "strings" - "sync" "time" "github.com/elastic/libbeat/common" @@ -35,11 +34,7 @@ import ( "github.com/tsg/gopacket/layers" ) -const ( - TransactionsHashSize = 2 ^ 16 - TransactionTimeoutNanos = 10 * 1e9 - MaxDnsTupleRawSize = 16 + 16 + 2 + 2 + 4 + 1 -) +const MaxDnsTupleRawSize = 16 + 16 + 2 + 2 + 4 + 1 // Constants used to associate the DNS QR flag with a meaningful value. const ( @@ -197,8 +192,6 @@ type DnsTransaction struct { Request *DnsMessage Response *DnsMessage - - timer *time.Timer } func newTransaction(ts time.Time, tuple DnsTuple, cmd common.CmdlineTuple) *DnsTransaction { @@ -228,43 +221,33 @@ type Dns struct { Include_authorities bool Include_additionals bool - // Map of active DNS transactions. The map key is the HashableDnsTuple - // associated with the request. Use the put, lookup, and deleteTransaction - // methods to make map access concurrency-safe. - transactionsMap map[HashableDnsTuple]*DnsTransaction - transactionsMutex sync.Mutex + // Cache of active DNS transactions. The map key is the HashableDnsTuple + // associated with the request. + transactions *common.Cache results chan common.MapStr // Channel where results are pushed. } -// putTransaction puts a transaction into the transaction map. If the -// key already exists then the exiting entry will be overridden. The -// key should be the HashableDnsTuple associated with the request (src -// is the requestor). -func (dns *Dns) putTransaction(h HashableDnsTuple, trans *DnsTransaction) { - dns.transactionsMutex.Lock() - defer dns.transactionsMutex.Unlock() - dns.transactionsMap[h] = trans -} - -// lookupTransaction returns the transaction associated with the given +// getTransaction returns the transaction associated with the given // HashableDnsTuple. The lookup key should be the HashableDnsTuple associated // with the request (src is the requestor). Nil is returned if the entry // does not exist. -func (dns *Dns) lookupTransaction(h HashableDnsTuple) *DnsTransaction { - dns.transactionsMutex.Lock() - defer dns.transactionsMutex.Unlock() - return dns.transactionsMap[h] +func (dns *Dns) getTransaction(k HashableDnsTuple) *DnsTransaction { + v := dns.transactions.Get(k) + if v != nil { + return v.(*DnsTransaction) + } + return nil } // deleteTransaction deletes an entry from the transaction map and returns // the deleted element. If the key does not exist then nil is returned. -func (dns *Dns) deleteTransaction(h HashableDnsTuple) *DnsTransaction { - dns.transactionsMutex.Lock() - defer dns.transactionsMutex.Unlock() - t := dns.transactionsMap[h] - delete(dns.transactionsMap, h) - return t +func (dns *Dns) deleteTransaction(k HashableDnsTuple) *DnsTransaction { + v := dns.transactions.Delete(k) + if v != nil { + return v.(*DnsTransaction) + } + return nil } func (dns *Dns) initDefaults() { @@ -299,8 +282,9 @@ func (dns *Dns) Init(test_mode bool, results chan common.MapStr) error { dns.setFromConfig(config.ConfigSingleton.Protocols.Dns) } - dns.transactionsMap = make(map[HashableDnsTuple]*DnsTransaction, TransactionsHashSize) - dns.transactionsMutex = sync.Mutex{} + dns.transactions = common.NewCache(protos.DefaultTransactionExpiration, + protos.DefaultTransactionHashSize) + dns.transactions.StartJanitor(protos.DefaultTransactionExpiration) dns.results = results @@ -356,17 +340,14 @@ func (dns *Dns) receivedDnsRequest(tuple *DnsTuple, msg *DnsMessage) { } trans = newTransaction(msg.Ts, *tuple, *msg.CmdlineTuple) - dns.putTransaction(tuple.Hashable(), trans) + dns.transactions.Put(tuple.Hashable(), trans) trans.Request = msg - - trans.timer = time.AfterFunc(TransactionTimeoutNanos, - func() { dns.expireTransaction(trans) }) } func (dns *Dns) receivedDnsResponse(tuple *DnsTuple, msg *DnsMessage) { logp.Debug("dns", "Processing response. %s", tuple) - trans := dns.lookupTransaction(tuple.RevHashable()) + trans := dns.getTransaction(tuple.RevHashable()) if trans == nil { trans = newTransaction(msg.Ts, tuple.Reverse(), common.CmdlineTuple{ Src: msg.CmdlineTuple.Dst, Dst: msg.CmdlineTuple.Src}) @@ -378,10 +359,6 @@ func (dns *Dns) receivedDnsResponse(tuple *DnsTuple, msg *DnsMessage) { dns.publishTransaction(trans) } -func (dns *Dns) expireTransaction(t *DnsTransaction) { - dns.deleteTransaction(t.tuple.Hashable()) -} - func (dns *Dns) publishTransaction(t *DnsTransaction) { if dns.results == nil { return @@ -390,10 +367,6 @@ func (dns *Dns) publishTransaction(t *DnsTransaction) { logp.Debug("dns", "Publishing transaction. %s", t.tuple.String()) dns.deleteTransaction(t.tuple.Hashable()) - if t.timer != nil { - t.timer.Stop() - } - event := common.MapStr{} event["timestamp"] = common.Time(t.ts) event["type"] = "dns" @@ -642,26 +615,18 @@ func dnsToString(dns *layers.DNS) string { } if len(dns.Answers) > 0 { - t = []string{} - for _, rr := range dns.Answers { - t = append(t, dnsResourceRecordToString(&rr)) - } - a = append(a, fmt.Sprintf("ANSWER %s", strings.Join(t, "; "))) + a = append(a, fmt.Sprintf("ANSWER %s", + dnsResourceRecordsToString(dns.Answers))) } if len(dns.Authorities) > 0 { - t = []string{} - for _, rr := range dns.Authorities { - t = append(t, dnsResourceRecordToString(&rr)) - } - a = append(a, fmt.Sprintf("AUTHORITY %s", strings.Join(t, "; "))) + a = append(a, fmt.Sprintf("AUTHORITY %s", + dnsResourceRecordsToString(dns.Authorities))) } if len(dns.Additionals) > 0 { - for _, rr := range dns.Additionals { - t = append(t, dnsResourceRecordToString(&rr)) - } - a = append(a, fmt.Sprintf("ADDITIONAL %s", strings.Join(t, "; "))) + a = append(a, fmt.Sprintf("ADDITIONAL %s", + dnsResourceRecordsToString(dns.Additionals))) } return strings.Join(a, "; ") diff --git a/protos/dns/dns_test.go b/protos/dns/dns_test.go index 8888e21ba234..99e38594f5b0 100644 --- a/protos/dns/dns_test.go +++ b/protos/dns/dns_test.go @@ -268,7 +268,7 @@ func TestParseUdp_emptyPacket(t *testing.T) { dns := newDns(testing.Verbose()) packet := newPacket(forward, []byte{}) dns.ParseUdp(packet) - assert.Empty(t, dns.transactionsMap, "There should be no transactions.") + assert.Empty(t, dns.transactions.Size(), "There should be no transactions.") close(dns.results) assert.Nil(t, <-dns.results, "No result should have been published.") } @@ -279,7 +279,7 @@ func TestParseUdp_malformedPacket(t *testing.T) { garbage := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13} packet := newPacket(forward, garbage) dns.ParseUdp(packet) - assert.Empty(t, dns.transactionsMap, "There should be no transactions.") + assert.Empty(t, dns.transactions.Size(), "There should be no transactions.") // As a future addition, a malformed message should publish a result. } @@ -289,7 +289,7 @@ func TestParseUdp_requestPacket(t *testing.T) { dns := newDns(testing.Verbose()) packet := newPacket(forward, elasticA.request) dns.ParseUdp(packet) - assert.Len(t, dns.transactionsMap, 1, "There should be one transaction.") + assert.Equal(t, 1, dns.transactions.Size(), "There should be one transaction.") close(dns.results) assert.Nil(t, <-dns.results, "No result should have been published.") } @@ -320,10 +320,10 @@ func TestParseUdp_duplicateRequests(t *testing.T) { q := elasticA packet := newPacket(forward, q.request) dns.ParseUdp(packet) - assert.Len(t, dns.transactionsMap, 1, "There should be one transaction.") + assert.Equal(t, 1, dns.transactions.Size(), "There should be one transaction.") packet = newPacket(forward, q.request) dns.ParseUdp(packet) - assert.Len(t, dns.transactionsMap, 1, "There should be one transaction.") + assert.Equal(t, 1, dns.transactions.Size(), "There should be one transaction.") m := expectResult(t, dns) assert.Equal(t, "udp", mapValue(t, m, "transport")) @@ -411,7 +411,7 @@ func parseUdpRequestResponse(t testing.TB, dns *Dns, q DnsTestMessage) { dns.ParseUdp(packet) packet = newPacket(reverse, q.response) dns.ParseUdp(packet) - assert.Empty(t, dns.transactionsMap, "There should be no transactions.") + assert.Empty(t, dns.transactions.Size(), "There should be no transactions.") m := expectResult(t, dns) assert.Equal(t, "udp", mapValue(t, m, "transport")) diff --git a/protos/http/http.go b/protos/http/http.go index 9cf6937c1b4a..450c7bbc5c33 100644 --- a/protos/http/http.go +++ b/protos/http/http.go @@ -103,8 +103,6 @@ type HttpTransaction struct { Request_raw string Response_raw string - - timer *time.Timer } type Http struct { @@ -120,11 +118,19 @@ type Http struct { Hide_keywords []string Strip_authorization bool - transactionsMap map[common.HashableTcpTuple]*HttpTransaction + transactions *common.Cache results chan common.MapStr } +func (http *Http) getTransaction(k common.HashableTcpTuple) *HttpTransaction { + v := http.transactions.Get(k) + if v != nil { + return v.(*HttpTransaction) + } + return nil +} + func (http *Http) InitDefaults() { http.Send_request = false http.Send_response = false @@ -175,11 +181,6 @@ func (http *Http) GetPorts() []int { return http.Ports } -const ( - TransactionsHashSize = 2 ^ 16 - TransactionTimeout = 10 * 1e9 -) - func (http *Http) Init(test_mode bool, results chan common.MapStr) error { http.InitDefaults() @@ -191,10 +192,9 @@ func (http *Http) Init(test_mode bool, results chan common.MapStr) error { } } - http.transactionsMap = make(map[common.HashableTcpTuple]*HttpTransaction, TransactionsHashSize) - - logp.Debug("http", "transactionsMap: %p http: %p", http.transactionsMap, &http) - + http.transactions = common.NewCache(protos.DefaultTransactionExpiration, + protos.DefaultTransactionHashSize) + http.transactions.StartJanitor(protos.DefaultTransactionExpiration) http.results = results return nil @@ -730,15 +730,14 @@ func (http *Http) handleHttp(m *HttpMessage, tcptuple *common.TcpTuple, func (http *Http) receivedHttpRequest(msg *HttpMessage) { - trans := http.transactionsMap[msg.TcpTuple.Hashable()] + trans := http.getTransaction(msg.TcpTuple.Hashable()) if trans != nil { if len(trans.Http) != 0 { logp.Warn("Two requests without a response. Dropping old request") } } else { trans = &HttpTransaction{Type: "http", tuple: msg.TcpTuple} - logp.Debug("http", "transactionsMap %p http %p", http.transactionsMap, http) - http.transactionsMap[msg.TcpTuple.Hashable()] = trans + http.transactions.Put(msg.TcpTuple.Hashable(), trans) } logp.Debug("http", "Received request with tuple: %s", msg.TcpTuple) @@ -796,17 +795,6 @@ func (http *Http) receivedHttpRequest(msg *HttpMessage) { if err != nil { logp.Warn("http", "Fail to parse HTTP parameters: %v", err) } - - if trans.timer != nil { - trans.timer.Stop() - } - trans.timer = time.AfterFunc(TransactionTimeout, func() { http.expireTransaction(trans) }) - -} - -func (http *Http) expireTransaction(trans *HttpTransaction) { - // remove from map - delete(http.transactionsMap, trans.tuple.Hashable()) } func (http *Http) receivedHttpResponse(msg *HttpMessage) { @@ -816,7 +804,7 @@ func (http *Http) receivedHttpResponse(msg *HttpMessage) { logp.Debug("http", "Received response with tuple: %s", tuple) - trans := http.transactionsMap[tuple.Hashable()] + trans := http.getTransaction(tuple.Hashable()) if trans == nil { logp.Warn("Response from unknown transaction. Ignoring: %v", tuple) return @@ -862,14 +850,9 @@ func (http *Http) receivedHttpResponse(msg *HttpMessage) { } http.publishTransaction(trans) + http.transactions.Delete(trans.tuple.Hashable()) logp.Debug("http", "HTTP transaction completed: %s\n", trans.Http) - - // remove from map - delete(http.transactionsMap, trans.tuple.Hashable()) - if trans.timer != nil { - trans.timer.Stop() - } } func (http *Http) publishTransaction(t *HttpTransaction) { diff --git a/protos/mongodb/mongodb.go b/protos/mongodb/mongodb.go index ddc074f20ffc..17d70c47169c 100644 --- a/protos/mongodb/mongodb.go +++ b/protos/mongodb/mongodb.go @@ -3,7 +3,6 @@ package mongodb import ( "fmt" "strings" - "time" "github.com/elastic/libbeat/common" "github.com/elastic/libbeat/logp" @@ -21,11 +20,19 @@ type Mongodb struct { Max_docs int Max_doc_length int - transactionsMap map[common.HashableTcpTuple]*MongodbTransaction + transactions *common.Cache results chan common.MapStr } +func (mongodb *Mongodb) getTransaction(k common.HashableTcpTuple) *MongodbTransaction { + v := mongodb.transactions.Get(k) + if v != nil { + return v.(*MongodbTransaction) + } + return nil +} + func (mongodb *Mongodb) InitDefaults() { mongodb.Send_request = false mongodb.Send_response = false @@ -66,7 +73,9 @@ func (mongodb *Mongodb) Init(test_mode bool, results chan common.MapStr) error { } } - mongodb.transactionsMap = make(map[common.HashableTcpTuple]*MongodbTransaction, TransactionsHashSize) + mongodb.transactions = common.NewCache(protos.DefaultTransactionExpiration, + protos.DefaultTransactionHashSize) + mongodb.transactions.StartJanitor(protos.DefaultTransactionExpiration) mongodb.results = results return nil @@ -160,7 +169,7 @@ func (mongodb *Mongodb) receivedMongodbRequest(msg *MongodbMessage) { // Add it to the HT tuple := msg.TcpTuple - trans := mongodb.transactionsMap[tuple.Hashable()] + trans := mongodb.getTransaction(tuple.Hashable()) if trans != nil { if trans.Mongodb != nil { logp.Warn("Two requests without a Response. Dropping old request") @@ -168,7 +177,7 @@ func (mongodb *Mongodb) receivedMongodbRequest(msg *MongodbMessage) { } else { logp.Debug("mongodb", "Initialize new transaction from request") trans = &MongodbTransaction{Type: "mongodb", tuple: tuple} - mongodb.transactionsMap[tuple.Hashable()] = trans + mongodb.transactions.Put(tuple.Hashable(), trans) } trans.Mongodb = common.MapStr{} @@ -197,24 +206,11 @@ func (mongodb *Mongodb) receivedMongodbRequest(msg *MongodbMessage) { trans.params = msg.params trans.resource = msg.resource trans.BytesIn = msg.messageLength - - if trans.timer != nil { - trans.timer.Stop() - } - trans.timer = time.AfterFunc(TransactionTimeout, func() { mongodb.expireTransaction(trans) }) - -} - -func (mongodb *Mongodb) expireTransaction(trans *MongodbTransaction) { - logp.Debug("mongodb", "Expire transaction") - // remove from map - delete(mongodb.transactionsMap, trans.tuple.Hashable()) } func (mongodb *Mongodb) receivedMongodbResponse(msg *MongodbMessage) { - tuple := msg.TcpTuple - trans := mongodb.transactionsMap[tuple.Hashable()] + trans := mongodb.getTransaction(msg.TcpTuple.Hashable()) if trans == nil { logp.Warn("Response from unknown transaction. Ignoring.") return @@ -238,14 +234,9 @@ func (mongodb *Mongodb) receivedMongodbResponse(msg *MongodbMessage) { trans.BytesOut = msg.messageLength mongodb.publishTransaction(trans) + mongodb.transactions.Delete(trans.tuple.Hashable()) logp.Debug("mongodb", "Mongodb transaction completed: %s", trans.Mongodb) - - // remove from map - delete(mongodb.transactionsMap, trans.tuple.Hashable()) - if trans.timer != nil { - trans.timer.Stop() - } } func (mongodb *Mongodb) GapInStream(tcptuple *common.TcpTuple, dir uint8, diff --git a/protos/mongodb/mongodb_structs.go b/protos/mongodb/mongodb_structs.go index e945d0628fe4..142505c56df0 100644 --- a/protos/mongodb/mongodb_structs.go +++ b/protos/mongodb/mongodb_structs.go @@ -73,7 +73,6 @@ type MongodbTransaction struct { ts time.Time BytesOut int BytesIn int - timer *time.Timer Mongodb common.MapStr @@ -85,11 +84,6 @@ type MongodbTransaction struct { documents []interface{} } -const ( - TransactionsHashSize = 2 ^ 16 - TransactionTimeout = 10 * 1e9 -) - // List of valid mongodb wire protocol operation codes // see http://docs.mongodb.org/meta-driver/latest/legacy/mongodb-wire-protocol/#request-opcodes var OpCodes = map[int]string{ diff --git a/protos/mysql/mysql.go b/protos/mysql/mysql.go index b77aa4a43170..da07bb6657f8 100644 --- a/protos/mysql/mysql.go +++ b/protos/mysql/mysql.go @@ -74,8 +74,6 @@ type MysqlTransaction struct { Request_raw string Response_raw string - - timer *time.Timer } type MysqlStream struct { @@ -90,11 +88,6 @@ type MysqlStream struct { message *MysqlMessage } -const ( - TransactionsHashSize = 2 ^ 16 - TransactionTimeout = 10 * 1e9 -) - type parseState int const ( @@ -126,7 +119,7 @@ type Mysql struct { Send_request bool Send_response bool - transactionsMap map[common.HashableTcpTuple]*MysqlTransaction + transactions *common.Cache results chan common.MapStr @@ -135,6 +128,14 @@ type Mysql struct { dir uint8, raw_msg []byte) } +func (mysql *Mysql) getTransaction(k common.HashableTcpTuple) *MysqlTransaction { + v := mysql.transactions.Get(k) + if v != nil { + return v.(*MysqlTransaction) + } + return nil +} + func (mysql *Mysql) InitDefaults() { mysql.maxRowLength = 1024 mysql.maxStoreRows = 10 @@ -175,7 +176,9 @@ func (mysql *Mysql) Init(test_mode bool, results chan common.MapStr) error { } } - mysql.transactionsMap = make(map[common.HashableTcpTuple]*MysqlTransaction, TransactionsHashSize) + mysql.transactions = common.NewCache(protos.DefaultTransactionExpiration, + protos.DefaultTransactionHashSize) + mysql.transactions.StartJanitor(protos.DefaultTransactionExpiration) mysql.handleMysql = handleMysql mysql.results = results @@ -573,18 +576,15 @@ func handleMysql(mysql *Mysql, m *MysqlMessage, tcptuple *common.TcpTuple, } func (mysql *Mysql) receivedMysqlRequest(msg *MysqlMessage) { - - // Add it to the HT tuple := msg.TcpTuple - - trans := mysql.transactionsMap[tuple.Hashable()] + trans := mysql.getTransaction(tuple.Hashable()) if trans != nil { if trans.Mysql != nil { logp.Debug("mysql", "Two requests without a Response. Dropping old request: %s", trans.Mysql) } } else { trans = &MysqlTransaction{Type: "mysql", tuple: tuple} - mysql.transactionsMap[tuple.Hashable()] = trans + mysql.transactions.Put(tuple.Hashable(), trans) } trans.ts = msg.Ts @@ -625,16 +625,10 @@ func (mysql *Mysql) receivedMysqlRequest(msg *MysqlMessage) { // save Raw message trans.Request_raw = msg.Query trans.BytesIn = msg.Size - - if trans.timer != nil { - trans.timer.Stop() - } - trans.timer = time.AfterFunc(TransactionTimeout, func() { mysql.expireTransaction(trans) }) } func (mysql *Mysql) receivedMysqlResponse(msg *MysqlMessage) { - tuple := msg.TcpTuple - trans := mysql.transactionsMap[tuple.Hashable()] + trans := mysql.getTransaction(msg.TcpTuple.Hashable()) if trans == nil { logp.Warn("Response from unknown transaction. Ignoring.") return @@ -670,23 +664,10 @@ func (mysql *Mysql) receivedMysqlResponse(msg *MysqlMessage) { trans.Notes = append(trans.Notes, msg.Notes...) mysql.publishTransaction(trans) + mysql.transactions.Delete(trans.tuple.Hashable()) logp.Debug("mysql", "Mysql transaction completed: %s", trans.Mysql) logp.Debug("mysql", "%s", trans.Response_raw) - - trans.Notes = append(trans.Notes, msg.Notes...) - - // remove from map - delete(mysql.transactionsMap, trans.tuple.Hashable()) - if trans.timer != nil { - trans.timer.Stop() - } -} - -func (mysql *Mysql) expireTransaction(trans *MysqlTransaction) { - // TODO: Here we need to PUBLISH an incomplete/timeout transaction - // remove from map - delete(mysql.transactionsMap, trans.tuple.Hashable()) } func (mysql *Mysql) parseMysqlResponse(data []byte) ([]string, [][]string) { diff --git a/protos/pgsql/pgsql.go b/protos/pgsql/pgsql.go index 14ccc3ec8bbb..87e2672d4815 100644 --- a/protos/pgsql/pgsql.go +++ b/protos/pgsql/pgsql.go @@ -60,8 +60,6 @@ type PgsqlTransaction struct { Request_raw string Response_raw string - - timer *time.Timer } type PgsqlStream struct { @@ -77,11 +75,6 @@ type PgsqlStream struct { message *PgsqlMessage } -const ( - TransactionsHashSize = 2 ^ 16 - TransactionTimeout = 10 * 1e9 -) - const ( PgsqlStartState = iota PgsqlGetDataState @@ -102,14 +95,22 @@ type Pgsql struct { Send_request bool Send_response bool - transactionsMap map[common.HashableTcpTuple][]*PgsqlTransaction - results chan common.MapStr + transactions *common.Cache + results chan common.MapStr // function pointer for mocking handlePgsql func(pgsql *Pgsql, m *PgsqlMessage, tcp *common.TcpTuple, dir uint8, raw_msg []byte) } +func (pgsql *Pgsql) getTransaction(k common.HashableTcpTuple) []*PgsqlTransaction { + v := pgsql.transactions.Get(k) + if v != nil { + return v.([]*PgsqlTransaction) + } + return nil +} + func (pgsql *Pgsql) InitDefaults() { pgsql.maxRowLength = 1024 pgsql.maxStoreRows = 10 @@ -150,7 +151,9 @@ func (pgsql *Pgsql) Init(test_mode bool, results chan common.MapStr) error { } } - pgsql.transactionsMap = make(map[common.HashableTcpTuple][]*PgsqlTransaction, TransactionsHashSize) + pgsql.transactions = common.NewCache(protos.DefaultTransactionExpiration, + protos.DefaultTransactionHashSize) + pgsql.transactions.StartJanitor(protos.DefaultTransactionExpiration) pgsql.handlePgsql = handlePgsql pgsql.results = results @@ -812,8 +815,9 @@ func (pgsql *Pgsql) receivedPgsqlRequest(msg *PgsqlMessage) { logp.Debug("pgsqldetailed", "Queries (%d) :%s", len(queries), queries) - if pgsql.transactionsMap[tuple.Hashable()] == nil { - pgsql.transactionsMap[tuple.Hashable()] = []*PgsqlTransaction{} + transList := pgsql.getTransaction(tuple.Hashable()) + if transList == nil { + transList = []*PgsqlTransaction{} } for _, query := range queries { @@ -846,27 +850,22 @@ func (pgsql *Pgsql) receivedPgsqlRequest(msg *PgsqlMessage) { trans.Request_raw = query - if trans.timer != nil { - trans.timer.Stop() - } - trans.timer = time.AfterFunc(TransactionTimeout, func() { pgsql.expireTransaction(trans) }) - - pgsql.transactionsMap[tuple.Hashable()] = append(pgsql.transactionsMap[tuple.Hashable()], trans) + transList = append(transList, trans) } + pgsql.transactions.Put(tuple.Hashable(), transList) } func (pgsql *Pgsql) receivedPgsqlResponse(msg *PgsqlMessage) { tuple := msg.TcpTuple - trans_list := pgsql.transactionsMap[tuple.Hashable()] - - if trans_list == nil || len(trans_list) == 0 { + transList := pgsql.getTransaction(tuple.Hashable()) + if transList == nil || len(transList) == 0 { logp.Warn("Response from unknown transaction. Ignoring.") return } // extract the first transaction from the array - trans := pgsql.removeTransaction(tuple, 0) + trans := pgsql.removeTransaction(transList, tuple, 0) // check if the request was received if trans.Pgsql == nil { @@ -892,10 +891,6 @@ func (pgsql *Pgsql) receivedPgsqlResponse(msg *PgsqlMessage) { pgsql.publishTransaction(trans) logp.Debug("pgsql", "Postgres transaction completed: %s\n%s", trans.Pgsql, trans.Response_raw) - - if trans.timer != nil { - trans.timer.Stop() - } } func (pgsql *Pgsql) publishTransaction(t *PgsqlTransaction) { @@ -936,29 +931,15 @@ func (pgsql *Pgsql) publishTransaction(t *PgsqlTransaction) { pgsql.results <- event } -func (pgsql *Pgsql) expireTransaction(trans *PgsqlTransaction) { - // TODO: Here we need to PUBLISH an incomplete/timeout transaction - // remove from map - for i, t := range pgsql.transactionsMap[trans.tuple.Hashable()] { - if t == trans { - pgsql.removeTransaction(trans.tuple, i) - break - } - } - if len(pgsql.transactionsMap[trans.tuple.Hashable()]) == 0 { - delete(pgsql.transactionsMap, trans.tuple.Hashable()) - } -} - -func (pgsql *Pgsql) removeTransaction(tuple common.TcpTuple, index int) *PgsqlTransaction { +func (pgsql *Pgsql) removeTransaction(transList []*PgsqlTransaction, + tuple common.TcpTuple, index int) *PgsqlTransaction { - trans_list := pgsql.transactionsMap[tuple.Hashable()] - trans := trans_list[index] - trans_list = append(trans_list[:index], trans_list[index+1:]...) - if len(trans_list) == 0 { - delete(pgsql.transactionsMap, trans.tuple.Hashable()) + trans := transList[index] + transList = append(transList[:index], transList[index+1:]...) + if len(transList) == 0 { + pgsql.transactions.Delete(trans.tuple.Hashable()) } else { - pgsql.transactionsMap[tuple.Hashable()] = trans_list + pgsql.transactions.Put(tuple.Hashable(), transList) } return trans diff --git a/protos/protos.go b/protos/protos.go index dec7e2d1b5c6..b2c232bfc502 100644 --- a/protos/protos.go +++ b/protos/protos.go @@ -11,6 +11,11 @@ import ( "github.com/elastic/libbeat/logp" ) +const ( + DefaultTransactionHashSize = 2 ^ 16 + DefaultTransactionExpiration time.Duration = 10 * time.Second +) + // ProtocolData interface to represent an upper // protocol private data. Used with types like // HttpStream, MysqlStream, etc. diff --git a/protos/redis/redis.go b/protos/redis/redis.go index 334b93bff6f1..08f15a341beb 100644 --- a/protos/redis/redis.go +++ b/protos/redis/redis.go @@ -74,8 +74,6 @@ type RedisTransaction struct { Request_raw string Response_raw string - - timer *time.Timer } // Keep sorted for future command addition @@ -240,22 +238,25 @@ var RedisCommands = map[string]struct{}{ "ZUNIONSTORE": struct{}{}, } -const ( - TransactionsHashSize = 2 ^ 16 - TransactionTimeout = 10 * 1e9 -) - type Redis struct { // config Ports []int Send_request bool Send_response bool - transactionsMap map[common.HashableTcpTuple]*RedisTransaction + transactions *common.Cache results chan common.MapStr } +func (redis *Redis) getTransaction(k common.HashableTcpTuple) *RedisTransaction { + v := redis.transactions.Get(k) + if v != nil { + return v.(*RedisTransaction) + } + return nil +} + func (redis *Redis) InitDefaults() { redis.Send_request = false redis.Send_response = false @@ -284,7 +285,9 @@ func (redis *Redis) Init(test_mode bool, results chan common.MapStr) error { redis.setFromConfig(config.ConfigSingleton.Protocols.Redis) } - redis.transactionsMap = make(map[common.HashableTcpTuple]*RedisTransaction, TransactionsHashSize) + redis.transactions = common.NewCache(protos.DefaultTransactionExpiration, + protos.DefaultTransactionHashSize) + redis.transactions.StartJanitor(protos.DefaultTransactionExpiration) redis.results = results return nil @@ -592,17 +595,15 @@ func (redis *Redis) handleRedis(m *RedisMessage, tcptuple *common.TcpTuple, } func (redis *Redis) receivedRedisRequest(msg *RedisMessage) { - // Add it to the HT tuple := msg.TcpTuple - - trans := redis.transactionsMap[tuple.Hashable()] + trans := redis.getTransaction(tuple.Hashable()) if trans != nil { if trans.Redis != nil { logp.Warn("Two requests without a Response. Dropping old request") } } else { trans = &RedisTransaction{Type: "redis", tuple: tuple} - redis.transactionsMap[tuple.Hashable()] = trans + redis.transactions.Put(tuple.Hashable(), trans) } trans.Redis = common.MapStr{} @@ -629,24 +630,11 @@ func (redis *Redis) receivedRedisRequest(msg *RedisMessage) { if msg.Direction == tcp.TcpDirectionReverse { trans.Src, trans.Dst = trans.Dst, trans.Src } - - if trans.timer != nil { - trans.timer.Stop() - } - trans.timer = time.AfterFunc(TransactionTimeout, func() { redis.expireTransaction(trans) }) - -} - -func (redis *Redis) expireTransaction(trans *RedisTransaction) { - - // remove from map - delete(redis.transactionsMap, trans.tuple.Hashable()) } func (redis *Redis) receivedRedisResponse(msg *RedisMessage) { - tuple := msg.TcpTuple - trans := redis.transactionsMap[tuple.Hashable()] + trans := redis.getTransaction(tuple.Hashable()) if trans == nil { logp.Warn("Response from unknown transaction. Ignoring.") return @@ -671,15 +659,9 @@ func (redis *Redis) receivedRedisResponse(msg *RedisMessage) { trans.ResponseTime = int32(msg.Ts.Sub(trans.ts).Nanoseconds() / 1e6) // resp_time in milliseconds redis.publishTransaction(trans) + redis.transactions.Delete(trans.tuple.Hashable()) logp.Debug("redis", "Redis transaction completed: %s", trans.Redis) - - // remove from map - delete(redis.transactionsMap, trans.tuple.Hashable()) - if trans.timer != nil { - trans.timer.Stop() - } - } func (redis *Redis) GapInStream(tcptuple *common.TcpTuple, dir uint8, diff --git a/protos/tcp/tcp.go b/protos/tcp/tcp.go index 47ad2fbdb2df..f6ef1ae8fa80 100644 --- a/protos/tcp/tcp.go +++ b/protos/tcp/tcp.go @@ -2,7 +2,6 @@ package tcp import ( "fmt" - "time" "github.com/elastic/libbeat/common" "github.com/elastic/libbeat/logp" @@ -12,8 +11,6 @@ import ( "github.com/tsg/gopacket/layers" ) -const TCP_STREAM_EXPIRY = 10 * 1e9 -const TCP_STREAM_HASH_SIZE = 2 ^ 16 const TCP_MAX_DATA_IN_STREAM = 10 * 1e6 const ( @@ -22,10 +19,10 @@ const ( ) type Tcp struct { - id uint32 - streamsMap map[common.HashableIpPortTuple]*TcpStream - portMap map[uint16]protos.Protocol - protocols protos.Protocols + id uint32 + streams *common.Cache + portMap map[uint16]protos.Protocol + protocols protos.Protocols } type Processor interface { @@ -51,10 +48,17 @@ func (tcp *Tcp) decideProtocol(tuple *common.IpPortTuple) protos.Protocol { return protos.UnknownProtocol } +func (tcp *Tcp) getStream(k common.HashableIpPortTuple) *TcpStream { + v := tcp.streams.Get(k) + if v != nil { + return v.(*TcpStream) + } + return nil +} + type TcpStream struct { id uint32 tuple *common.IpPortTuple - timer *time.Timer protocol protos.Protocol tcptuple common.TcpTuple tcp *Tcp @@ -62,54 +66,42 @@ type TcpStream struct { lastSeq [2]uint32 // protocols private data - Data protos.ProtocolData + data protos.ProtocolData } -func (stream *TcpStream) AddPacket(pkt *protos.Packet, tcphdr *layers.TCP, original_dir uint8) { - - // create/reset timer - if stream.timer != nil { - stream.timer.Stop() - } - stream.timer = time.AfterFunc(TCP_STREAM_EXPIRY, func() { stream.Expire() }) +func (stream *TcpStream) String() string { + return fmt.Sprintf("TcpStream id[%d] tuple[%s] protocol[%s] lastSeq[%d %d]", + stream.id, stream.tuple, stream.protocol, stream.lastSeq[0], stream.lastSeq[1]) +} +func (stream *TcpStream) addPacket(pkt *protos.Packet, tcphdr *layers.TCP, original_dir uint8) { mod := stream.tcp.protocols.GetTcp(stream.protocol) if mod == nil { - logp.Debug("tcp", "Ignoring protocol for which we have no module loaded: %s", stream.protocol) + logp.Debug("tcp", "Ignoring protocol for which we have no module "+ + "loaded: %s", stream.protocol) return } if len(pkt.Payload) > 0 { - stream.Data = mod.Parse(pkt, &stream.tcptuple, original_dir, stream.Data) + stream.data = mod.Parse(pkt, &stream.tcptuple, original_dir, stream.data) } if tcphdr.FIN { - stream.Data = mod.ReceivedFin(&stream.tcptuple, original_dir, stream.Data) + stream.data = mod.ReceivedFin(&stream.tcptuple, original_dir, stream.data) } } -func (stream *TcpStream) GapInStream(original_dir uint8, nbytes int) (drop bool) { +func (stream *TcpStream) gapInStream(original_dir uint8, nbytes int) (drop bool) { mod := stream.tcp.protocols.GetTcp(stream.protocol) - stream.Data, drop = mod.GapInStream(&stream.tcptuple, original_dir, nbytes, stream.Data) + stream.data, drop = mod.GapInStream(&stream.tcptuple, original_dir, nbytes, stream.data) return drop } -func (stream *TcpStream) Expire() { - - logp.Debug("mem", "Tcp stream expired") - - // de-register from dict - delete(stream.tcp.streamsMap, stream.tuple.Hashable()) - - // nullify to help the GC - stream.Data = nil -} - -func TcpSeqBefore(seq1 uint32, seq2 uint32) bool { +func tcpSeqBefore(seq1 uint32, seq2 uint32) bool { return int32(seq1-seq2) < 0 } -func TcpSeqBeforeEq(seq1 uint32, seq2 uint32) bool { +func tcpSeqBeforeEq(seq1 uint32, seq2 uint32) bool { return int32(seq1-seq2) <= 0 } @@ -119,12 +111,12 @@ func (tcp *Tcp) Process(tcphdr *layers.TCP, pkt *protos.Packet) { // protocol modules. defer logp.Recover("Process tcp exception") - stream, exists := tcp.streamsMap[pkt.Tuple.Hashable()] + stream := tcp.getStream(pkt.Tuple.Hashable()) var original_dir uint8 = TcpDirectionOriginal created := false - if !exists { - stream, exists = tcp.streamsMap[pkt.Tuple.RevHashable()] - if !exists { + if stream == nil { + stream = tcp.getStream(pkt.Tuple.RevHashable()) + if stream == nil { protocol := tcp.decideProtocol(&pkt.Tuple) if protocol == protos.UnknownProtocol { // don't follow @@ -135,7 +127,7 @@ func (tcp *Tcp) Process(tcphdr *layers.TCP, pkt *protos.Packet) { // create stream = &TcpStream{id: tcp.getId(), tuple: &pkt.Tuple, protocol: protocol, tcp: tcp} stream.tcptuple = common.TcpTupleFromIpPort(stream.tuple, stream.id) - tcp.streamsMap[pkt.Tuple.Hashable()] = stream + tcp.streams.Put(pkt.Tuple.Hashable(), stream) created = true } else { original_dir = TcpDirectionReverse @@ -150,38 +142,28 @@ func (tcp *Tcp) Process(tcphdr *layers.TCP, pkt *protos.Packet) { if len(pkt.Payload) > 0 && stream.lastSeq[original_dir] != 0 { - if TcpSeqBeforeEq(tcp_seq, stream.lastSeq[original_dir]) { + if tcpSeqBeforeEq(tcp_seq, stream.lastSeq[original_dir]) { logp.Debug("tcp", "Ignoring what looks like a retrasmitted segment. pkt.seq=%v len=%v stream.seq=%v", tcphdr.Seq, len(pkt.Payload), stream.lastSeq[original_dir]) return } - if TcpSeqBefore(stream.lastSeq[original_dir], tcp_start_seq) { + if tcpSeqBefore(stream.lastSeq[original_dir], tcp_start_seq) { if !created { logp.Debug("tcp", "Gap in tcp stream. last_seq: %d, seq: %d", stream.lastSeq[original_dir], tcp_start_seq) - drop := stream.GapInStream(original_dir, + drop := stream.gapInStream(original_dir, int(tcp_start_seq-stream.lastSeq[original_dir])) if drop { logp.Debug("tcp", "Dropping stream because of gap") - stream.Expire() + tcp.streams.Delete(stream.tuple.Hashable()) } } } } stream.lastSeq[original_dir] = tcp_seq - stream.AddPacket(pkt, tcphdr, original_dir) -} - -func (tcp *Tcp) PrintTcpMap() { - fmt.Printf("Streams in memory:") - for _, stream := range tcp.streamsMap { - fmt.Printf(" %d", stream.id) - } - fmt.Printf("\n") - - fmt.Printf("Streams dict: %v", tcp.streamsMap) + stream.addPacket(pkt, tcphdr, original_dir) } func buildPortsMap(plugins map[protos.Protocol]protos.TcpProtocolPlugin) (map[uint16]protos.Protocol, error) { @@ -211,8 +193,13 @@ func NewTcp(p protos.Protocols) (*Tcp, error) { return nil, err } - tcp := &Tcp{protocols: p, portMap: portMap} - tcp.streamsMap = make(map[common.HashableIpPortTuple]*TcpStream, TCP_STREAM_HASH_SIZE) + tcp := &Tcp{ + protocols: p, + portMap: portMap, + streams: common.NewCache(protos.DefaultTransactionExpiration, + protos.DefaultTransactionHashSize), + } + tcp.streams.StartJanitor(protos.DefaultTransactionExpiration) logp.Debug("tcp", "Port map: %v", portMap) return tcp, nil diff --git a/protos/tcp/tcp_test.go b/protos/tcp/tcp_test.go index 18e8c9e8b6a4..4b75d3b6559f 100644 --- a/protos/tcp/tcp_test.go +++ b/protos/tcp/tcp_test.go @@ -1,37 +1,50 @@ package tcp import ( + "math/rand" + "net" "testing" + "time" "github.com/elastic/libbeat/common" "github.com/elastic/packetbeat/protos" "github.com/stretchr/testify/assert" + "github.com/tsg/gopacket/layers" +) + +// Test Constants +const ( + ServerIp = "192.168.0.1" + ServerPort = 12345 + ClientIp = "10.0.0.1" ) type TestProtocol struct { Ports []int } -func (proto *TestProtocol) Init(test_mode bool, results chan common.MapStr) error { +var _ protos.ProtocolPlugin = &TestProtocol{} + +func (proto TestProtocol) Init(test_mode bool, results chan common.MapStr) error { return nil } -func (proto *TestProtocol) GetPorts() []int { +func (proto TestProtocol) GetPorts() []int { return proto.Ports } -func (proto *TestProtocol) Parse(pkt *protos.Packet, tcptuple *common.TcpTuple, +func (proto TestProtocol) Parse(pkt *protos.Packet, tcptuple *common.TcpTuple, dir uint8, private protos.ProtocolData) protos.ProtocolData { return private } -func (proto *TestProtocol) ReceivedFin(tcptuple *common.TcpTuple, dir uint8, +func (proto TestProtocol) ReceivedFin(tcptuple *common.TcpTuple, dir uint8, private protos.ProtocolData) protos.ProtocolData { return private } -func (proto *TestProtocol) GapInStream(tcptuple *common.TcpTuple, dir uint8, +func (proto TestProtocol) GapInStream(tcptuple *common.TcpTuple, dir uint8, nbytes int, private protos.ProtocolData) (priv protos.ProtocolData, drop bool) { return private, true } @@ -114,3 +127,44 @@ func Test_configToPortsMap_negative(t *testing.T) { assert.Contains(t, err.Error(), test.Err) } } + +// Mock protos.Protocols used for testing the tcp package. +type protocols struct { + tcp map[protos.Protocol]protos.TcpProtocolPlugin +} + +// Verify protocols implements the protos.Protocols interface. +var _ protos.Protocols = &protocols{} + +func (p protocols) BpfFilter(with_vlans bool) string { return "" } +func (p protocols) GetTcp(proto protos.Protocol) protos.TcpProtocolPlugin { return p.tcp[proto] } +func (p protocols) GetUdp(proto protos.Protocol) protos.UdpProtocolPlugin { return nil } +func (p protocols) GetAll() map[protos.Protocol]protos.ProtocolPlugin { return nil } +func (p protocols) GetAllTcp() map[protos.Protocol]protos.TcpProtocolPlugin { return p.tcp } +func (p protocols) GetAllUdp() map[protos.Protocol]protos.UdpProtocolPlugin { return nil } +func (p protocols) Register(proto protos.Protocol, plugin protos.ProtocolPlugin) { return } + +// Benchmark that runs with parallelism to help find concurrency related +// issues. To run with parallelism, the 'go test' cpu flag must be set +// greater than 1, otherwise it just runs concurrently but not in parallel. +func BenchmarkParallelProcess(b *testing.B) { + rand.Seed(18) + p := protocols{} + p.tcp = make(map[protos.Protocol]protos.TcpProtocolPlugin) + p.tcp[1] = TestProtocol{Ports: []int{ServerPort}} + tcp, _ := NewTcp(p) + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + pkt := &protos.Packet{ + Ts: time.Now(), + Tuple: common.NewIpPortTuple(4, + net.ParseIP(ServerIp), ServerPort, + net.ParseIP(ClientIp), uint16(rand.Intn(65535))), + Payload: []byte{1, 2, 3, 4}, + } + tcp.Process(&layers.TCP{}, pkt) + } + }) +} diff --git a/protos/thrift/thrift.go b/protos/thrift/thrift.go index d633a2e7b6c2..b270968eb621 100644 --- a/protos/thrift/thrift.go +++ b/protos/thrift/thrift.go @@ -81,15 +81,8 @@ type ThriftTransaction struct { Request *ThriftMessage Reply *ThriftMessage - - timer *time.Timer } -const ( - TransactionsHashSize = 2 ^ 16 - TransactionTimeout = 10 * 1e9 -) - const ( ThriftStartState = iota ThriftFieldState @@ -156,7 +149,7 @@ type Thrift struct { TransportType byte ProtocolType byte - transMap map[common.HashableTcpTuple]*ThriftTransaction + transactions *common.Cache PublishQueue chan *ThriftTransaction results chan common.MapStr @@ -165,6 +158,14 @@ type Thrift struct { var ThriftMod Thrift +func (thrift *Thrift) getTransaction(k common.HashableTcpTuple) *ThriftTransaction { + v := thrift.transactions.Get(k) + if v != nil { + return v.(*ThriftTransaction) + } + return nil +} + func (thrift *Thrift) InitDefaults() { // defaults thrift.StringMaxSize = 200 @@ -248,7 +249,9 @@ func (thrift *Thrift) Init(test_mode bool, results chan common.MapStr) error { } } - thrift.transMap = make(map[common.HashableTcpTuple]*ThriftTransaction, TransactionsHashSize) + thrift.transactions = common.NewCache(protos.DefaultTransactionExpiration, + protos.DefaultTransactionHashSize) + thrift.transactions.StartJanitor(protos.DefaultTransactionExpiration) if !test_mode { thrift.PublishQueue = make(chan *ThriftTransaction, 1000) @@ -954,7 +957,7 @@ func (thrift *Thrift) handleThrift(msg *ThriftMessage) { func (thrift *Thrift) receivedRequest(msg *ThriftMessage) { tuple := msg.TcpTuple - trans := thrift.transMap[tuple.Hashable()] + trans := thrift.getTransaction(tuple.Hashable()) if trans != nil { logp.Debug("thrift", "Two requests without reply, assuming the old one is oneway") thrift.PublishQueue <- trans @@ -964,7 +967,7 @@ func (thrift *Thrift) receivedRequest(msg *ThriftMessage) { Type: "thrift", tuple: tuple, } - thrift.transMap[tuple.Hashable()] = trans + thrift.transactions.Put(tuple.Hashable(), trans) trans.ts = msg.Ts trans.Ts = int64(trans.ts.UnixNano() / 1000) @@ -985,12 +988,6 @@ func (thrift *Thrift) receivedRequest(msg *ThriftMessage) { trans.Request = msg trans.BytesIn = uint64(msg.FrameSize) - - if trans.timer != nil { - trans.timer.Stop() - } - trans.timer = time.AfterFunc(TransactionTimeout, func() { thrift.expireTransaction(trans) }) - } func (thrift *Thrift) receivedReply(msg *ThriftMessage) { @@ -998,7 +995,7 @@ func (thrift *Thrift) receivedReply(msg *ThriftMessage) { // we need to search the request first. tuple := msg.TcpTuple - trans := thrift.transMap[tuple.Hashable()] + trans := thrift.getTransaction(tuple.Hashable()) if trans == nil { logp.Debug("thrift", "Response from unknown transaction. Ignoring: %v", tuple) return @@ -1016,28 +1013,20 @@ func (thrift *Thrift) receivedReply(msg *ThriftMessage) { trans.ResponseTime = int32(msg.Ts.Sub(trans.ts).Nanoseconds() / 1e6) // resp_time in milliseconds thrift.PublishQueue <- trans + thrift.transactions.Delete(tuple.Hashable()) logp.Debug("thrift", "Transaction queued") - - // remove from map - thrift.transMap[tuple.Hashable()] = nil - if trans.timer != nil { - trans.timer.Stop() - } } func (thrift *Thrift) ReceivedFin(tcptuple *common.TcpTuple, dir uint8, private protos.ProtocolData) protos.ProtocolData { - trans := thrift.transMap[tcptuple.Hashable()] + trans := thrift.getTransaction(tcptuple.Hashable()) if trans != nil { if trans.Request != nil && trans.Reply == nil { logp.Debug("thrift", "FIN and had only one transaction. Assuming one way") thrift.PublishQueue <- trans - delete(thrift.transMap, trans.tuple.Hashable()) - if trans.timer != nil { - trans.timer.Stop() - } + thrift.transactions.Delete(trans.tuple.Hashable()) } } @@ -1140,9 +1129,3 @@ func (thrift *Thrift) publishTransactions() { logp.Debug("thrift", "Published event") } } - -func (thrift *Thrift) expireTransaction(trans *ThriftTransaction) { - // TODO - also publish? - // remove from map - delete(thrift.transMap, trans.tuple.Hashable()) -}