From 5a7c5903f64efb240d575ef947b0ed1d59b1a3d0 Mon Sep 17 00:00:00 2001 From: Adam Date: Sun, 22 Sep 2019 11:55:16 +1000 Subject: [PATCH] Allow changing context in websocket init func --- docs/content/recipes/authentication.md | 57 ++++++++++++++++---------- handler/graphql.go | 4 +- handler/mock.go | 6 ++- handler/websocket.go | 4 +- handler/websocket_test.go | 41 +++++++++++++++--- 5 files changed, 82 insertions(+), 30 deletions(-) diff --git a/docs/content/recipes/authentication.md b/docs/content/recipes/authentication.md index e2b4a896a68..284330a9c08 100644 --- a/docs/content/recipes/authentication.md +++ b/docs/content/recipes/authentication.md @@ -114,29 +114,44 @@ func (r *queryResolver) Hero(ctx context.Context, episode Episode) (Character, e } ``` -Things are different with websockets, and if you do things in the vein of the above example, you have to compute this at every call to `auth.ForContext`. +### Websockets -```golang -// ForContext finds the user from the context. REQUIRES Middleware to have run. -func ForContext(ctx context.Context) *User { - raw, ok := ctx.Value(userCtxKey).(*User) - - if !ok { - payload := handler.GetInitPayload(ctx) - if payload == nil { - return nil - } - - userId, err := validateAndGetUserID(payload["token"]) - if err != nil { - return nil - } - - return getUserByID(db, userId) - } +If you need access to the websocket init payload we can do the same thing with the WebsocketInitFunc: - return raw +```go +func main() { + router := chi.NewRouter() + + router.Use(auth.Middleware(db)) + + router.Handle("/", handler.Playground("Starwars", "/query")) + router.Handle("/query", + handler.GraphQL(starwars.NewExecutableSchema(starwars.NewResolver())), + WebsocketInitFunc(func(ctx context.Context, initPayload InitPayload) (context.Context, error) { + userId, err := validateAndGetUserID(payload["token"]) + if err != nil { + return nil, err + } + + // get the user from the database + user := getUserByID(db, userId) + + // put it in context + userCtx := context.WithValue(r.Context(), userCtxKey, user) + + // and return it so the resolvers can see it + return userCtx, nil + })) + ) + + err := http.ListenAndServe(":8080", router) + if err != nil { + panic(err) + } } ``` -It's a bit inefficient if you have multiple calls to this function (e.g. on a field resolver), but what you might do to mitigate that is to have a session object set on the http request and only populate it upon the first check. \ No newline at end of file +> Note +> +> Subscriptions are long lived, if your tokens can timeout or need to be refreshed you should keep the token in +context too and verify it is still valid in `auth.ForContext`. diff --git a/handler/graphql.go b/handler/graphql.go index 289901f0f0f..fcccf9d7bc5 100644 --- a/handler/graphql.go +++ b/handler/graphql.go @@ -52,7 +52,7 @@ type PersistedQueryCache interface { Get(ctx context.Context, hash string) (string, bool) } -type websocketInitFunc func(ctx context.Context, initPayload InitPayload) error +type websocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error) type Config struct { cacheSize int @@ -278,7 +278,7 @@ func (tw *tracerWrapper) EndOperationExecution(ctx context.Context) { // WebsocketInitFunc is called when the server receives connection init message from the client. // This can be used to check initial payload to see whether to accept the websocket connection. -func WebsocketInitFunc(websocketInitFunc func(ctx context.Context, initPayload InitPayload) error) Option { +func WebsocketInitFunc(websocketInitFunc websocketInitFunc) Option { return func(cfg *Config) { cfg.websocketInitFunc = websocketInitFunc } diff --git a/handler/mock.go b/handler/mock.go index 3e70cf036bf..a3d73bd1902 100644 --- a/handler/mock.go +++ b/handler/mock.go @@ -9,6 +9,7 @@ import ( ) type executableSchemaMock struct { + QueryFunc func(ctx context.Context, op *ast.OperationDefinition) *graphql.Response MutationFunc func(ctx context.Context, op *ast.OperationDefinition) *graphql.Response } @@ -42,7 +43,10 @@ func (e *executableSchemaMock) Complexity(typeName, field string, childComplexit } func (e *executableSchemaMock) Query(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { - return graphql.ErrorResponse(ctx, "queries are not supported") + if e.QueryFunc == nil { + return graphql.ErrorResponse(ctx, "queries are not supported") + } + return e.QueryFunc(ctx, op) } func (e *executableSchemaMock) Mutation(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { diff --git a/handler/websocket.go b/handler/websocket.go index 0637cfb19e6..83aa0700926 100644 --- a/handler/websocket.go +++ b/handler/websocket.go @@ -95,11 +95,13 @@ func (c *wsConnection) init() bool { } if c.cfg.websocketInitFunc != nil { - if err := c.cfg.websocketInitFunc(c.ctx, c.initPayload); err != nil { + ctx, err := c.cfg.websocketInitFunc(c.ctx, c.initPayload) + if err != nil { c.sendConnectionError(err.Error()) c.close(websocket.CloseNormalClosure, "terminated") return false } + c.ctx = ctx } c.write(&operationMessage{Type: connectionAckMsg}) diff --git a/handler/websocket_test.go b/handler/websocket_test.go index 9040720f7bc..44aa65e4f7a 100644 --- a/handler/websocket_test.go +++ b/handler/websocket_test.go @@ -9,8 +9,12 @@ import ( "testing" "time" + "github.com/99designs/gqlgen/client" + "github.com/99designs/gqlgen/graphql" "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/vektah/gqlparser/ast" ) func TestWebsocket(t *testing.T) { @@ -182,10 +186,37 @@ func TestWebsocketInitFunc(t *testing.T) { require.Equal(t, connectionKeepAliveMsg, readOp(c).Type) }) + t.Run("can return context for request from WebsocketInitFunc", func(t *testing.T) { + es := &executableSchemaMock{ + QueryFunc: func(ctx context.Context, op *ast.OperationDefinition) *graphql.Response { + assert.Equal(t, "newvalue", ctx.Value("newkey")) + return &graphql.Response{Data: []byte(`{"empty":"ok"}`)} + }, + } + + h := GraphQL(es, + WebsocketInitFunc(func(ctx context.Context, initPayload InitPayload) (context.Context, error) { + return context.WithValue(ctx, "newkey", "newvalue"), nil + })) + + c := client.New(h) + + socket := c.Websocket("{ empty } ") + defer socket.Close() + var resp struct { + Empty string + } + err := socket.Next(&resp) + require.NoError(t, err) + require.Equal(t, "ok", resp.Empty) + }) + t.Run("accept connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) { - h := GraphQL(&executableSchemaStub{next}, WebsocketInitFunc(func(ctx context.Context, initPayload InitPayload) error { - return nil - })) + h := GraphQL(&executableSchemaStub{next}, + WebsocketInitFunc(func(ctx context.Context, initPayload InitPayload) (context.Context, error) { + return context.WithValue(ctx, "newkey", "newvalue"), nil + }), + ) srv := httptest.NewServer(h) defer srv.Close() @@ -199,8 +230,8 @@ func TestWebsocketInitFunc(t *testing.T) { }) t.Run("reject connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) { - h := GraphQL(&executableSchemaStub{next}, WebsocketInitFunc(func(ctx context.Context, initPayload InitPayload) error { - return errors.New("invalid init payload") + h := GraphQL(&executableSchemaStub{next}, WebsocketInitFunc(func(ctx context.Context, initPayload InitPayload) (context.Context, error) { + return ctx, errors.New("invalid init payload") })) srv := httptest.NewServer(h) defer srv.Close()