diff --git a/internal/orchestrion/context.go b/internal/orchestrion/context.go index d5fda50f6d..f66d12c576 100644 --- a/internal/orchestrion/context.go +++ b/internal/orchestrion/context.go @@ -30,13 +30,15 @@ func WrapContext(ctx context.Context) context.Context { // CtxWithValue runs context.WithValue, adds the result to the GLS slot of orchestrion, and returns it. // If orchestrion is not enabled, it will run context.WithValue and return the result. +// Since we don't support cross-goroutine switch of the GLS we still run context.WithValue in the case +// we are switching goroutines. func CtxWithValue(parent context.Context, key, val any) context.Context { if !Enabled() { return context.WithValue(parent, key, val) } getDDContextStack().Push(key, val) - return WrapContext(parent) + return context.WithValue(WrapContext(parent), key, val) } // GLSPopValue pops the value from the GLS slot of orchestrion and returns it. Using context.Context values usually does diff --git a/internal/orchestrion/context_test.go b/internal/orchestrion/context_test.go index 3c70623484..9bbf48cd7b 100644 --- a/internal/orchestrion/context_test.go +++ b/internal/orchestrion/context_test.go @@ -7,8 +7,9 @@ package orchestrion import ( "context" - "github.com/stretchr/testify/require" "testing" + + "github.com/stretchr/testify/require" ) type key string @@ -71,11 +72,18 @@ func TestCtxWithValue(t *testing.T) { t.Run("true", func(t *testing.T) { enabled = true ctx := CtxWithValue(context.Background(), key("key"), "value") - require.Equal(t, &glsContext{context.Background()}, ctx) + require.Equal(t, context.WithValue(&glsContext{context.Background()}, key("key"), "value"), ctx) require.Equal(t, "value", ctx.Value(key("key"))) require.Equal(t, "value", getDDContextStack().Peek(key("key"))) require.Equal(t, "value", GLSPopValue(key("key"))) require.Nil(t, getDDContextStack().Peek(key("key"))) - require.Nil(t, ctx.Value(key("key"))) + }) + + t.Run("cross-goroutine switch", func(t *testing.T) { + enabled = true + ctx := CtxWithValue(context.Background(), key("key"), "value") + go func() { + require.Equal(t, "value", ctx.Value(key("key"))) + }() }) }