From 1f981682e046c35c522046c41041ed6ba5c150ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Do=C4=9Fan=20Can=20Bak=C4=B1r?= Date: Tue, 5 Sep 2023 11:43:13 +0000 Subject: [PATCH] add before and after func --- trace/trace.go | 26 ++++++++++++++++++++++++++ trace/trace_test.go | 30 ++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+) diff --git a/trace/trace.go b/trace/trace.go index 9ac2db9..f38bcf0 100644 --- a/trace/trace.go +++ b/trace/trace.go @@ -33,12 +33,22 @@ type Metrics struct { type FunctionContext struct { strategy ActionStrategy action func() + before func() + after func() } func (f *FunctionContext) Execute() { + if f.before != nil { + f.before() + } + f.strategy.Before() f.action() f.strategy.After() + + if f.after != nil { + f.after() + } } type ActionStrategy interface { @@ -117,6 +127,8 @@ func (d *DefaultStrategy) GetMetrics() *Metrics { type TraceOptions struct { strategy ActionStrategy + before func() + after func() } type TraceOptionSetter func(opts *TraceOptions) @@ -127,6 +139,18 @@ func WithStrategy(s ActionStrategy) TraceOptionSetter { } } +func WithBefore(b func()) TraceOptionSetter { + return func(opts *TraceOptions) { + opts.before = b + } +} + +func WithAfter(a func()) TraceOptionSetter { + return func(opts *TraceOptions) { + opts.after = a + } +} + func Trace(f func(), setters ...TraceOptionSetter) (*Metrics, error) { opts := &TraceOptions{ strategy: &DefaultStrategy{metrics: generic.Lockable[*Metrics]{V: &Metrics{}}}, @@ -144,6 +168,8 @@ func Trace(f func(), setters ...TraceOptionSetter) (*Metrics, error) { context := &FunctionContext{ strategy: opts.strategy, action: f, + before: opts.before, + after: opts.after, } context.Execute() diff --git a/trace/trace_test.go b/trace/trace_test.go index 659fef6..30add4a 100644 --- a/trace/trace_test.go +++ b/trace/trace_test.go @@ -5,6 +5,36 @@ import ( "time" ) +func TestFunctionWithBeforeFunction(t *testing.T) { + var beforeCalled bool + _, _ = Trace(func() { + if !beforeCalled { + t.Errorf("Before function was not called before the main function") + } + }, WithBefore(func() { + beforeCalled = true + })) + + if !beforeCalled { + t.Errorf("Before function was not called") + } +} + +func TestFunctionWithAfterFunction(t *testing.T) { + var afterCalled bool + _, _ = Trace(func() { + if afterCalled { + t.Errorf("After function was called before the main function finished") + } + }, WithAfter(func() { + afterCalled = true + })) + + if !afterCalled { + t.Errorf("After function was not called") + } +} + func TestFunctionTracing(t *testing.T) { metrics, _ := Trace(func() { time.Sleep(2 * time.Second)