Skip to content

Commit

Permalink
Merge pull request #240 from projectdiscovery/introduce_func_trace
Browse files Browse the repository at this point in the history
introduce func trace
  • Loading branch information
Mzack9999 authored Sep 11, 2023
2 parents bf26936 + 1f98168 commit 27c68d9
Show file tree
Hide file tree
Showing 2 changed files with 286 additions and 0 deletions.
177 changes: 177 additions & 0 deletions trace/trace.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
package trace

import (
"errors"
"math"
"runtime"
"sync"
"time"

"github.com/projectdiscovery/utils/generic"
)

const (
// DefaultMemorySnapshotInterval is the default interval for taking memory snapshots
DefaultMemorySnapshotInterval = 100 * time.Millisecond
)

type MemorySnapshot struct {
Time time.Time
Alloc uint64
}

type Metrics struct {
StartTime time.Time
FinishTime time.Time
ExecutionDuration time.Duration
Snapshots []MemorySnapshot
MinAllocMemory uint64
MaxAllocMemory uint64
AvgAllocMemory uint64
}

type FunctionContext struct {
strategy ActionStrategy
action func()
before func()
after func()
}

func (f *FunctionContext) Execute() {
if f.before != nil {
f.before()
}

f.strategy.Before()
f.action()
f.strategy.After()

if f.after != nil {
f.after()
}
}

type ActionStrategy interface {
Before()
After()
GetMetrics() *Metrics
}

type DefaultStrategy struct {
metrics generic.Lockable[*Metrics]
ticker *time.Ticker
done chan bool
wg sync.WaitGroup
}

func (d *DefaultStrategy) Before() {
d.metrics.Do(func(m *Metrics) {
m.StartTime = time.Now()

d.ticker = time.NewTicker(DefaultMemorySnapshotInterval)
d.done = make(chan bool)
d.wg.Add(1)
go func() {
defer d.wg.Done()
for {
select {
case <-d.done:
return
case t := <-d.ticker.C:
var mem runtime.MemStats
runtime.ReadMemStats(&mem)
m.Snapshots = append(m.Snapshots, MemorySnapshot{
Time: t,
Alloc: mem.Alloc,
})
}
}
}()
})
}

func (d *DefaultStrategy) After() {
close(d.done)
d.wg.Wait()
d.ticker.Stop()
d.metrics.Do(func(m *Metrics) {
m.FinishTime = time.Now()
m.ExecutionDuration = m.FinishTime.Sub(m.StartTime)

var totalMemory uint64 = 0
if len(m.Snapshots) > 0 {
m.MinAllocMemory = m.Snapshots[0].Alloc
m.MaxAllocMemory = m.Snapshots[0].Alloc

for _, s := range m.Snapshots {
if s.Alloc < m.MinAllocMemory {
m.MinAllocMemory = s.Alloc
}
m.MinAllocMemory = uint64(math.Min(float64(m.MinAllocMemory), float64(s.Alloc)))
m.MaxAllocMemory = uint64(math.Max(float64(m.MaxAllocMemory), float64(s.Alloc)))
totalMemory += s.Alloc
}
m.AvgAllocMemory = totalMemory / uint64(len(m.Snapshots))
}
})

}

func (d *DefaultStrategy) GetMetrics() *Metrics {
var metrics *Metrics
d.metrics.Do(func(m *Metrics) {
metrics = m
})
return metrics
}

type TraceOptions struct {
strategy ActionStrategy
before func()
after func()
}

type TraceOptionSetter func(opts *TraceOptions)

func WithStrategy(s ActionStrategy) TraceOptionSetter {
return func(opts *TraceOptions) {
opts.strategy = s
}
}

func WithBefore(b func()) TraceOptionSetter {
return func(opts *TraceOptions) {
opts.before = b
}
}

func WithAfter(a func()) TraceOptionSetter {
return func(opts *TraceOptions) {
opts.after = a
}
}

func Trace(f func(), setters ...TraceOptionSetter) (*Metrics, error) {
opts := &TraceOptions{
strategy: &DefaultStrategy{metrics: generic.Lockable[*Metrics]{V: &Metrics{}}},
}

// Apply option if provided
for _, setter := range setters {
setter(opts)
}

if opts.strategy == nil {
return nil, errors.New("strategy should not be nil")
}

context := &FunctionContext{
strategy: opts.strategy,
action: f,
before: opts.before,
after: opts.after,
}

context.Execute()
return opts.strategy.GetMetrics(), nil
}
109 changes: 109 additions & 0 deletions trace/trace_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package trace

import (
"testing"
"time"
)

func TestFunctionWithBeforeFunction(t *testing.T) {
var beforeCalled bool
_, _ = Trace(func() {
if !beforeCalled {
t.Errorf("Before function was not called before the main function")
}
}, WithBefore(func() {
beforeCalled = true
}))

if !beforeCalled {
t.Errorf("Before function was not called")
}
}

func TestFunctionWithAfterFunction(t *testing.T) {
var afterCalled bool
_, _ = Trace(func() {
if afterCalled {
t.Errorf("After function was called before the main function finished")
}
}, WithAfter(func() {
afterCalled = true
}))

if !afterCalled {
t.Errorf("After function was not called")
}
}

func TestFunctionTracing(t *testing.T) {
metrics, _ := Trace(func() {
time.Sleep(2 * time.Second)
})

if metrics.ExecutionDuration.Seconds() < 2 {
t.Errorf("ExecutionDuration is less than expected: %v", metrics.ExecutionDuration)
}

if len(metrics.Snapshots) == 0 {
t.Errorf("Memory snapshots are not captured")
}

if metrics.MinAllocMemory == 0 {
t.Errorf("MinMemory not computed")
}

if metrics.MaxAllocMemory == 0 {
t.Errorf("MaxMemory not computed")
}

if metrics.AvgAllocMemory == 0 {
t.Errorf("AvgMemory not computed")
}
}

func TestFunctionWithCustomStrategy(t *testing.T) {
var customLogs []string
metrics, _ := Trace(func() {
time.Sleep(1 * time.Second)
}, WithStrategy(&CustomStrategy{metrics: &Metrics{}, logs: &customLogs}))

if len(customLogs) != 2 {
t.Errorf("Custom logs not captured as expected")
}

if customLogs[0] != "Custom Before method started." {
t.Errorf("Expected custom log for Before method not found")
}

if customLogs[1] != "Custom After method executed." {
t.Errorf("Expected custom log for After method not found")
}

if metrics.ExecutionDuration.Seconds() < 1 {
t.Errorf("ExecutionDuration is less than expected: %v", metrics.ExecutionDuration)
}

if len(metrics.Snapshots) != 0 {
t.Errorf("Custom strategy should not capture snapshots")
}
}

type CustomStrategy struct {
metrics *Metrics
logs *[]string
}

func (c *CustomStrategy) Before() {
*c.logs = append(*c.logs, "Custom Before method started.")
c.metrics.StartTime = time.Now()
}

func (c *CustomStrategy) After() {
*c.logs = append(*c.logs, "Custom After method executed.")
c.metrics.FinishTime = time.Now()
c.metrics.ExecutionDuration = c.metrics.FinishTime.Sub(c.metrics.StartTime)
}

func (c *CustomStrategy) GetMetrics() *Metrics {
return c.metrics
}

0 comments on commit 27c68d9

Please sign in to comment.