Skip to content

Commit

Permalink
Merge pull request graph-gophers#203 from matiasanaya/feature/subscri…
Browse files Browse the repository at this point in the history
…ptions

Support subscriptions
  • Loading branch information
pavelnikolov authored Oct 22, 2018
2 parents da17c29 + c3873ab commit b2470f2
Show file tree
Hide file tree
Showing 6 changed files with 584 additions and 13 deletions.
117 changes: 117 additions & 0 deletions gqltesting/subscriptions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package gqltesting

import (
"bytes"
"context"
"encoding/json"
"strconv"
"testing"

graphql "github.com/graph-gophers/graphql-go"
"github.com/graph-gophers/graphql-go/errors"
)

// TestResponse models the expected response
type TestResponse struct {
Data json.RawMessage
Errors []*errors.QueryError
}

// TestSubscription is a GraphQL test case to be used with RunSubscribe.
type TestSubscription struct {
Name string
Schema *graphql.Schema
Query string
OperationName string
Variables map[string]interface{}
ExpectedResults []TestResponse
ExpectedErr error
}

// RunSubscribes runs the given GraphQL subscription test cases as subtests.
func RunSubscribes(t *testing.T, tests []*TestSubscription) {
for i, test := range tests {
if test.Name == "" {
test.Name = strconv.Itoa(i + 1)
}

t.Run(test.Name, func(t *testing.T) {
RunSubscribe(t, test)
})
}
}

// RunSubscribe runs a single GraphQL subscription test case.
func RunSubscribe(t *testing.T, test *TestSubscription) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

c, err := test.Schema.Subscribe(ctx, test.Query, test.OperationName, test.Variables)
if err != nil {
if err.Error() != test.ExpectedErr.Error() {
t.Fatalf("unexpected error: got %+v, want %+v", err, test.ExpectedErr)
}

return
}

var results []*graphql.Response
for res := range c {
results = append(results, res)
}

for i, expected := range test.ExpectedResults {
res := results[i]

checkErrorStrings(t, expected.Errors, res.Errors)

resData, err := res.Data.MarshalJSON()
if err != nil {
t.Fatal(err)
}
got, err := formatJSON(resData)
if err != nil {
t.Fatalf("got: invalid JSON: %s", err)
}

expectedData, err := expected.Data.MarshalJSON()
if err != nil {
t.Fatal(err)
}
want, err := formatJSON(expectedData)
if err != nil {
t.Fatalf("got: invalid JSON: %s", err)
}

if !bytes.Equal(got, want) {
t.Logf("got: %s", got)
t.Logf("want: %s", want)
t.Fail()
}
}
}

func checkErrorStrings(t *testing.T, expected, actual []*errors.QueryError) {
expectedCount, actualCount := len(expected), len(actual)

if expectedCount != actualCount {
t.Fatalf("unexpected number of errors: want %d, got %d", expectedCount, actualCount)
}

if expectedCount > 0 {
for i, want := range expected {
got := actual[i]

if got.Error() != want.Error() {
t.Fatalf("unexpected error: got %+v, want %+v", got, want)
}
}

// Return because we're done checking.
return
}

for _, err := range actual {
t.Errorf("unexpected error: '%s'", err)
}
}
44 changes: 31 additions & 13 deletions internal/exec/resolvable/resolvable.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ import (

type Schema struct {
schema.Schema
Query Resolvable
Mutation Resolvable
Resolver reflect.Value
Query Resolvable
Mutation Resolvable
Subscription Resolvable
Resolver reflect.Value
}

type Resolvable interface {
Expand Down Expand Up @@ -57,7 +58,7 @@ func (*Scalar) isResolvable() {}
func ApplyResolver(s *schema.Schema, resolver interface{}) (*Schema, error) {
b := newBuilder(s)

var query, mutation Resolvable
var query, mutation, subscription Resolvable

if t, ok := s.EntryPoints["query"]; ok {
if err := b.assignExec(&query, t, reflect.TypeOf(resolver)); err != nil {
Expand All @@ -71,15 +72,22 @@ func ApplyResolver(s *schema.Schema, resolver interface{}) (*Schema, error) {
}
}

if t, ok := s.EntryPoints["subscription"]; ok {
if err := b.assignExec(&subscription, t, reflect.TypeOf(resolver)); err != nil {
return nil, err
}
}

if err := b.finish(); err != nil {
return nil, err
}

return &Schema{
Schema: *s,
Resolver: reflect.ValueOf(resolver),
Query: query,
Mutation: mutation,
Schema: *s,
Resolver: reflect.ValueOf(resolver),
Query: query,
Mutation: mutation,
Subscription: subscription,
}, nil
}

Expand Down Expand Up @@ -284,14 +292,19 @@ func (b *execBuilder) makeFieldExec(typeName string, f *schema.Field, m reflect.
return nil, fmt.Errorf("too many parameters")
}

if m.Type.NumOut() > 2 {
maxNumOfReturns := 2
if m.Type.NumOut() < maxNumOfReturns-1 {
return nil, fmt.Errorf("too few return values")
}

if m.Type.NumOut() > maxNumOfReturns {
return nil, fmt.Errorf("too many return values")
}

hasError := m.Type.NumOut() == 2
hasError := m.Type.NumOut() == maxNumOfReturns
if hasError {
if m.Type.Out(1) != errorType {
return nil, fmt.Errorf(`must have "error" as its second return value`)
if m.Type.Out(maxNumOfReturns-1) != errorType {
return nil, fmt.Errorf(`must have "error" as its last return value`)
}
}

Expand All @@ -304,7 +317,12 @@ func (b *execBuilder) makeFieldExec(typeName string, f *schema.Field, m reflect.
HasError: hasError,
TraceLabel: fmt.Sprintf("GraphQL field: %s.%s", typeName, f.Name),
}
if err := b.assignExec(&fe.ValueExec, f.Type, m.Type.Out(0)); err != nil {

out := m.Type.Out(0)
if typeName == "Subscription" && out.Kind() == reflect.Chan {
out = m.Type.Out(0).Elem()
}
if err := b.assignExec(&fe.ValueExec, f.Type, out); err != nil {
return nil, err
}
return fe, nil
Expand Down
2 changes: 2 additions & 0 deletions internal/exec/selected/selected.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ func ApplyOperation(r *Request, s *resolvable.Schema, op *query.Operation) []Sel
obj = s.Query.(*resolvable.Object)
case query.Mutation:
obj = s.Mutation.(*resolvable.Object)
case query.Subscription:
obj = s.Subscription.(*resolvable.Object)
}
return applySelectionSet(r, obj, op.Selections)
}
Expand Down
147 changes: 147 additions & 0 deletions internal/exec/subscribe.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package exec

import (
"bytes"
"context"
"encoding/json"
"fmt"
"reflect"
"time"

"github.com/graph-gophers/graphql-go/errors"
"github.com/graph-gophers/graphql-go/internal/exec/resolvable"
"github.com/graph-gophers/graphql-go/internal/exec/selected"
"github.com/graph-gophers/graphql-go/internal/query"
)

type Response struct {
Data json.RawMessage
Errors []*errors.QueryError
}

func (r *Request) Subscribe(ctx context.Context, s *resolvable.Schema, op *query.Operation) <-chan *Response {
var result reflect.Value
var f *fieldToExec
var err *errors.QueryError
func() {
defer r.handlePanic(ctx)

sels := selected.ApplyOperation(&r.Request, s, op)
var fields []*fieldToExec
collectFieldsToResolve(sels, s.Resolver, &fields, make(map[string]*fieldToExec))

// TODO: move this check into validation.Validate
if len(fields) != 1 {
err = errors.Errorf("%s", "can subscribe to at most one subscription at a time")
return
}
f = fields[0]

var in []reflect.Value
if f.field.HasContext {
in = append(in, reflect.ValueOf(ctx))
}
if f.field.ArgsPacker != nil {
in = append(in, f.field.PackedArgs)
}
callOut := f.resolver.Method(f.field.MethodIndex).Call(in)
result = callOut[0]

if f.field.HasError && !callOut[1].IsNil() {
resolverErr := callOut[1].Interface().(error)
err = errors.Errorf("%s", resolverErr)
err.ResolverError = resolverErr
}
}()

if err != nil {
return sendAndReturnClosed(&Response{Errors: []*errors.QueryError{err}})
}

if ctxErr := ctx.Err(); ctxErr != nil {
return sendAndReturnClosed(&Response{Errors: []*errors.QueryError{errors.Errorf("%s", ctxErr)}})
}

c := make(chan *Response)
// TODO: handle resolver nil channel better?
if result == reflect.Zero(result.Type()) {
close(c)
return c
}

go func() {
for {
// Check subscription context
chosen, resp, ok := reflect.Select([]reflect.SelectCase{
{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(ctx.Done()),
},
{
Dir: reflect.SelectRecv,
Chan: result,
},
})
switch chosen {
// subscription context done
case 0:
close(c)
return
// upstream received
case 1:
// upstream closed
if !ok {
close(c)
return
}

subR := &Request{
Request: selected.Request{
Doc: r.Request.Doc,
Vars: r.Request.Vars,
Schema: r.Request.Schema,
},
Limiter: r.Limiter,
Tracer: r.Tracer,
Logger: r.Logger,
}
var out bytes.Buffer
func() {
// TODO: configurable timeout
subCtx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()

// resolve response
func() {
defer subR.handlePanic(subCtx)

out.WriteString(fmt.Sprintf(`{"%s":`, f.field.Alias))
subR.execSelectionSet(subCtx, f.sels, f.field.Type, &pathSegment{nil, f.field.Alias}, resp, &out)
out.WriteString(`}`)
}()

if err := subCtx.Err(); err != nil {
c <- &Response{Errors: []*errors.QueryError{errors.Errorf("%s", err)}}
return
}

// Send response within timeout
// TODO: maybe block until sent?
select {
case <-subCtx.Done():
case c <- &Response{Data: out.Bytes(), Errors: subR.Errs}:
}
}()
}
}
}()

return c
}

func sendAndReturnClosed(resp *Response) chan *Response {
c := make(chan *Response, 1)
c <- resp
close(c)
return c
}
Loading

0 comments on commit b2470f2

Please sign in to comment.