Skip to content

Commit

Permalink
Allow changing context in websocket init func
Browse files Browse the repository at this point in the history
  • Loading branch information
vektah committed Sep 22, 2019
1 parent 17f32d2 commit 5a7c590
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 30 deletions.
57 changes: 36 additions & 21 deletions docs/content/recipes/authentication.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,29 +114,44 @@ func (r *queryResolver) Hero(ctx context.Context, episode Episode) (Character, e
}
```

Things are different with websockets, and if you do things in the vein of the above example, you have to compute this at every call to `auth.ForContext`.
### Websockets

```golang
// ForContext finds the user from the context. REQUIRES Middleware to have run.
func ForContext(ctx context.Context) *User {
raw, ok := ctx.Value(userCtxKey).(*User)

if !ok {
payload := handler.GetInitPayload(ctx)
if payload == nil {
return nil
}

userId, err := validateAndGetUserID(payload["token"])
if err != nil {
return nil
}

return getUserByID(db, userId)
}
If you need access to the websocket init payload we can do the same thing with the WebsocketInitFunc:

return raw
```go
func main() {
router := chi.NewRouter()

router.Use(auth.Middleware(db))

router.Handle("/", handler.Playground("Starwars", "/query"))
router.Handle("/query",
handler.GraphQL(starwars.NewExecutableSchema(starwars.NewResolver())),
WebsocketInitFunc(func(ctx context.Context, initPayload InitPayload) (context.Context, error) {
userId, err := validateAndGetUserID(payload["token"])
if err != nil {
return nil, err
}

// get the user from the database
user := getUserByID(db, userId)

// put it in context
userCtx := context.WithValue(r.Context(), userCtxKey, user)

// and return it so the resolvers can see it
return userCtx, nil
}))
)

err := http.ListenAndServe(":8080", router)
if err != nil {
panic(err)
}
}
```

It's a bit inefficient if you have multiple calls to this function (e.g. on a field resolver), but what you might do to mitigate that is to have a session object set on the http request and only populate it upon the first check.
> Note
>
> Subscriptions are long lived, if your tokens can timeout or need to be refreshed you should keep the token in
context too and verify it is still valid in `auth.ForContext`.
4 changes: 2 additions & 2 deletions handler/graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ type PersistedQueryCache interface {
Get(ctx context.Context, hash string) (string, bool)
}

type websocketInitFunc func(ctx context.Context, initPayload InitPayload) error
type websocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error)

type Config struct {
cacheSize int
Expand Down Expand Up @@ -278,7 +278,7 @@ func (tw *tracerWrapper) EndOperationExecution(ctx context.Context) {

// WebsocketInitFunc is called when the server receives connection init message from the client.
// This can be used to check initial payload to see whether to accept the websocket connection.
func WebsocketInitFunc(websocketInitFunc func(ctx context.Context, initPayload InitPayload) error) Option {
func WebsocketInitFunc(websocketInitFunc websocketInitFunc) Option {
return func(cfg *Config) {
cfg.websocketInitFunc = websocketInitFunc
}
Expand Down
6 changes: 5 additions & 1 deletion handler/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
)

type executableSchemaMock struct {
QueryFunc func(ctx context.Context, op *ast.OperationDefinition) *graphql.Response
MutationFunc func(ctx context.Context, op *ast.OperationDefinition) *graphql.Response
}

Expand Down Expand Up @@ -42,7 +43,10 @@ func (e *executableSchemaMock) Complexity(typeName, field string, childComplexit
}

func (e *executableSchemaMock) Query(ctx context.Context, op *ast.OperationDefinition) *graphql.Response {
return graphql.ErrorResponse(ctx, "queries are not supported")
if e.QueryFunc == nil {
return graphql.ErrorResponse(ctx, "queries are not supported")
}
return e.QueryFunc(ctx, op)
}

func (e *executableSchemaMock) Mutation(ctx context.Context, op *ast.OperationDefinition) *graphql.Response {
Expand Down
4 changes: 3 additions & 1 deletion handler/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,13 @@ func (c *wsConnection) init() bool {
}

if c.cfg.websocketInitFunc != nil {
if err := c.cfg.websocketInitFunc(c.ctx, c.initPayload); err != nil {
ctx, err := c.cfg.websocketInitFunc(c.ctx, c.initPayload)
if err != nil {
c.sendConnectionError(err.Error())
c.close(websocket.CloseNormalClosure, "terminated")
return false
}
c.ctx = ctx
}

c.write(&operationMessage{Type: connectionAckMsg})
Expand Down
41 changes: 36 additions & 5 deletions handler/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@ import (
"testing"
"time"

"github.com/99designs/gqlgen/client"
"github.com/99designs/gqlgen/graphql"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vektah/gqlparser/ast"
)

func TestWebsocket(t *testing.T) {
Expand Down Expand Up @@ -182,10 +186,37 @@ func TestWebsocketInitFunc(t *testing.T) {
require.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
})

t.Run("can return context for request from WebsocketInitFunc", func(t *testing.T) {
es := &executableSchemaMock{
QueryFunc: func(ctx context.Context, op *ast.OperationDefinition) *graphql.Response {
assert.Equal(t, "newvalue", ctx.Value("newkey"))
return &graphql.Response{Data: []byte(`{"empty":"ok"}`)}
},
}

h := GraphQL(es,
WebsocketInitFunc(func(ctx context.Context, initPayload InitPayload) (context.Context, error) {
return context.WithValue(ctx, "newkey", "newvalue"), nil
}))

c := client.New(h)

socket := c.Websocket("{ empty } ")
defer socket.Close()
var resp struct {
Empty string
}
err := socket.Next(&resp)
require.NoError(t, err)
require.Equal(t, "ok", resp.Empty)
})

t.Run("accept connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) {
h := GraphQL(&executableSchemaStub{next}, WebsocketInitFunc(func(ctx context.Context, initPayload InitPayload) error {
return nil
}))
h := GraphQL(&executableSchemaStub{next},
WebsocketInitFunc(func(ctx context.Context, initPayload InitPayload) (context.Context, error) {
return context.WithValue(ctx, "newkey", "newvalue"), nil
}),
)
srv := httptest.NewServer(h)
defer srv.Close()

Expand All @@ -199,8 +230,8 @@ func TestWebsocketInitFunc(t *testing.T) {
})

t.Run("reject connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) {
h := GraphQL(&executableSchemaStub{next}, WebsocketInitFunc(func(ctx context.Context, initPayload InitPayload) error {
return errors.New("invalid init payload")
h := GraphQL(&executableSchemaStub{next}, WebsocketInitFunc(func(ctx context.Context, initPayload InitPayload) (context.Context, error) {
return ctx, errors.New("invalid init payload")
}))
srv := httptest.NewServer(h)
defer srv.Close()
Expand Down

0 comments on commit 5a7c590

Please sign in to comment.