diff --git a/client/websocket.go b/client/websocket.go index bd92e3c022..d66c872ca3 100644 --- a/client/websocket.go +++ b/client/websocket.go @@ -38,6 +38,10 @@ func errorSubscription(err error) *Subscription { } func (p *Client) Websocket(query string, options ...Option) *Subscription { + return p.WebsocketWithPayload(query, nil, options...) +} + +func (p *Client) WebsocketWithPayload(query string, initPayload map[string]interface{}, options ...Option) *Subscription { r := p.mkRequest(query, options...) requestBody, err := json.Marshal(r) if err != nil { @@ -52,7 +56,15 @@ func (p *Client) Websocket(query string, options ...Option) *Subscription { return errorSubscription(fmt.Errorf("dial: %s", err.Error())) } - if err = c.WriteJSON(operationMessage{Type: connectionInitMsg}); err != nil { + initMessage := operationMessage{Type: connectionInitMsg} + if initPayload != nil { + initMessage.Payload, err = json.Marshal(initPayload) + if err != nil { + return errorSubscription(fmt.Errorf("parse payload: %s", err.Error())) + } + } + + if err = c.WriteJSON(initMessage); err != nil { return errorSubscription(fmt.Errorf("init: %s", err.Error())) } diff --git a/codegen/testserver/generated.go b/codegen/testserver/generated.go index 62ded8f19c..2c64cd9f4a 100644 --- a/codegen/testserver/generated.go +++ b/codegen/testserver/generated.go @@ -95,7 +95,8 @@ type ComplexityRoot struct { } Subscription struct { - Updated func(childComplexity int) int + Updated func(childComplexity int) int + InitPayload func(childComplexity int) int } } @@ -117,6 +118,7 @@ type QueryResolver interface { } type SubscriptionResolver interface { Updated(ctx context.Context) (<-chan string, error) + InitPayload(ctx context.Context) (<-chan string, error) } func field_Query_mapInput_args(rawArgs map[string]interface{}) (map[string]interface{}, error) { @@ -716,6 +718,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Subscription.Updated(childComplexity), true + case "Subscription.initPayload": + if e.complexity.Subscription.InitPayload == nil { + break + } + + return e.complexity.Subscription.InitPayload(childComplexity), true + } return 0, false } @@ -1891,6 +1900,8 @@ func (ec *executionContext) _Subscription(ctx context.Context, sel ast.Selection switch fields[0].Name { case "updated": return ec._Subscription_updated(ctx, fields[0]) + case "initPayload": + return ec._Subscription_initPayload(ctx, fields[0]) default: panic("unknown field " + strconv.Quote(fields[0].Name)) } @@ -1916,6 +1927,26 @@ func (ec *executionContext) _Subscription_updated(ctx context.Context, field gra } } +func (ec *executionContext) _Subscription_initPayload(ctx context.Context, field graphql.CollectedField) func() graphql.Marshaler { + ctx = graphql.WithResolverContext(ctx, &graphql.ResolverContext{ + Field: field, + }) + results, err := ec.resolvers.Subscription().InitPayload(ctx) + if err != nil { + ec.Error(ctx, err) + return nil + } + return func() graphql.Marshaler { + res, ok := <-results + if !ok { + return nil + } + var out graphql.OrderedMap + out.Add(field.Alias, func() graphql.Marshaler { return graphql.MarshalString(res) }()) + return &out + } +} + var __DirectiveImplementors = []string{"__Directive"} // nolint: gocyclo, errcheck, gas, goconst @@ -3485,6 +3516,7 @@ var parsedSchema = gqlparser.MustLoadSchema( type Subscription { updated: String! + initPayload: String! } type Error { diff --git a/codegen/testserver/generated_test.go b/codegen/testserver/generated_test.go index 15e79011b8..0eb9f66672 100644 --- a/codegen/testserver/generated_test.go +++ b/codegen/testserver/generated_test.go @@ -5,10 +5,12 @@ package testserver import ( "context" + "fmt" "net/http" "net/http/httptest" "reflect" "runtime" + "sort" "testing" "time" @@ -109,6 +111,34 @@ func TestGeneratedServer(t *testing.T) { require.Equal(t, initialGoroutineCount, runtime.NumGoroutine()) }) + + t.Run("will parse init payload", func(t *testing.T) { + sub := c.WebsocketWithPayload(`subscription { initPayload }`, map[string]interface{}{ + "Authorization": "Bearer of the curse", + "number": 32, + "strings": []string{"hello", "world"}, + }) + + var msg struct { + resp struct { + InitPayload string + } + } + + err := sub.Next(&msg.resp) + require.NoError(t, err) + require.Equal(t, "AUTH:Bearer of the curse", msg.resp.InitPayload) + err = sub.Next(&msg.resp) + require.NoError(t, err) + require.Equal(t, "Authorization = \"Bearer of the curse\"", msg.resp.InitPayload) + err = sub.Next(&msg.resp) + require.NoError(t, err) + require.Equal(t, "number = 32", msg.resp.InitPayload) + err = sub.Next(&msg.resp) + require.NoError(t, err) + require.Equal(t, "strings = []interface {}{\"hello\", \"world\"}", msg.resp.InitPayload) + sub.Close() + }) }) } @@ -174,3 +204,33 @@ func (r *testSubscriptionResolver) Updated(ctx context.Context) (<-chan string, }() return res, nil } + +func (r *testSubscriptionResolver) InitPayload(ctx context.Context) (<-chan string, error) { + payload := graphql.GetInitPayload(ctx) + channel := make(chan string, len(payload)+1) + + go func() { + <-ctx.Done() + close(channel) + }() + + // Test the helper function separately + auth := payload.Authorization() + if auth != "" { + channel <- "AUTH:" + auth + } else { + channel <- "AUTH:NONE" + } + + // Send them over the channel in alphabetic order + keys := make([]string, 0, len(payload)) + for key := range payload { + keys = append(keys, key) + } + sort.Strings(keys) + for _, key := range keys { + channel <- fmt.Sprintf("%s = %#+v", key, payload[key]) + } + + return channel, nil +} diff --git a/codegen/testserver/resolver.go b/codegen/testserver/resolver.go index c99d550eae..167a0c2825 100644 --- a/codegen/testserver/resolver.go +++ b/codegen/testserver/resolver.go @@ -68,3 +68,6 @@ type subscriptionResolver struct{ *Resolver } func (r *subscriptionResolver) Updated(ctx context.Context) (<-chan string, error) { panic("not implemented") } +func (r *subscriptionResolver) InitPayload(ctx context.Context) (<-chan string, error) { + panic("not implemented") +} diff --git a/codegen/testserver/schema.graphql b/codegen/testserver/schema.graphql index d162856891..d88aa2ee84 100644 --- a/codegen/testserver/schema.graphql +++ b/codegen/testserver/schema.graphql @@ -13,6 +13,7 @@ type Query { type Subscription { updated: String! + initPayload: String! } type Error { diff --git a/graphql/context.go b/graphql/context.go index e298895c41..85e6349552 100644 --- a/graphql/context.go +++ b/graphql/context.go @@ -195,3 +195,50 @@ func (c *RequestContext) RegisterExtension(key string, value interface{}) error c.Extensions[key] = value return nil } + +const ( + initpayload key = "initpayload_context" +) + +// InitPayload is a structure that is parsed from the websocket init message payload. TO use +// request headers for non-websocket, instead wrap the graphql handler in a middleware. +type InitPayload map[string]interface{} + +// GetString safely gets a string value from the payload. It returns an empty string if the +// payload is nil or the value isn't set. +func (payload InitPayload) GetString(key string) string { + if payload == nil { + return "" + } + + if value, ok := payload[key]; ok { + res, _ := value.(string) + return res + } + + return "" +} + +// Authorization is a short hand for getting the Authorization header from the +// payload. +func (payload InitPayload) Authorization() string { + if value := payload.GetString("Authorization"); value != "" { + return value + } + + if value := payload.GetString("authorization"); value != "" { + return value + } + + return "" +} + +// WithInitPayload makes a context with the init payload. +func WithInitPayload(ctx context.Context, payload InitPayload) context.Context { + return context.WithValue(ctx, initpayload, payload) +} + +// GetInitPayload gets the payload from context. +func GetInitPayload(ctx context.Context) InitPayload { + return ctx.Value(initpayload).(InitPayload) +} diff --git a/handler/websocket.go b/handler/websocket.go index 2be1e87f9a..2af0433244 100644 --- a/handler/websocket.go +++ b/handler/websocket.go @@ -63,33 +63,42 @@ func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Req cfg: cfg, } - if !conn.init() { + initPayload, ok := conn.init() + if !ok { return } - conn.run() + conn.run(initPayload) } -func (c *wsConnection) init() bool { +func (c *wsConnection) init() (initPayload graphql.InitPayload, ok bool) { message := c.readOp() if message == nil { c.close(websocket.CloseProtocolError, "decoding error") - return false + return nil, false } + initPayload = make(graphql.InitPayload) + switch message.Type { case connectionInitMsg: + err := json.Unmarshal(message.Payload, &initPayload) + if err != nil { + // Treat an invalid payload as no payload + initPayload = nil + } + c.write(&operationMessage{Type: connectionAckMsg}) case connectionTerminateMsg: c.close(websocket.CloseNormalClosure, "terminated") - return false + return nil, false default: c.sendConnectionError("unexpected message %s", message.Type) c.close(websocket.CloseProtocolError, "unexpected message") - return false + return nil, false } - return true + return initPayload, true } func (c *wsConnection) write(msg *operationMessage) { @@ -98,7 +107,7 @@ func (c *wsConnection) write(msg *operationMessage) { c.mu.Unlock() } -func (c *wsConnection) run() { +func (c *wsConnection) run(initPayload graphql.InitPayload) { for { message := c.readOp() if message == nil { @@ -107,7 +116,7 @@ func (c *wsConnection) run() { switch message.Type { case startMsg: - if !c.subscribe(message) { + if !c.subscribe(message, initPayload) { return } case stopMsg: @@ -131,7 +140,7 @@ func (c *wsConnection) run() { } } -func (c *wsConnection) subscribe(message *operationMessage) bool { +func (c *wsConnection) subscribe(message *operationMessage, initPayload graphql.InitPayload) bool { var reqParams params if err := jsonDecode(bytes.NewReader(message.Payload), &reqParams); err != nil { c.sendConnectionError("invalid json") @@ -158,6 +167,10 @@ func (c *wsConnection) subscribe(message *operationMessage) bool { reqCtx := c.cfg.newRequestContext(doc, reqParams.Query, vars) ctx := graphql.WithRequestContext(c.ctx, reqCtx) + if initPayload != nil { + ctx = graphql.WithInitPayload(ctx, initPayload) + } + if op.Operation != ast.Subscription { var result *graphql.Response if op.Operation == ast.Query {