Skip to content

Commit

Permalink
Added parsing of the websocket init message payload, and making it av…
Browse files Browse the repository at this point in the history
…ailable via the context passed to resolvers.

* Added GetInitPayload(ctx) function to graphql
* Added WithInitPayload(ctx) function to graphql
* Added WebsocketWithPayload method to client.Client (Websocket calls it with a nil payload for backwards compability)
* Added tests for these changes in codegen/testserver/generated_test
  • Loading branch information
gissleh committed Sep 19, 2018
1 parent 2bd1cc2 commit 380828f
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 12 deletions.
14 changes: 13 additions & 1 deletion client/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ func errorSubscription(err error) *Subscription {
}

func (p *Client) Websocket(query string, options ...Option) *Subscription {
return p.WebsocketWithPayload(query, nil, options...)
}

func (p *Client) WebsocketWithPayload(query string, initPayload map[string]interface{}, options ...Option) *Subscription {
r := p.mkRequest(query, options...)
requestBody, err := json.Marshal(r)
if err != nil {
Expand All @@ -52,7 +56,15 @@ func (p *Client) Websocket(query string, options ...Option) *Subscription {
return errorSubscription(fmt.Errorf("dial: %s", err.Error()))
}

if err = c.WriteJSON(operationMessage{Type: connectionInitMsg}); err != nil {
initMessage := operationMessage{Type: connectionInitMsg}
if initPayload != nil {
initMessage.Payload, err = json.Marshal(initPayload)
if err != nil {
return errorSubscription(fmt.Errorf("parse payload: %s", err.Error()))
}
}

if err = c.WriteJSON(initMessage); err != nil {
return errorSubscription(fmt.Errorf("init: %s", err.Error()))
}

Expand Down
34 changes: 33 additions & 1 deletion codegen/testserver/generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

60 changes: 60 additions & 0 deletions codegen/testserver/generated_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ package testserver

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"reflect"
"runtime"
"sort"
"testing"
"time"

Expand Down Expand Up @@ -109,6 +111,34 @@ func TestGeneratedServer(t *testing.T) {

require.Equal(t, initialGoroutineCount, runtime.NumGoroutine())
})

t.Run("will parse init payload", func(t *testing.T) {
sub := c.WebsocketWithPayload(`subscription { initPayload }`, map[string]interface{}{
"Authorization": "Bearer of the curse",
"number": 32,
"strings": []string{"hello", "world"},
})

var msg struct {
resp struct {
InitPayload string
}
}

err := sub.Next(&msg.resp)
require.NoError(t, err)
require.Equal(t, "AUTH:Bearer of the curse", msg.resp.InitPayload)
err = sub.Next(&msg.resp)
require.NoError(t, err)
require.Equal(t, "Authorization = \"Bearer of the curse\"", msg.resp.InitPayload)
err = sub.Next(&msg.resp)
require.NoError(t, err)
require.Equal(t, "number = 32", msg.resp.InitPayload)
err = sub.Next(&msg.resp)
require.NoError(t, err)
require.Equal(t, "strings = []interface {}{\"hello\", \"world\"}", msg.resp.InitPayload)
sub.Close()
})
})
}

Expand Down Expand Up @@ -174,3 +204,33 @@ func (r *testSubscriptionResolver) Updated(ctx context.Context) (<-chan string,
}()
return res, nil
}

func (r *testSubscriptionResolver) InitPayload(ctx context.Context) (<-chan string, error) {
payload := graphql.GetInitPayload(ctx)
channel := make(chan string, len(payload)+1)

go func() {
<-ctx.Done()
close(channel)
}()

// Test the helper function separately
auth := payload.Authorization()
if auth != "" {
channel <- "AUTH:" + auth
} else {
channel <- "AUTH:NONE"
}

// Send them over the channel in alphabetic order
keys := make([]string, 0, len(payload))
for key := range payload {
keys = append(keys, key)
}
sort.Strings(keys)
for _, key := range keys {
channel <- fmt.Sprintf("%s = %#+v", key, payload[key])
}

return channel, nil
}
3 changes: 3 additions & 0 deletions codegen/testserver/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,6 @@ type subscriptionResolver struct{ *Resolver }
func (r *subscriptionResolver) Updated(ctx context.Context) (<-chan string, error) {
panic("not implemented")
}
func (r *subscriptionResolver) InitPayload(ctx context.Context) (<-chan string, error) {
panic("not implemented")
}
1 change: 1 addition & 0 deletions codegen/testserver/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type Query {

type Subscription {
updated: String!
initPayload: String!
}

type Error {
Expand Down
47 changes: 47 additions & 0 deletions graphql/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,50 @@ func (c *RequestContext) RegisterExtension(key string, value interface{}) error
c.Extensions[key] = value
return nil
}

const (
initpayload key = "initpayload_context"
)

// InitPayload is a structure that is parsed from the websocket init message payload. TO use
// request headers for non-websocket, instead wrap the graphql handler in a middleware.
type InitPayload map[string]interface{}

// GetString safely gets a string value from the payload. It returns an empty string if the
// payload is nil or the value isn't set.
func (payload InitPayload) GetString(key string) string {
if payload == nil {
return ""
}

if value, ok := payload[key]; ok {
res, _ := value.(string)
return res
}

return ""
}

// Authorization is a short hand for getting the Authorization header from the
// payload.
func (payload InitPayload) Authorization() string {
if value := payload.GetString("Authorization"); value != "" {
return value
}

if value := payload.GetString("authorization"); value != "" {
return value
}

return ""
}

// WithInitPayload makes a context with the init payload.
func WithInitPayload(ctx context.Context, payload InitPayload) context.Context {
return context.WithValue(ctx, initpayload, payload)
}

// GetInitPayload gets the payload from context.
func GetInitPayload(ctx context.Context) InitPayload {
return ctx.Value(initpayload).(InitPayload)
}
33 changes: 23 additions & 10 deletions handler/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,33 +63,42 @@ func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Req
cfg: cfg,
}

if !conn.init() {
initPayload, ok := conn.init()
if !ok {
return
}

conn.run()
conn.run(initPayload)
}

func (c *wsConnection) init() bool {
func (c *wsConnection) init() (initPayload graphql.InitPayload, ok bool) {
message := c.readOp()
if message == nil {
c.close(websocket.CloseProtocolError, "decoding error")
return false
return nil, false
}

initPayload = make(graphql.InitPayload)

switch message.Type {
case connectionInitMsg:
err := json.Unmarshal(message.Payload, &initPayload)
if err != nil {
// Treat an invalid payload as no payload
initPayload = nil
}

c.write(&operationMessage{Type: connectionAckMsg})
case connectionTerminateMsg:
c.close(websocket.CloseNormalClosure, "terminated")
return false
return nil, false
default:
c.sendConnectionError("unexpected message %s", message.Type)
c.close(websocket.CloseProtocolError, "unexpected message")
return false
return nil, false
}

return true
return initPayload, true
}

func (c *wsConnection) write(msg *operationMessage) {
Expand All @@ -98,7 +107,7 @@ func (c *wsConnection) write(msg *operationMessage) {
c.mu.Unlock()
}

func (c *wsConnection) run() {
func (c *wsConnection) run(initPayload graphql.InitPayload) {
for {
message := c.readOp()
if message == nil {
Expand All @@ -107,7 +116,7 @@ func (c *wsConnection) run() {

switch message.Type {
case startMsg:
if !c.subscribe(message) {
if !c.subscribe(message, initPayload) {
return
}
case stopMsg:
Expand All @@ -131,7 +140,7 @@ func (c *wsConnection) run() {
}
}

func (c *wsConnection) subscribe(message *operationMessage) bool {
func (c *wsConnection) subscribe(message *operationMessage, initPayload graphql.InitPayload) bool {
var reqParams params
if err := jsonDecode(bytes.NewReader(message.Payload), &reqParams); err != nil {
c.sendConnectionError("invalid json")
Expand All @@ -158,6 +167,10 @@ func (c *wsConnection) subscribe(message *operationMessage) bool {
reqCtx := c.cfg.newRequestContext(doc, reqParams.Query, vars)
ctx := graphql.WithRequestContext(c.ctx, reqCtx)

if initPayload != nil {
ctx = graphql.WithInitPayload(ctx, initPayload)
}

if op.Operation != ast.Subscription {
var result *graphql.Response
if op.Operation == ast.Query {
Expand Down

0 comments on commit 380828f

Please sign in to comment.