diff --git a/client/client.go b/client/client.go index fefe7948d3..903e26cef4 100644 --- a/client/client.go +++ b/client/client.go @@ -1,4 +1,5 @@ // client is used internally for testing. See readme for alternatives + package client import ( @@ -7,82 +8,63 @@ import ( "fmt" "io/ioutil" "net/http" + "net/http/httptest" "github.com/mitchellh/mapstructure" ) -// Client for graphql requests -type Client struct { - url string - client *http.Client -} - -// New creates a graphql client -func New(url string, client ...*http.Client) *Client { - p := &Client{ - url: url, +type ( + // Client used for testing GraphQL servers. Not for production use. + Client struct { + h http.Handler + opts []Option } - if len(client) > 0 { - p.client = client[0] - } else { - p.client = http.DefaultClient + // Option implements a visitor that mutates an outgoing GraphQL request + // + // This is the Option pattern - https://dave.cheney.net/2014/10/17/functional-options-for-friendly-apis + Option func(bd *Request) + + // Request represents an outgoing GraphQL request + Request struct { + Query string `json:"query"` + Variables map[string]interface{} `json:"variables,omitempty"` + OperationName string `json:"operationName,omitempty"` + HTTP *http.Request `json:"-"` } - return p -} - -type Request struct { - Query string `json:"query"` - Variables map[string]interface{} `json:"variables,omitempty"` - OperationName string `json:"operationName,omitempty"` -} - -type Option func(r *Request) - -func Var(name string, value interface{}) Option { - return func(r *Request) { - if r.Variables == nil { - r.Variables = map[string]interface{}{} - } - r.Variables[name] = value + // Response is a GraphQL layer response from a handler. + Response struct { + Data interface{} + Errors json.RawMessage + Extensions map[string]interface{} } -} +) -func Operation(name string) Option { - return func(r *Request) { - r.OperationName = name +// New creates a graphql client +// Options can be set that should be applied to all requests made with this client +func New(h http.Handler, opts ...Option) *Client { + p := &Client{ + h: h, + opts: opts, } + + return p } +// MustPost is a convenience wrapper around Post that automatically panics on error func (p *Client) MustPost(query string, response interface{}, options ...Option) { if err := p.Post(query, response, options...); err != nil { panic(err) } } -func (p *Client) mkRequest(query string, options ...Option) Request { - r := Request{ - Query: query, - } - - for _, option := range options { - option(&r) - } - - return r -} - -type ResponseData struct { - Data interface{} - Errors json.RawMessage - Extensions map[string]interface{} -} - -func (p *Client) Post(query string, response interface{}, options ...Option) (resperr error) { - respDataRaw, resperr := p.RawPost(query, options...) - if resperr != nil { - return resperr +// Post sends a http POST request to the graphql endpoint with the given query then unpacks +// the response into the given object. +func (p *Client) Post(query string, response interface{}, options ...Option) error { + respDataRaw, err := p.RawPost(query, options...) + if err != nil { + return err } // we want to unpack even if there is an error, so we can see partial responses @@ -94,35 +76,26 @@ func (p *Client) Post(query string, response interface{}, options ...Option) (re return unpackErr } -func (p *Client) RawPost(query string, options ...Option) (*ResponseData, error) { - r := p.mkRequest(query, options...) - requestBody, err := json.Marshal(r) +// RawPost is similar to Post, except it skips decoding the raw json response +// unpacked onto Response. This is used to test extension keys which are not +// available when using Post. +func (p *Client) RawPost(query string, options ...Option) (*Response, error) { + r, err := p.newRequest(query, options...) if err != nil { - return nil, fmt.Errorf("encode: %s", err.Error()) + return nil, fmt.Errorf("build: %s", err.Error()) } - rawResponse, err := p.client.Post(p.url, "application/json", bytes.NewBuffer(requestBody)) - if err != nil { - return nil, fmt.Errorf("post: %s", err.Error()) - } - defer func() { - _ = rawResponse.Body.Close() - }() + w := httptest.NewRecorder() + p.h.ServeHTTP(w, r) - if rawResponse.StatusCode >= http.StatusBadRequest { - responseBody, _ := ioutil.ReadAll(rawResponse.Body) - return nil, fmt.Errorf("http %d: %s", rawResponse.StatusCode, responseBody) - } - - responseBody, err := ioutil.ReadAll(rawResponse.Body) - if err != nil { - return nil, fmt.Errorf("read: %s", err.Error()) + if w.Code >= http.StatusBadRequest { + return nil, fmt.Errorf("http %d: %s", w.Code, w.Body.String()) } // decode it into map string first, let mapstructure do the final decode // because it can be much stricter about unknown fields. - respDataRaw := &ResponseData{} - err = json.Unmarshal(responseBody, &respDataRaw) + respDataRaw := &Response{} + err = json.Unmarshal(w.Body.Bytes(), &respDataRaw) if err != nil { return nil, fmt.Errorf("decode: %s", err.Error()) } @@ -130,12 +103,34 @@ func (p *Client) RawPost(query string, options ...Option) (*ResponseData, error) return respDataRaw, nil } -type RawJsonError struct { - json.RawMessage -} +func (p *Client) newRequest(query string, options ...Option) (*http.Request, error) { + bd := &Request{ + Query: query, + HTTP: httptest.NewRequest(http.MethodPost, "/", nil), + } + bd.HTTP.Header.Set("Content-Type", "application/json") + + // per client options from client.New apply first + for _, option := range options { + option(bd) + } + // per request options + for _, option := range options { + option(bd) + } + + switch bd.HTTP.Header.Get("Content-Type") { + case "application/json": + requestBody, err := json.Marshal(bd) + if err != nil { + return nil, fmt.Errorf("encode: %s", err.Error()) + } + bd.HTTP.Body = ioutil.NopCloser(bytes.NewBuffer(requestBody)) + default: + panic("unsupported encoding" + bd.HTTP.Header.Get("Content-Type")) + } -func (r RawJsonError) Error() string { - return string(r.RawMessage) + return bd.HTTP, nil } func unpack(data interface{}, into interface{}) error { diff --git a/client/client_test.go b/client/client_test.go index af58feb140..98d907e5c1 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -4,7 +4,6 @@ import ( "encoding/json" "io/ioutil" "net/http" - "net/http/httptest" "testing" "github.com/99designs/gqlgen/client" @@ -12,7 +11,7 @@ import ( ) func TestClient(t *testing.T) { - h := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { b, err := ioutil.ReadAll(r.Body) if err != nil { panic(err) @@ -27,9 +26,9 @@ func TestClient(t *testing.T) { if err != nil { panic(err) } - })) + }) - c := client.New(h.URL) + c := client.New(h) var resp struct { Name string @@ -39,3 +38,53 @@ func TestClient(t *testing.T) { require.Equal(t, "bob", resp.Name) } + +func TestAddHeader(t *testing.T) { + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "ASDF", r.Header.Get("Test-Key")) + + w.Write([]byte(`{}`)) + }) + + c := client.New(h) + + var resp struct{} + c.MustPost("{ id }", &resp, + client.AddHeader("Test-Key", "ASDF"), + ) +} + +func TestBasicAuth(t *testing.T) { + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user, pass, ok := r.BasicAuth() + require.True(t, ok) + require.Equal(t, "user", user) + require.Equal(t, "pass", pass) + + w.Write([]byte(`{}`)) + }) + + c := client.New(h) + + var resp struct{} + c.MustPost("{ id }", &resp, + client.BasicAuth("user", "pass"), + ) +} + +func TestAddCookie(t *testing.T) { + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := r.Cookie("foo") + require.NoError(t, err) + require.Equal(t, "value", c.Value) + + w.Write([]byte(`{}`)) + }) + + c := client.New(h) + + var resp struct{} + c.MustPost("{ id }", &resp, + client.AddCookie(&http.Cookie{Name: "foo", Value: "value"}), + ) +} diff --git a/client/errors.go b/client/errors.go new file mode 100644 index 0000000000..7bbdd41735 --- /dev/null +++ b/client/errors.go @@ -0,0 +1,12 @@ +package client + +import "encoding/json" + +// RawJsonError is a json formatted error from a GraphQL server. +type RawJsonError struct { + json.RawMessage +} + +func (r RawJsonError) Error() string { + return string(r.RawMessage) +} diff --git a/client/options.go b/client/options.go new file mode 100644 index 0000000000..e600f38285 --- /dev/null +++ b/client/options.go @@ -0,0 +1,50 @@ +package client + +import "net/http" + +// Var adds a variable into the outgoing request +func Var(name string, value interface{}) Option { + return func(bd *Request) { + if bd.Variables == nil { + bd.Variables = map[string]interface{}{} + } + + bd.Variables[name] = value + } +} + +// Operation sets the operation name for the outgoing request +func Operation(name string) Option { + return func(bd *Request) { + bd.OperationName = name + } +} + +// Path sets the url that this request will be made against, useful if you are mounting your entire router +// and need to specify the url to the graphql endpoint. +func Path(url string) Option { + return func(bd *Request) { + bd.HTTP.URL.Path = url + } +} + +// AddHeader adds a header to the outgoing request. This is useful for setting expected Authentication headers for example. +func AddHeader(key string, value string) Option { + return func(bd *Request) { + bd.HTTP.Header.Add(key, value) + } +} + +// BasicAuth authenticates the request using http basic auth. +func BasicAuth(username, password string) Option { + return func(bd *Request) { + bd.HTTP.SetBasicAuth(username, password) + } +} + +// AddCookie adds a cookie to the outgoing request +func AddCookie(cookie *http.Cookie) Option { + return func(bd *Request) { + bd.HTTP.AddCookie(cookie) + } +} diff --git a/client/websocket.go b/client/websocket.go index b036c3176e..a1fdc8fb2e 100644 --- a/client/websocket.go +++ b/client/websocket.go @@ -3,6 +3,8 @@ package client import ( "encoding/json" "fmt" + "io/ioutil" + "net/http/httptest" "strings" "github.com/gorilla/websocket" @@ -13,7 +15,7 @@ const ( connectionInitMsg = "connection_init" // Client -> Server startMsg = "start" // Client -> Server connectionAckMsg = "connection_ack" // Server -> Client - connectionKa = "ka" // Server -> Client + connectionKaMsg = "ka" // Server -> Client dataMsg = "data" // Server -> Client errorMsg = "error" // Server -> Client ) @@ -43,16 +45,23 @@ func (p *Client) Websocket(query string, options ...Option) *Subscription { } func (p *Client) WebsocketWithPayload(query string, initPayload map[string]interface{}, options ...Option) *Subscription { - r := p.mkRequest(query, options...) - requestBody, err := json.Marshal(r) + r, err := p.newRequest(query, options...) if err != nil { - return errorSubscription(fmt.Errorf("encode: %s", err.Error())) + return errorSubscription(fmt.Errorf("request: %s", err.Error())) } + r.Header.Set("Host", "99designs.com") - url := strings.Replace(p.url, "http://", "ws://", -1) + requestBody, err := ioutil.ReadAll(r.Body) + if err != nil { + return errorSubscription(fmt.Errorf("parse body: %s", err.Error())) + } + + srv := httptest.NewServer(p.h) + url := strings.Replace(srv.URL, "http://", "ws://", -1) url = strings.Replace(url, "https://", "wss://", -1) - c, resp, err := websocket.DefaultDialer.Dial(url, nil) + c, _, err := websocket.DefaultDialer.Dial(url, r.Header) + if err != nil { return errorSubscription(fmt.Errorf("dial: %s", err.Error())) } @@ -80,11 +89,11 @@ func (p *Client) WebsocketWithPayload(query string, initPayload map[string]inter var ka operationMessage if err = c.ReadJSON(&ka); err != nil { - return errorSubscription(fmt.Errorf("ka: %s", err.Error())) + return errorSubscription(fmt.Errorf("ack: %s", err.Error())) } - if ka.Type != connectionKa { - return errorSubscription(fmt.Errorf("expected ka message, got %#v", ack)) + if ka.Type != connectionKaMsg { + return errorSubscription(fmt.Errorf("expected ack message, got %#v", ack)) } if err = c.WriteJSON(operationMessage{Type: startMsg, ID: "1", Payload: requestBody}); err != nil { @@ -93,9 +102,8 @@ func (p *Client) WebsocketWithPayload(query string, initPayload map[string]inter return &Subscription{ Close: func() error { - c.Close() - resp.Body.Close() - return nil + srv.Close() + return c.Close() }, Next: func(response interface{}) error { var op operationMessage diff --git a/codegen/testserver/complexity_test.go b/codegen/testserver/complexity_test.go index 9caca33f0c..a48b74303b 100644 --- a/codegen/testserver/complexity_test.go +++ b/codegen/testserver/complexity_test.go @@ -2,7 +2,6 @@ package testserver import ( "context" - "net/http/httptest" "testing" "github.com/99designs/gqlgen/client" @@ -13,8 +12,7 @@ import ( func TestComplexityCollisions(t *testing.T) { resolvers := &Stub{} - srv := httptest.NewServer(handler.GraphQL(NewExecutableSchema(Config{Resolvers: resolvers}))) - c := client.New(srv.URL) + c := client.New(handler.GraphQL(NewExecutableSchema(Config{Resolvers: resolvers}))) resolvers.QueryResolver.Overlapping = func(ctx context.Context) (fields *OverlappingFields, e error) { return &OverlappingFields{ @@ -50,8 +48,7 @@ func TestComplexityFuncs(t *testing.T) { cfg.Complexity.OverlappingFields.Foo = func(childComplexity int) int { return 1000 } cfg.Complexity.OverlappingFields.NewFoo = func(childComplexity int) int { return 5 } - srv := httptest.NewServer(handler.GraphQL(NewExecutableSchema(cfg), handler.ComplexityLimit(10))) - c := client.New(srv.URL) + c := client.New(handler.GraphQL(NewExecutableSchema(cfg), handler.ComplexityLimit(10))) resolvers.QueryResolver.Overlapping = func(ctx context.Context) (fields *OverlappingFields, e error) { return &OverlappingFields{ diff --git a/codegen/testserver/directive_test.go b/codegen/testserver/directive_test.go index 376dc66951..579a3939a4 100644 --- a/codegen/testserver/directive_test.go +++ b/codegen/testserver/directive_test.go @@ -3,7 +3,6 @@ package testserver import ( "context" "fmt" - "net/http/httptest" "testing" "github.com/99designs/gqlgen/client" @@ -72,7 +71,7 @@ func TestDirectives(t *testing.T) { return &s, nil } - srv := httptest.NewServer( + srv := handler.GraphQL( NewExecutableSchema(Config{ Resolvers: resolvers, @@ -160,8 +159,8 @@ func TestDirectives(t *testing.T) { path, _ := ctx.Value("path").([]int) return next(context.WithValue(ctx, "path", append(path, 2))) }), - )) - c := client.New(srv.URL) + ) + c := client.New(srv) t.Run("arg directives", func(t *testing.T) { t.Run("when function errors on directives", func(t *testing.T) { diff --git a/codegen/testserver/generated_test.go b/codegen/testserver/generated_test.go index bcc00d21e5..38362c131f 100644 --- a/codegen/testserver/generated_test.go +++ b/codegen/testserver/generated_test.go @@ -6,7 +6,6 @@ package testserver import ( "context" "net/http" - "net/http/httptest" "reflect" "testing" @@ -40,8 +39,7 @@ func TestUnionFragments(t *testing.T) { return &Circle{Radius: 32}, nil } - srv := httptest.NewServer(handler.GraphQL(NewExecutableSchema(Config{Resolvers: resolvers}))) - c := client.New(srv.URL) + c := client.New(handler.GraphQL(NewExecutableSchema(Config{Resolvers: resolvers}))) t.Run("inline fragment on union", func(t *testing.T) { var resp struct { diff --git a/codegen/testserver/input_test.go b/codegen/testserver/input_test.go index 7e8f1f3b77..851d1f5eec 100644 --- a/codegen/testserver/input_test.go +++ b/codegen/testserver/input_test.go @@ -2,7 +2,6 @@ package testserver import ( "context" - "net/http/httptest" "testing" "github.com/99designs/gqlgen/client" @@ -13,9 +12,7 @@ import ( func TestInput(t *testing.T) { resolvers := &Stub{} - srv := httptest.NewServer(handler.GraphQL(NewExecutableSchema(Config{Resolvers: resolvers}))) - - c := client.New(srv.URL) + c := client.New(handler.GraphQL(NewExecutableSchema(Config{Resolvers: resolvers}))) t.Run("when function errors on directives", func(t *testing.T) { resolvers.QueryResolver.InputSlice = func(ctx context.Context, arg []string) (b bool, e error) { diff --git a/codegen/testserver/introspection_test.go b/codegen/testserver/introspection_test.go index 546390749a..9717879866 100644 --- a/codegen/testserver/introspection_test.go +++ b/codegen/testserver/introspection_test.go @@ -2,7 +2,6 @@ package testserver import ( "context" - "net/http/httptest" "testing" "github.com/99designs/gqlgen/client" @@ -16,14 +15,10 @@ func TestIntrospection(t *testing.T) { t.Run("disabled", func(t *testing.T) { resolvers := &Stub{} - srv := httptest.NewServer( - handler.GraphQL( - NewExecutableSchema(Config{Resolvers: resolvers}), - handler.IntrospectionEnabled(false), - ), - ) - - c := client.New(srv.URL) + c := client.New(handler.GraphQL( + NewExecutableSchema(Config{Resolvers: resolvers}), + handler.IntrospectionEnabled(false), + )) var resp interface{} err := c.Post(introspection.Query, &resp) @@ -33,13 +28,9 @@ func TestIntrospection(t *testing.T) { t.Run("enabled by default", func(t *testing.T) { resolvers := &Stub{} - srv := httptest.NewServer( - handler.GraphQL( - NewExecutableSchema(Config{Resolvers: resolvers}), - ), - ) - - c := client.New(srv.URL) + c := client.New(handler.GraphQL( + NewExecutableSchema(Config{Resolvers: resolvers}), + )) var resp interface{} err := c.Post(introspection.Query, &resp) @@ -55,7 +46,6 @@ func TestIntrospection(t *testing.T) { } }` - c := client.New(srv.URL) var resp struct { Type struct { Fields []struct { @@ -75,18 +65,14 @@ func TestIntrospection(t *testing.T) { t.Run("disabled by middleware", func(t *testing.T) { resolvers := &Stub{} - srv := httptest.NewServer( - handler.GraphQL( - NewExecutableSchema(Config{Resolvers: resolvers}), - handler.RequestMiddleware(func(ctx context.Context, next func(ctx context.Context) []byte) []byte { - graphql.GetRequestContext(ctx).DisableIntrospection = true - - return next(ctx) - }), - ), - ) + c := client.New(handler.GraphQL( + NewExecutableSchema(Config{Resolvers: resolvers}), + handler.RequestMiddleware(func(ctx context.Context, next func(ctx context.Context) []byte) []byte { + graphql.GetRequestContext(ctx).DisableIntrospection = true - c := client.New(srv.URL) + return next(ctx) + }), + )) var resp interface{} err := c.Post(introspection.Query, &resp) diff --git a/codegen/testserver/maps_test.go b/codegen/testserver/maps_test.go index e2490909a7..f77134b016 100644 --- a/codegen/testserver/maps_test.go +++ b/codegen/testserver/maps_test.go @@ -2,7 +2,6 @@ package testserver import ( "context" - "net/http/httptest" "testing" "github.com/99designs/gqlgen/client" @@ -16,12 +15,9 @@ func TestMaps(t *testing.T) { return in, nil } - srv := httptest.NewServer( - handler.GraphQL( - NewExecutableSchema(Config{Resolvers: resolver}), - )) - defer srv.Close() - c := client.New(srv.URL) + c := client.New(handler.GraphQL( + NewExecutableSchema(Config{Resolvers: resolver}), + )) t.Run("unset", func(t *testing.T) { var resp struct { MapStringInterface map[string]interface{} diff --git a/codegen/testserver/middleware_test.go b/codegen/testserver/middleware_test.go index 26ec8e1c62..ebfd0b217b 100644 --- a/codegen/testserver/middleware_test.go +++ b/codegen/testserver/middleware_test.go @@ -2,7 +2,6 @@ package testserver import ( "context" - "net/http/httptest" "testing" "github.com/99designs/gqlgen/client" @@ -28,24 +27,23 @@ func TestMiddleware(t *testing.T) { } areMethods := []bool{} - srv := httptest.NewServer( - handler.GraphQL( - NewExecutableSchema(Config{Resolvers: resolvers}), - handler.ResolverMiddleware(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { - path, _ := ctx.Value("path").([]int) - return next(context.WithValue(ctx, "path", append(path, 1))) - }), - handler.ResolverMiddleware(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { - path, _ := ctx.Value("path").([]int) - return next(context.WithValue(ctx, "path", append(path, 2))) - }), - handler.ResolverMiddleware(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { - areMethods = append(areMethods, graphql.GetResolverContext(ctx).IsMethod) - return next(ctx) - }), - )) + srv := handler.GraphQL( + NewExecutableSchema(Config{Resolvers: resolvers}), + handler.ResolverMiddleware(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { + path, _ := ctx.Value("path").([]int) + return next(context.WithValue(ctx, "path", append(path, 1))) + }), + handler.ResolverMiddleware(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { + path, _ := ctx.Value("path").([]int) + return next(context.WithValue(ctx, "path", append(path, 2))) + }), + handler.ResolverMiddleware(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { + areMethods = append(areMethods, graphql.GetResolverContext(ctx).IsMethod) + return next(ctx) + }), + ) - c := client.New(srv.URL) + c := client.New(srv) var resp struct { User struct { diff --git a/codegen/testserver/modelmethod_test.go b/codegen/testserver/modelmethod_test.go index d484aa745a..6174c15eab 100644 --- a/codegen/testserver/modelmethod_test.go +++ b/codegen/testserver/modelmethod_test.go @@ -2,7 +2,6 @@ package testserver import ( "context" - "net/http/httptest" "testing" "github.com/99designs/gqlgen/client" @@ -19,12 +18,9 @@ func TestModelMethods(t *testing.T) { return true, nil } - srv := httptest.NewServer( - handler.GraphQL( - NewExecutableSchema(Config{Resolvers: resolver}), - )) - defer srv.Close() - c := client.New(srv.URL) + c := client.New(handler.GraphQL( + NewExecutableSchema(Config{Resolvers: resolver}), + )) t.Run("without context", func(t *testing.T) { var resp struct { ModelMethods struct { diff --git a/codegen/testserver/nulls_test.go b/codegen/testserver/nulls_test.go index a9d502a678..eac27a9847 100644 --- a/codegen/testserver/nulls_test.go +++ b/codegen/testserver/nulls_test.go @@ -2,7 +2,6 @@ package testserver import ( "context" - "net/http/httptest" "testing" "github.com/99designs/gqlgen/client" @@ -23,8 +22,7 @@ func TestNullBubbling(t *testing.T) { return &Error{ID: "E1234"}, nil } - srv := httptest.NewServer(handler.GraphQL(NewExecutableSchema(Config{Resolvers: resolvers}))) - c := client.New(srv.URL) + c := client.New(handler.GraphQL(NewExecutableSchema(Config{Resolvers: resolvers}))) t.Run("when function errors on non required field", func(t *testing.T) { var resp struct { diff --git a/codegen/testserver/panics_test.go b/codegen/testserver/panics_test.go index 1972dffed9..92319ed4ab 100644 --- a/codegen/testserver/panics_test.go +++ b/codegen/testserver/panics_test.go @@ -2,7 +2,6 @@ package testserver import ( "context" - "net/http/httptest" "testing" "github.com/stretchr/testify/require" @@ -23,8 +22,7 @@ func TestPanics(t *testing.T) { return []MarshalPanic{MarshalPanic("aa"), MarshalPanic("bb")}, nil } - srv := httptest.NewServer(handler.GraphQL(NewExecutableSchema(Config{Resolvers: resolvers}))) - c := client.New(srv.URL) + c := client.New(handler.GraphQL(NewExecutableSchema(Config{Resolvers: resolvers}))) t.Run("panics in marshallers will not kill server", func(t *testing.T) { var resp interface{} diff --git a/codegen/testserver/primitive_objects_test.go b/codegen/testserver/primitive_objects_test.go index 6ceef5c8f7..cef6a84b2a 100644 --- a/codegen/testserver/primitive_objects_test.go +++ b/codegen/testserver/primitive_objects_test.go @@ -2,7 +2,6 @@ package testserver import ( "context" - "net/http/httptest" "testing" "github.com/99designs/gqlgen/client" @@ -20,8 +19,7 @@ func TestPrimitiveObjects(t *testing.T) { return int(*obj), nil } - srv := httptest.NewServer(handler.GraphQL(NewExecutableSchema(Config{Resolvers: resolvers}))) - c := client.New(srv.URL) + c := client.New(handler.GraphQL(NewExecutableSchema(Config{Resolvers: resolvers}))) t.Run("can fetch value", func(t *testing.T) { var resp struct { @@ -53,8 +51,7 @@ func TestPrimitiveStringObjects(t *testing.T) { return len(string(*obj)), nil } - srv := httptest.NewServer(handler.GraphQL(NewExecutableSchema(Config{Resolvers: resolvers}))) - c := client.New(srv.URL) + c := client.New(handler.GraphQL(NewExecutableSchema(Config{Resolvers: resolvers}))) t.Run("can fetch value", func(t *testing.T) { var resp struct { diff --git a/codegen/testserver/response_extension_test.go b/codegen/testserver/response_extension_test.go index 026aa31716..985c7aa09e 100644 --- a/codegen/testserver/response_extension_test.go +++ b/codegen/testserver/response_extension_test.go @@ -2,7 +2,6 @@ package testserver import ( "context" - "net/http/httptest" "testing" "github.com/99designs/gqlgen/client" @@ -17,7 +16,7 @@ func TestResponseExtension(t *testing.T) { return "Ok", nil } - srv := httptest.NewServer(handler.GraphQL( + srv := handler.GraphQL( NewExecutableSchema(Config{Resolvers: resolvers}), handler.RequestMiddleware(func(ctx context.Context, next func(ctx context.Context) []byte) []byte { rctx := graphql.GetRequestContext(ctx) @@ -26,8 +25,8 @@ func TestResponseExtension(t *testing.T) { } return next(ctx) }), - )) - c := client.New(srv.URL) + ) + c := client.New(srv) raw, _ := c.RawPost(`query { valid }`) require.Equal(t, raw.Extensions["example"], "value") diff --git a/codegen/testserver/scalar_default_test.go b/codegen/testserver/scalar_default_test.go index 4b438560ce..d891737b78 100644 --- a/codegen/testserver/scalar_default_test.go +++ b/codegen/testserver/scalar_default_test.go @@ -2,7 +2,6 @@ package testserver import ( "context" - "net/http/httptest" "testing" "github.com/99designs/gqlgen/client" @@ -13,8 +12,7 @@ import ( func TestDefaultScalarImplementation(t *testing.T) { resolvers := &Stub{} - srv := httptest.NewServer(handler.GraphQL(NewExecutableSchema(Config{Resolvers: resolvers}))) - c := client.New(srv.URL) + c := client.New(handler.GraphQL(NewExecutableSchema(Config{Resolvers: resolvers}))) resolvers.QueryResolver.DefaultScalar = func(ctx context.Context, arg string) (i string, e error) { return arg, nil diff --git a/codegen/testserver/slices_test.go b/codegen/testserver/slices_test.go index e67dd6fdb4..6d314b71bb 100644 --- a/codegen/testserver/slices_test.go +++ b/codegen/testserver/slices_test.go @@ -2,7 +2,6 @@ package testserver import ( "context" - "net/http/httptest" "testing" "github.com/99designs/gqlgen/client" @@ -13,8 +12,7 @@ import ( func TestSlices(t *testing.T) { resolvers := &Stub{} - srv := httptest.NewServer(handler.GraphQL(NewExecutableSchema(Config{Resolvers: resolvers}))) - c := client.New(srv.URL) + c := client.New(handler.GraphQL(NewExecutableSchema(Config{Resolvers: resolvers}))) t.Run("nulls vs empty slices", func(t *testing.T) { resolvers.QueryResolver.Slices = func(ctx context.Context) (slices *Slices, e error) { diff --git a/codegen/testserver/subscription_test.go b/codegen/testserver/subscription_test.go index 16fdc8007e..9e34599155 100644 --- a/codegen/testserver/subscription_test.go +++ b/codegen/testserver/subscription_test.go @@ -3,7 +3,6 @@ package testserver import ( "context" "fmt" - "net/http/httptest" "runtime" "sort" "testing" @@ -67,19 +66,18 @@ func TestSubscriptions(t *testing.T) { return res, nil } - srv := httptest.NewServer( - handler.GraphQL( - NewExecutableSchema(Config{Resolvers: resolvers}), - handler.ResolverMiddleware(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { - path, _ := ctx.Value("path").([]int) - return next(context.WithValue(ctx, "path", append(path, 1))) - }), - handler.ResolverMiddleware(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { - path, _ := ctx.Value("path").([]int) - return next(context.WithValue(ctx, "path", append(path, 2))) - }), - )) - c := client.New(srv.URL) + srv := handler.GraphQL( + NewExecutableSchema(Config{Resolvers: resolvers}), + handler.ResolverMiddleware(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { + path, _ := ctx.Value("path").([]int) + return next(context.WithValue(ctx, "path", append(path, 1))) + }), + handler.ResolverMiddleware(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { + path, _ := ctx.Value("path").([]int) + return next(context.WithValue(ctx, "path", append(path, 2))) + }), + ) + c := client.New(srv) t.Run("wont leak goroutines", func(t *testing.T) { runtime.GC() // ensure no go-routines left from preceding tests diff --git a/codegen/testserver/time_test.go b/codegen/testserver/time_test.go index fc293563e7..1b41df3aea 100644 --- a/codegen/testserver/time_test.go +++ b/codegen/testserver/time_test.go @@ -2,7 +2,6 @@ package testserver import ( "context" - "net/http/httptest" "testing" "time" @@ -14,8 +13,7 @@ import ( func TestTime(t *testing.T) { resolvers := &Stub{} - srv := httptest.NewServer(handler.GraphQL(NewExecutableSchema(Config{Resolvers: resolvers}))) - c := client.New(srv.URL) + c := client.New(handler.GraphQL(NewExecutableSchema(Config{Resolvers: resolvers}))) resolvers.QueryResolver.User = func(ctx context.Context, id int) (user *User, e error) { return &User{}, nil diff --git a/codegen/testserver/tracer_test.go b/codegen/testserver/tracer_test.go index 461e6426d7..07c39b080f 100644 --- a/codegen/testserver/tracer_test.go +++ b/codegen/testserver/tracer_test.go @@ -3,7 +3,6 @@ package testserver import ( "context" "fmt" - "net/http/httptest" "sync" "testing" @@ -23,36 +22,34 @@ func TestTracer(t *testing.T) { var tracerLog []string var mu sync.Mutex - srv := httptest.NewServer( - handler.GraphQL( - NewExecutableSchema(Config{Resolvers: resolvers}), - handler.ResolverMiddleware(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { - path, _ := ctx.Value("path").([]int) - return next(context.WithValue(ctx, "path", append(path, 1))) - }), - handler.ResolverMiddleware(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { - path, _ := ctx.Value("path").([]int) - return next(context.WithValue(ctx, "path", append(path, 2))) - }), - handler.Tracer(&testTracer{ - id: 1, - append: func(s string) { - mu.Lock() - defer mu.Unlock() - tracerLog = append(tracerLog, s) - }, - }), - handler.Tracer(&testTracer{ - id: 2, - append: func(s string) { - mu.Lock() - defer mu.Unlock() - tracerLog = append(tracerLog, s) - }, - }), - )) - defer srv.Close() - c := client.New(srv.URL) + srv := handler.GraphQL( + NewExecutableSchema(Config{Resolvers: resolvers}), + handler.ResolverMiddleware(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { + path, _ := ctx.Value("path").([]int) + return next(context.WithValue(ctx, "path", append(path, 1))) + }), + handler.ResolverMiddleware(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { + path, _ := ctx.Value("path").([]int) + return next(context.WithValue(ctx, "path", append(path, 2))) + }), + handler.Tracer(&testTracer{ + id: 1, + append: func(s string) { + mu.Lock() + defer mu.Unlock() + tracerLog = append(tracerLog, s) + }, + }), + handler.Tracer(&testTracer{ + id: 2, + append: func(s string) { + mu.Lock() + defer mu.Unlock() + tracerLog = append(tracerLog, s) + }, + }), + ) + c := client.New(srv) var resp struct { User struct { @@ -157,13 +154,10 @@ func TestTracer(t *testing.T) { }, } - srv := httptest.NewServer( - handler.GraphQL( - NewExecutableSchema(Config{Resolvers: resolvers}), - handler.Tracer(configurableTracer), - )) - defer srv.Close() - c := client.New(srv.URL) + c := client.New(handler.GraphQL( + NewExecutableSchema(Config{Resolvers: resolvers}), + handler.Tracer(configurableTracer), + )) var resp struct { User struct { diff --git a/codegen/testserver/typefallback_test.go b/codegen/testserver/typefallback_test.go index 0b0f83135e..ac74f36ac7 100644 --- a/codegen/testserver/typefallback_test.go +++ b/codegen/testserver/typefallback_test.go @@ -2,7 +2,6 @@ package testserver import ( "context" - "net/http/httptest" "testing" "github.com/99designs/gqlgen/client" @@ -13,8 +12,7 @@ import ( func TestTypeFallback(t *testing.T) { resolvers := &Stub{} - srv := httptest.NewServer(handler.GraphQL(NewExecutableSchema(Config{Resolvers: resolvers}))) - c := client.New(srv.URL) + c := client.New(handler.GraphQL(NewExecutableSchema(Config{Resolvers: resolvers}))) resolvers.QueryResolver.Fallback = func(ctx context.Context, arg FallbackToStringEncoding) (FallbackToStringEncoding, error) { return arg, nil diff --git a/codegen/testserver/validtypes_test.go b/codegen/testserver/validtypes_test.go index 7c6df2fafe..412ebbe41d 100644 --- a/codegen/testserver/validtypes_test.go +++ b/codegen/testserver/validtypes_test.go @@ -2,7 +2,6 @@ package testserver import ( "context" - "net/http/httptest" "testing" "github.com/99designs/gqlgen/client" @@ -19,8 +18,7 @@ func TestValidType(t *testing.T) { }, nil } - srv := httptest.NewServer(handler.GraphQL(NewExecutableSchema(Config{Resolvers: resolvers}))) - c := client.New(srv.URL) + c := client.New(handler.GraphQL(NewExecutableSchema(Config{Resolvers: resolvers}))) t.Run("fields with differing cases can be distinguished", func(t *testing.T) { var resp struct { diff --git a/codegen/testserver/wrapped_type_test.go b/codegen/testserver/wrapped_type_test.go index 0842fa3ba0..4785307f2b 100644 --- a/codegen/testserver/wrapped_type_test.go +++ b/codegen/testserver/wrapped_type_test.go @@ -2,7 +2,6 @@ package testserver import ( "context" - "net/http/httptest" "testing" "github.com/99designs/gqlgen/client" @@ -14,11 +13,10 @@ import ( func TestWrappedTypes(t *testing.T) { resolvers := &Stub{} - srv := httptest.NewServer(handler.GraphQL(NewExecutableSchema(Config{Resolvers: resolvers}))) - c := client.New(srv.URL) + c := client.New(handler.GraphQL(NewExecutableSchema(Config{Resolvers: resolvers}))) resolvers.QueryResolver.WrappedScalar = func(ctx context.Context) (scalar WrappedScalar, e error) { - return WrappedScalar("hello"), nil + return "hello", nil } resolvers.QueryResolver.WrappedStruct = func(ctx context.Context) (wrappedStruct *WrappedStruct, e error) { diff --git a/example/chat/chat_test.go b/example/chat/chat_test.go index 23c6673ef6..c3eff1832f 100644 --- a/example/chat/chat_test.go +++ b/example/chat/chat_test.go @@ -1,7 +1,6 @@ package chat import ( - "net/http/httptest" "testing" "time" @@ -12,8 +11,7 @@ import ( ) func TestChatSubscriptions(t *testing.T) { - srv := httptest.NewServer(handler.GraphQL(NewExecutableSchema(New()))) - c := client.New(srv.URL) + c := client.New(handler.GraphQL(NewExecutableSchema(New()))) sub := c.Websocket(`subscription @user(username:"vektah") { messageAdded(roomName:"#gophers") { text createdBy } }`) defer sub.Close() diff --git a/example/dataloader/dataloader_test.go b/example/dataloader/dataloader_test.go index 7283f12089..ef20367ad3 100644 --- a/example/dataloader/dataloader_test.go +++ b/example/dataloader/dataloader_test.go @@ -1,7 +1,6 @@ package dataloader import ( - "net/http/httptest" "testing" "github.com/99designs/gqlgen/client" @@ -11,8 +10,7 @@ import ( ) func TestTodo(t *testing.T) { - srv := httptest.NewServer(LoaderMiddleware(handler.GraphQL(NewExecutableSchema(Config{Resolvers: &Resolver{}})))) - c := client.New(srv.URL) + c := client.New(LoaderMiddleware(handler.GraphQL(NewExecutableSchema(Config{Resolvers: &Resolver{}})))) t.Run("create a new todo", func(t *testing.T) { var resp interface{} diff --git a/example/scalars/scalar_test.go b/example/scalars/scalar_test.go index 15ac529070..8268c9acf8 100644 --- a/example/scalars/scalar_test.go +++ b/example/scalars/scalar_test.go @@ -1,7 +1,6 @@ package scalars import ( - "net/http/httptest" "testing" "time" @@ -22,8 +21,7 @@ type RawUser struct { } func TestScalars(t *testing.T) { - srv := httptest.NewServer(handler.GraphQL(NewExecutableSchema(Config{Resolvers: &Resolver{}}))) - c := client.New(srv.URL) + c := client.New(handler.GraphQL(NewExecutableSchema(Config{Resolvers: &Resolver{}}))) t.Run("marshaling", func(t *testing.T) { var resp struct { diff --git a/example/selection/selection_test.go b/example/selection/selection_test.go index a72683a003..9ccf3a150f 100644 --- a/example/selection/selection_test.go +++ b/example/selection/selection_test.go @@ -1,7 +1,6 @@ package selection import ( - "net/http/httptest" "testing" "github.com/99designs/gqlgen/client" @@ -10,8 +9,7 @@ import ( ) func TestSelection(t *testing.T) { - srv := httptest.NewServer(handler.GraphQL(NewExecutableSchema(Config{Resolvers: &Resolver{}}))) - c := client.New(srv.URL) + c := client.New(handler.GraphQL(NewExecutableSchema(Config{Resolvers: &Resolver{}}))) query := `{ events { diff --git a/example/starwars/starwars_test.go b/example/starwars/starwars_test.go index 53f33d34f5..bfb22dbf5c 100644 --- a/example/starwars/starwars_test.go +++ b/example/starwars/starwars_test.go @@ -1,7 +1,6 @@ package starwars import ( - "net/http/httptest" "testing" "github.com/99designs/gqlgen/client" @@ -12,8 +11,7 @@ import ( ) func TestStarwars(t *testing.T) { - srv := httptest.NewServer(handler.GraphQL(generated.NewExecutableSchema(NewResolver()))) - c := client.New(srv.URL) + c := client.New(handler.GraphQL(generated.NewExecutableSchema(NewResolver()))) t.Run("Lukes starships", func(t *testing.T) { var resp struct { diff --git a/example/todo/todo_test.go b/example/todo/todo_test.go index 5bd188625b..965f62b760 100644 --- a/example/todo/todo_test.go +++ b/example/todo/todo_test.go @@ -1,7 +1,6 @@ package todo import ( - "net/http/httptest" "testing" "github.com/99designs/gqlgen/client" @@ -11,8 +10,7 @@ import ( ) func TestTodo(t *testing.T) { - srv := httptest.NewServer(handler.GraphQL(NewExecutableSchema(New()))) - c := client.New(srv.URL) + c := client.New(handler.GraphQL(NewExecutableSchema(New()))) var resp struct { CreateTodo struct{ ID string } @@ -182,8 +180,7 @@ func TestTodo(t *testing.T) { } func TestSkipAndIncludeDirectives(t *testing.T) { - srv := httptest.NewServer(handler.GraphQL(NewExecutableSchema(New()))) - c := client.New(srv.URL) + c := client.New(handler.GraphQL(NewExecutableSchema(New()))) t.Run("skip on field", func(t *testing.T) { var resp map[string]interface{} diff --git a/go.mod b/go.mod index 9ff313d865..8df60052cb 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,7 @@ module github.com/99designs/gqlgen +go 1.12 + require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-chi/chi v3.3.2+incompatible