-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
187efa4
commit 8de7e1d
Showing
2 changed files
with
158 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
package contextutil | ||
|
||
import "context" | ||
|
||
// A problematic situation when implementing | ||
// context in a function is when there is more than one value is returned by the function | ||
// if function has only one return value we can safely wrap it something like this | ||
/* | ||
func DoSomething() error {} | ||
ch := make(chan error) | ||
go func() { | ||
ch <- DoSomething() | ||
}() | ||
select { | ||
case err := <-ch: | ||
// handle error | ||
case <-ctx.Done(): | ||
// handle context cancelation | ||
} | ||
*/ | ||
// but what if we have more than one value to return? we can use generics and a struct | ||
// and that is what we are doing here | ||
// unfortunately there is no such thing as variadic return unless we put it in a slice | ||
// we will have to use structs with 2,3 fields | ||
|
||
type twoValueCtx[T1 any, T2 any] struct { | ||
var1 T1 | ||
var2 T2 | ||
} | ||
|
||
type threeValueCtx[T1 any, T2 any, T3 any] struct { | ||
var1 T1 | ||
var2 T2 | ||
var3 T3 | ||
} | ||
|
||
// ExecFuncWithTwoReturns wraps a function which has two return values given that last one is error | ||
// and executes that function in a goroutine there by implementing context | ||
// if context is cancelled before function returns it will return context error | ||
// otherwise it will return function's return values | ||
func ExecFuncWithTwoReturns[T1 any](ctx context.Context, fn func() (T1, error)) (T1, error) { | ||
var ( | ||
ch = make(chan twoValueCtx[T1, error]) | ||
) | ||
go func() { | ||
x, y := fn() | ||
ch <- twoValueCtx[T1, error]{var1: x, var2: y} | ||
}() | ||
select { | ||
case <-ctx.Done(): | ||
var tmp T1 | ||
return tmp, ctx.Err() | ||
case v := <-ch: | ||
return v.var1, v.var2 | ||
} | ||
} | ||
|
||
// ExecFuncWithThreeReturns wraps a function which has three return values given that last one is error | ||
// and executes that function in a goroutine there by implementing context | ||
// if context is cancelled before function returns it will return context error | ||
// otherwise it will return function's return values | ||
func ExecFuncWithThreeReturns[T1 any, T2 any](ctx context.Context, fn func() (T1, T2, error)) (T1, T2, error) { | ||
var ( | ||
ch = make(chan threeValueCtx[T1, T2, error]) | ||
) | ||
go func() { | ||
x, y, z := fn() | ||
ch <- threeValueCtx[T1, T2, error]{var1: x, var2: y, var3: z} | ||
}() | ||
select { | ||
case <-ctx.Done(): | ||
var tmp1 T1 | ||
var tmp2 T2 | ||
return tmp1, tmp2, ctx.Err() | ||
case v := <-ch: | ||
return v.var1, v.var2, v.var3 | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
package contextutil_test | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"testing" | ||
"time" | ||
|
||
contextutil "github.com/projectdiscovery/utils/context" | ||
) | ||
|
||
func TestExecFuncWithTwoReturns(t *testing.T) { | ||
t.Run("function completes before context cancellation", func(t *testing.T) { | ||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) | ||
defer cancel() | ||
|
||
fn := func() (int, error) { | ||
time.Sleep(1 * time.Second) | ||
return 42, nil | ||
} | ||
|
||
val, err := contextutil.ExecFuncWithTwoReturns(ctx, fn) | ||
if err != nil { | ||
t.Errorf("Unexpected error: %v", err) | ||
} | ||
if val != 42 { | ||
t.Errorf("Unexpected return value: got %v, want 42", val) | ||
} | ||
}) | ||
|
||
t.Run("context cancelled before function completes", func(t *testing.T) { | ||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) | ||
defer cancel() | ||
|
||
fn := func() (int, error) { | ||
time.Sleep(2 * time.Second) | ||
return 42, nil | ||
} | ||
|
||
_, err := contextutil.ExecFuncWithTwoReturns(ctx, fn) | ||
if !errors.Is(err, context.DeadlineExceeded) { | ||
t.Errorf("Expected context deadline exceeded error, got: %v", err) | ||
} | ||
}) | ||
} | ||
|
||
func TestExecFuncWithThreeReturns(t *testing.T) { | ||
t.Run("function completes before context cancellation", func(t *testing.T) { | ||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) | ||
defer cancel() | ||
|
||
fn := func() (int, string, error) { | ||
time.Sleep(1 * time.Second) | ||
return 42, "hello", nil | ||
} | ||
|
||
val1, val2, err := contextutil.ExecFuncWithThreeReturns(ctx, fn) | ||
if err != nil { | ||
t.Errorf("Unexpected error: %v", err) | ||
} | ||
if val1 != 42 || val2 != "hello" { | ||
t.Errorf("Unexpected return values: got %v and %v, want 42 and 'hello'", val1, val2) | ||
} | ||
}) | ||
|
||
t.Run("context cancelled before function completes", func(t *testing.T) { | ||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) | ||
defer cancel() | ||
|
||
fn := func() (int, string, error) { | ||
time.Sleep(2 * time.Second) | ||
return 42, "hello", nil | ||
} | ||
|
||
_, _, err := contextutil.ExecFuncWithThreeReturns(ctx, fn) | ||
if !errors.Is(err, context.DeadlineExceeded) { | ||
t.Errorf("Expected context deadline exceeded error, got: %v", err) | ||
} | ||
}) | ||
} |