Skip to content

Commit

Permalink
introduce: NContext
Browse files Browse the repository at this point in the history
  • Loading branch information
tarunKoyalwar committed Oct 31, 2023
1 parent 187efa4 commit 8de7e1d
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 0 deletions.
78 changes: 78 additions & 0 deletions context/NContext.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package contextutil

import "context"

// A problematic situation when implementing
// context in a function is when there is more than one value is returned by the function
// if function has only one return value we can safely wrap it something like this
/*
func DoSomething() error {}
ch := make(chan error)
go func() {
ch <- DoSomething()
}()
select {
case err := <-ch:
// handle error
case <-ctx.Done():
// handle context cancelation
}
*/
// but what if we have more than one value to return? we can use generics and a struct
// and that is what we are doing here
// unfortunately there is no such thing as variadic return unless we put it in a slice
// we will have to use structs with 2,3 fields

type twoValueCtx[T1 any, T2 any] struct {
var1 T1
var2 T2
}

type threeValueCtx[T1 any, T2 any, T3 any] struct {
var1 T1
var2 T2
var3 T3
}

// ExecFuncWithTwoReturns wraps a function which has two return values given that last one is error
// and executes that function in a goroutine there by implementing context
// if context is cancelled before function returns it will return context error
// otherwise it will return function's return values
func ExecFuncWithTwoReturns[T1 any](ctx context.Context, fn func() (T1, error)) (T1, error) {
var (
ch = make(chan twoValueCtx[T1, error])
)
go func() {
x, y := fn()
ch <- twoValueCtx[T1, error]{var1: x, var2: y}
}()
select {
case <-ctx.Done():
var tmp T1
return tmp, ctx.Err()
case v := <-ch:
return v.var1, v.var2
}
}

// ExecFuncWithThreeReturns wraps a function which has three return values given that last one is error
// and executes that function in a goroutine there by implementing context
// if context is cancelled before function returns it will return context error
// otherwise it will return function's return values
func ExecFuncWithThreeReturns[T1 any, T2 any](ctx context.Context, fn func() (T1, T2, error)) (T1, T2, error) {
var (
ch = make(chan threeValueCtx[T1, T2, error])
)
go func() {
x, y, z := fn()
ch <- threeValueCtx[T1, T2, error]{var1: x, var2: y, var3: z}
}()
select {
case <-ctx.Done():
var tmp1 T1
var tmp2 T2
return tmp1, tmp2, ctx.Err()
case v := <-ch:
return v.var1, v.var2, v.var3
}
}
80 changes: 80 additions & 0 deletions context/Ncontext_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package contextutil_test

import (
"context"
"errors"
"testing"
"time"

contextutil "github.com/projectdiscovery/utils/context"
)

func TestExecFuncWithTwoReturns(t *testing.T) {
t.Run("function completes before context cancellation", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

fn := func() (int, error) {
time.Sleep(1 * time.Second)
return 42, nil
}

val, err := contextutil.ExecFuncWithTwoReturns(ctx, fn)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if val != 42 {
t.Errorf("Unexpected return value: got %v, want 42", val)
}
})

t.Run("context cancelled before function completes", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()

fn := func() (int, error) {
time.Sleep(2 * time.Second)
return 42, nil
}

_, err := contextutil.ExecFuncWithTwoReturns(ctx, fn)
if !errors.Is(err, context.DeadlineExceeded) {
t.Errorf("Expected context deadline exceeded error, got: %v", err)
}
})
}

func TestExecFuncWithThreeReturns(t *testing.T) {
t.Run("function completes before context cancellation", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

fn := func() (int, string, error) {
time.Sleep(1 * time.Second)
return 42, "hello", nil
}

val1, val2, err := contextutil.ExecFuncWithThreeReturns(ctx, fn)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if val1 != 42 || val2 != "hello" {
t.Errorf("Unexpected return values: got %v and %v, want 42 and 'hello'", val1, val2)
}
})

t.Run("context cancelled before function completes", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()

fn := func() (int, string, error) {
time.Sleep(2 * time.Second)
return 42, "hello", nil
}

_, _, err := contextutil.ExecFuncWithThreeReturns(ctx, fn)
if !errors.Is(err, context.DeadlineExceeded) {
t.Errorf("Expected context deadline exceeded error, got: %v", err)
}
})
}

0 comments on commit 8de7e1d

Please sign in to comment.