diff --git a/graphql/handler/apq.go b/graphql/handler/apq.go index ba8c21499c..1f4a456ae7 100644 --- a/graphql/handler/apq.go +++ b/graphql/handler/apq.go @@ -5,6 +5,8 @@ import ( "crypto/sha256" "encoding/hex" + "github.com/99designs/gqlgen/graphql/handler/transport" + "github.com/99designs/gqlgen/graphql" "github.com/mitchellh/mapstructure" ) @@ -20,7 +22,7 @@ const ( // see https://github.com/apollographql/apollo-link-persisted-queries func AutomaticPersistedQuery(cache Cache) Middleware { return func(next Handler) Handler { - return func(ctx context.Context, writer Writer) { + return func(ctx context.Context, writer transport.Writer) { rc := graphql.GetRequestContext(ctx) if rc.Extensions["persistedQuery"] == nil { diff --git a/graphql/handler/complexity.go b/graphql/handler/complexity.go index 0cae2d5616..41219df0fe 100644 --- a/graphql/handler/complexity.go +++ b/graphql/handler/complexity.go @@ -3,6 +3,8 @@ package handler import ( "context" + "github.com/99designs/gqlgen/graphql/handler/transport" + "github.com/99designs/gqlgen/graphql" ) @@ -11,7 +13,7 @@ import ( // If a query is submitted that exceeds the limit, a 422 status code will be returned. func ComplexityLimit(limit int) Middleware { return func(next Handler) Handler { - return func(ctx context.Context, writer Writer) { + return func(ctx context.Context, writer transport.Writer) { graphql.GetRequestContext(ctx).ComplexityLimit = limit next(ctx, writer) } @@ -25,7 +27,7 @@ func ComplexityLimit(limit int) Middleware { // If a query is submitted that exceeds the limit, a 422 status code will be returned. func ComplexityLimitFunc(f graphql.ComplexityLimitFunc) Middleware { return func(next Handler) Handler { - return func(ctx context.Context, writer Writer) { + return func(ctx context.Context, writer transport.Writer) { graphql.GetRequestContext(ctx).ComplexityLimit = f(ctx) next(ctx, writer) } diff --git a/graphql/handler/errors.go b/graphql/handler/errors.go index f4c9f97c47..a238fd3650 100644 --- a/graphql/handler/errors.go +++ b/graphql/handler/errors.go @@ -3,6 +3,8 @@ package handler import ( "context" + "github.com/99designs/gqlgen/graphql/handler/transport" + "github.com/99designs/gqlgen/graphql" ) @@ -11,7 +13,7 @@ import ( // implementation in graphql.DefaultErrorPresenter for an example. func ErrorPresenter(ep graphql.ErrorPresenterFunc) Middleware { return func(next Handler) Handler { - return func(ctx context.Context, writer Writer) { + return func(ctx context.Context, writer transport.Writer) { graphql.GetRequestContext(ctx).ErrorPresenter = ep next(ctx, writer) } @@ -22,7 +24,7 @@ func ErrorPresenter(ep graphql.ErrorPresenterFunc) Middleware { // and hide internal error types from clients. func RecoverFunc(recover graphql.RecoverFunc) Middleware { return func(next Handler) Handler { - return func(ctx context.Context, writer Writer) { + return func(ctx context.Context, writer transport.Writer) { graphql.GetRequestContext(ctx).Recover = recover next(ctx, writer) } diff --git a/graphql/handler/introspection.go b/graphql/handler/introspection.go index 279d8fdb0b..0d435ba72f 100644 --- a/graphql/handler/introspection.go +++ b/graphql/handler/introspection.go @@ -3,13 +3,15 @@ package handler import ( "context" + "github.com/99designs/gqlgen/graphql/handler/transport" + "github.com/99designs/gqlgen/graphql" ) // Introspection enables clients to reflect all of the types available on the graph. func Introspection() Middleware { return func(next Handler) Handler { - return func(ctx context.Context, writer Writer) { + return func(ctx context.Context, writer transport.Writer) { graphql.GetRequestContext(ctx).DisableIntrospection = false next(ctx, writer) } diff --git a/graphql/handler/server.go b/graphql/handler/server.go index 7675760ca2..091b0f4893 100644 --- a/graphql/handler/server.go +++ b/graphql/handler/server.go @@ -6,6 +6,8 @@ import ( "fmt" "net/http" + "github.com/99designs/gqlgen/graphql/handler/transport" + "github.com/vektah/gqlparser/validator" "github.com/99designs/gqlgen/graphql" @@ -17,39 +19,20 @@ import ( type ( Server struct { es graphql.ExecutableSchema - transports []Transport + transports []transport.Transport middlewares []Middleware } - Handler func(ctx context.Context, writer Writer) - - Writer func(*graphql.Response) + Handler func(ctx context.Context, writer transport.Writer) Middleware func(next Handler) Handler - Transport interface { - Supports(r *http.Request) bool - Do(w http.ResponseWriter, r *http.Request) (*graphql.RequestContext, Writer) - } - Option func(Server) ResponseStream func() *graphql.Response ) -func (w Writer) Errorf(format string, args ...interface{}) { - w(&graphql.Response{ - Errors: gqlerror.List{{Message: fmt.Sprintf(format, args...)}}, - }) -} - -func (w Writer) Error(msg string) { - w(&graphql.Response{ - Errors: gqlerror.List{{Message: msg}}, - }) -} - -func (s *Server) AddTransport(transport Transport) { +func (s *Server) AddTransport(transport transport.Transport) { s.transports = append(s.transports, transport) } @@ -63,7 +46,7 @@ func New(es graphql.ExecutableSchema) *Server { } } -func (s *Server) getTransport(r *http.Request) Transport { +func (s *Server) getTransport(r *http.Request) transport.Transport { for _, t := range s.transports { if t.Supports(r) { return t @@ -96,7 +79,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // executableSchemaHandler is the inner most handler, it invokes the graph directly after all middleware // and sends responses to the transport so it can be returned to the client -func (s *Server) executableSchemaHandler(ctx context.Context, write Writer) { +func (s *Server) executableSchemaHandler(ctx context.Context, write transport.Writer) { rc := graphql.GetRequestContext(ctx) var gerr *gqlerror.Error diff --git a/graphql/handler/server_test.go b/graphql/handler/server_test.go index ad5f0fe60d..d5919f0c25 100644 --- a/graphql/handler/server_test.go +++ b/graphql/handler/server_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/99designs/gqlgen/graphql" + "github.com/99designs/gqlgen/graphql/handler/transport" "github.com/stretchr/testify/assert" "github.com/vektah/gqlparser/ast" ) @@ -31,11 +32,14 @@ func TestServer(t *testing.T) { return &graphql.Response{Data: []byte(`"subscription resp"`)} } }, + SchemaFunc: func() *ast.Schema { + return &ast.Schema{} + }, } srv := New(es) - srv.AddTransport(&HTTPGet{}) + srv.AddTransport(&transport.HTTPGet{}) srv.Use(func(next Handler) Handler { - return func(ctx context.Context, writer Writer) { + return func(ctx context.Context, writer transport.Writer) { next(ctx, writer) } }) @@ -67,13 +71,13 @@ func TestServer(t *testing.T) { t.Run("invokes middleware in order", func(t *testing.T) { var calls []string srv.Use(func(next Handler) Handler { - return func(ctx context.Context, writer Writer) { + return func(ctx context.Context, writer transport.Writer) { calls = append(calls, "first") next(ctx, writer) } }) srv.Use(func(next Handler) Handler { - return func(ctx context.Context, writer Writer) { + return func(ctx context.Context, writer transport.Writer) { calls = append(calls, "second") next(ctx, writer) } diff --git a/graphql/handler/tracer.go b/graphql/handler/tracer.go index 4ab1baed40..ccdb605681 100644 --- a/graphql/handler/tracer.go +++ b/graphql/handler/tracer.go @@ -3,6 +3,8 @@ package handler import ( "context" + "github.com/99designs/gqlgen/graphql/handler/transport" + "github.com/99designs/gqlgen/graphql" ) @@ -10,7 +12,7 @@ import ( // calling resolver. This is useful for tracing func Tracer(tracer graphql.Tracer) Middleware { return func(next Handler) Handler { - return func(ctx context.Context, writer Writer) { + return func(ctx context.Context, writer transport.Writer) { rc := graphql.GetRequestContext(ctx) rc.AddTracer(tracer) rc.AddRequestMiddleware(func(ctx context.Context, next func(ctx context.Context) []byte) []byte { diff --git a/graphql/handler/http_get.go b/graphql/handler/transport/http_get.go similarity index 85% rename from graphql/handler/http_get.go rename to graphql/handler/transport/http_get.go index c31301dc60..15720b2b30 100644 --- a/graphql/handler/http_get.go +++ b/graphql/handler/transport/http_get.go @@ -1,4 +1,4 @@ -package handler +package transport import ( "encoding/json" @@ -24,16 +24,24 @@ func (H HTTPGet) Do(w http.ResponseWriter, r *http.Request) (*graphql.RequestCon reqParams.RawQuery = r.URL.Query().Get("query") reqParams.OperationName = r.URL.Query().Get("operationName") + writer := Writer(func(response *graphql.Response) { + b, err := json.Marshal(response) + if err != nil { + panic(err) + } + w.Write(b) + }) + if variables := r.URL.Query().Get("variables"); variables != "" { if err := jsonDecode(strings.NewReader(variables), &reqParams.Variables); err != nil { - sendErrorf(w, http.StatusBadRequest, "variables could not be decoded") + writer.Errorf("variables could not be decoded") return nil, nil } } if extensions := r.URL.Query().Get("extensions"); extensions != "" { if err := jsonDecode(strings.NewReader(extensions), &reqParams.Extensions); err != nil { - sendErrorf(w, http.StatusBadRequest, "extensions could not be decoded") + writer.Errorf("extensions could not be decoded") return nil, nil } } @@ -43,13 +51,7 @@ func (H HTTPGet) Do(w http.ResponseWriter, r *http.Request) (*graphql.RequestCon // return ctx, nil, nil, gqlerror.List{gqlerror.Errorf("GET requests only allow query operations")} //} - return reqParams, func(response *graphql.Response) { - b, err := json.Marshal(response) - if err != nil { - panic(err) - } - w.Write(b) - } + return reqParams, writer } func jsonDecode(r io.Reader, val interface{}) error { diff --git a/graphql/handler/jsonpost.go b/graphql/handler/transport/jsonpost.go similarity index 98% rename from graphql/handler/jsonpost.go rename to graphql/handler/transport/jsonpost.go index f9395826b1..4f103a7c9f 100644 --- a/graphql/handler/jsonpost.go +++ b/graphql/handler/transport/jsonpost.go @@ -1,4 +1,4 @@ -package handler +package transport import ( "encoding/json" diff --git a/graphql/handler/jsonpost_test.go b/graphql/handler/transport/jsonpost_test.go similarity index 96% rename from graphql/handler/jsonpost_test.go rename to graphql/handler/transport/jsonpost_test.go index 0031098236..3c6d83d767 100644 --- a/graphql/handler/jsonpost_test.go +++ b/graphql/handler/transport/jsonpost_test.go @@ -1,4 +1,4 @@ -package handler +package transport_test import ( "context" @@ -9,6 +9,8 @@ import ( "testing" "github.com/99designs/gqlgen/graphql" + "github.com/99designs/gqlgen/graphql/handler" + "github.com/99designs/gqlgen/graphql/handler/transport" "github.com/stretchr/testify/assert" "github.com/vektah/gqlparser" "github.com/vektah/gqlparser/ast" @@ -33,8 +35,8 @@ func TestJsonPost(t *testing.T) { `}) }, } - h := New(es) - h.AddTransport(JsonPostTransport{}) + h := handler.New(es) + h.AddTransport(transport.JsonPostTransport{}) t.Run("success", func(t *testing.T) { resp := doRequest(h, "POST", "/graphql", `{"query":"{ me { name } }"}`) diff --git a/graphql/handler/requestcontext.go b/graphql/handler/transport/requestcontext.go similarity index 95% rename from graphql/handler/requestcontext.go rename to graphql/handler/transport/requestcontext.go index 98568232c6..b5e443d068 100644 --- a/graphql/handler/requestcontext.go +++ b/graphql/handler/transport/requestcontext.go @@ -1,4 +1,4 @@ -package handler +package transport import "github.com/99designs/gqlgen/graphql" diff --git a/graphql/handler/transport/transport.go b/graphql/handler/transport/transport.go new file mode 100644 index 0000000000..ecb62d7286 --- /dev/null +++ b/graphql/handler/transport/transport.go @@ -0,0 +1,29 @@ +package transport + +import ( + "fmt" + "net/http" + + "github.com/99designs/gqlgen/graphql" + "github.com/vektah/gqlparser/gqlerror" +) + +type ( + Transport interface { + Supports(r *http.Request) bool + Do(w http.ResponseWriter, r *http.Request) (*graphql.RequestContext, Writer) + } + Writer func(*graphql.Response) +) + +func (w Writer) Errorf(format string, args ...interface{}) { + w(&graphql.Response{ + Errors: gqlerror.List{{Message: fmt.Sprintf(format, args...)}}, + }) +} + +func (w Writer) Error(msg string) { + w(&graphql.Response{ + Errors: gqlerror.List{{Message: msg}}, + }) +} diff --git a/graphql/handler/utils_test.go b/graphql/handler/utils_test.go index 3ec9dae658..2c80698376 100644 --- a/graphql/handler/utils_test.go +++ b/graphql/handler/utils_test.go @@ -4,6 +4,7 @@ import ( "context" "github.com/99designs/gqlgen/graphql" + "github.com/99designs/gqlgen/graphql/handler/transport" ) type middlewareContext struct { @@ -19,7 +20,7 @@ func testMiddleware(m Middleware, initialContexts ...graphql.RequestContext) mid initial = &initialContexts[0] } - m(func(ctx context.Context, writer Writer) { + m(func(ctx context.Context, writer transport.Writer) { c.ResultContext = *graphql.GetRequestContext(ctx) c.InvokedNext = true })(graphql.WithRequestContext(context.Background(), initial), func(response *graphql.Response) {