diff --git a/graphql/handler/apq.go b/graphql/handler/apq.go new file mode 100644 index 00000000000..ba8c21499c5 --- /dev/null +++ b/graphql/handler/apq.go @@ -0,0 +1,70 @@ +package handler + +import ( + "context" + "crypto/sha256" + "encoding/hex" + + "github.com/99designs/gqlgen/graphql" + "github.com/mitchellh/mapstructure" +) + +const ( + errPersistedQueryNotSupported = "PersistedQueryNotSupported" + errPersistedQueryNotFound = "PersistedQueryNotFound" +) + +// AutomaticPersistedQuery saves client upload by optimistically sending only the hashes of queries, if the server +// does not yet know what the query is for the hash it will respond telling the client to send the query along with the +// hash in the next request. +// see https://github.com/apollographql/apollo-link-persisted-queries +func AutomaticPersistedQuery(cache Cache) Middleware { + return func(next Handler) Handler { + return func(ctx context.Context, writer Writer) { + rc := graphql.GetRequestContext(ctx) + + if rc.Extensions["persistedQuery"] == nil { + next(ctx, writer) + return + } + + var extension struct { + Sha256 string `json:"sha256Hash"` + Version int64 `json:"version"` + } + + if err := mapstructure.Decode(rc.Extensions["persistedQuery"], &extension); err != nil { + writer.Error("Invalid APQ extension data") + return + } + + if extension.Version != 1 { + writer.Error("Unsupported APQ version") + return + } + + if rc.RawQuery == "" { + // client sent optimistic query hash without query string, get it from the cache + query, ok := cache.Get(extension.Sha256) + if !ok { + writer.Error(errPersistedQueryNotFound) + return + } + rc.RawQuery = query.(string) + } else { + // client sent optimistic query hash with query string, verify and store it + if computeQueryHash(rc.RawQuery) != extension.Sha256 { + writer.Error("Provided APQ hash does not match query") + return + } + cache.Add(extension.Sha256, rc.RawQuery) + } + next(ctx, writer) + } + } +} + +func computeQueryHash(query string) string { + b := sha256.Sum256([]byte(query)) + return hex.EncodeToString(b[:]) +} diff --git a/graphql/handler/apq_test.go b/graphql/handler/apq_test.go new file mode 100644 index 00000000000..28f2f825862 --- /dev/null +++ b/graphql/handler/apq_test.go @@ -0,0 +1,128 @@ +package handler + +import ( + "testing" + + "github.com/99designs/gqlgen/graphql" + "github.com/stretchr/testify/require" +) + +func TestAPQ(t *testing.T) { + const query = "{ me { name } }" + const hash = "b8d9506e34c83b0e53c2aa463624fcea354713bc38f95276e6f0bd893ffb5b88" + + t.Run("with query and no hash", func(t *testing.T) { + rc := testMiddleware(AutomaticPersistedQuery(MapCache{}), graphql.RequestContext{ + RawQuery: "original query", + }) + + require.True(t, rc.InvokedNext) + require.Equal(t, "original query", rc.ResultContext.RawQuery) + }) + + t.Run("with hash miss and no query", func(t *testing.T) { + rc := testMiddleware(AutomaticPersistedQuery(MapCache{}), graphql.RequestContext{ + RawQuery: "", + Extensions: map[string]interface{}{ + "persistedQuery": map[string]interface{}{ + "sha256": hash, + "version": 1, + }, + }, + }) + + require.False(t, rc.InvokedNext) + require.Equal(t, "PersistedQueryNotFound", rc.Response.Errors[0].Message) + }) + + t.Run("with hash miss and query", func(t *testing.T) { + cache := MapCache{} + rc := testMiddleware(AutomaticPersistedQuery(cache), graphql.RequestContext{ + RawQuery: query, + Extensions: map[string]interface{}{ + "persistedQuery": map[string]interface{}{ + "sha256": hash, + "version": 1, + }, + }, + }) + + require.True(t, rc.InvokedNext, rc.Response.Errors) + require.Equal(t, "{ me { name } }", rc.ResultContext.RawQuery) + require.Equal(t, "{ me { name } }", cache[hash]) + }) + + t.Run("with hash miss and query", func(t *testing.T) { + cache := MapCache{} + rc := testMiddleware(AutomaticPersistedQuery(cache), graphql.RequestContext{ + RawQuery: query, + Extensions: map[string]interface{}{ + "persistedQuery": map[string]interface{}{ + "sha256": hash, + "version": 1, + }, + }, + }) + + require.True(t, rc.InvokedNext, rc.Response.Errors) + require.Equal(t, "{ me { name } }", rc.ResultContext.RawQuery) + require.Equal(t, "{ me { name } }", cache[hash]) + }) + + t.Run("with hash hit and no query", func(t *testing.T) { + cache := MapCache{ + hash: query, + } + rc := testMiddleware(AutomaticPersistedQuery(cache), graphql.RequestContext{ + RawQuery: "", + Extensions: map[string]interface{}{ + "persistedQuery": map[string]interface{}{ + "sha256": hash, + "version": 1, + }, + }, + }) + + require.True(t, rc.InvokedNext, rc.Response.Errors) + require.Equal(t, "{ me { name } }", rc.ResultContext.RawQuery) + }) + + t.Run("with malformed extension payload", func(t *testing.T) { + rc := testMiddleware(AutomaticPersistedQuery(MapCache{}), graphql.RequestContext{ + Extensions: map[string]interface{}{ + "persistedQuery": "asdf", + }, + }) + + require.False(t, rc.InvokedNext) + require.Equal(t, "Invalid APQ extension data", rc.Response.Errors[0].Message) + }) + + t.Run("with invalid extension version", func(t *testing.T) { + rc := testMiddleware(AutomaticPersistedQuery(MapCache{}), graphql.RequestContext{ + Extensions: map[string]interface{}{ + "persistedQuery": map[string]interface{}{ + "version": 2, + }, + }, + }) + + require.False(t, rc.InvokedNext) + require.Equal(t, "Unsupported APQ version", rc.Response.Errors[0].Message) + }) + + t.Run("with hash mismatch", func(t *testing.T) { + rc := testMiddleware(AutomaticPersistedQuery(MapCache{}), graphql.RequestContext{ + RawQuery: query, + Extensions: map[string]interface{}{ + "persistedQuery": map[string]interface{}{ + "sha256": "badhash", + "version": 1, + }, + }, + }) + + require.False(t, rc.InvokedNext) + require.Equal(t, "Provided APQ hash does not match query", rc.Response.Errors[0].Message) + }) +} diff --git a/graphql/handler/cache.go b/graphql/handler/cache.go new file mode 100644 index 00000000000..dff56702779 --- /dev/null +++ b/graphql/handler/cache.go @@ -0,0 +1,24 @@ +package handler + +// Cache is a shared store for APQ and query AST caching +type Cache interface { + // Get looks up a key's value from the cache. + Get(key string) (value interface{}, ok bool) + + // Add adds a value to the cache. + Add(key, value string) +} + +// MapCache is the simplest implementation of a cache, because it can not evict it should only be used in tests +type MapCache map[string]interface{} + +// Get looks up a key's value from the cache. +func (m MapCache) Get(key string) (value interface{}, ok bool) { + v, ok := m[key] + return v, ok +} + +// Add adds a value to the cache. +func (m MapCache) Add(key, value string) { + m[key] = value +} diff --git a/graphql/handler/complexity_test.go b/graphql/handler/complexity_test.go index 8ee346fa38d..c53e858a364 100644 --- a/graphql/handler/complexity_test.go +++ b/graphql/handler/complexity_test.go @@ -13,7 +13,7 @@ func TestComplexityLimit(t *testing.T) { })) require.True(t, rc.InvokedNext) - require.Equal(t, 10, rc.ComplexityLimit) + require.Equal(t, 10, rc.ResultContext.ComplexityLimit) } func TestComplexityLimitFunc(t *testing.T) { @@ -22,5 +22,5 @@ func TestComplexityLimitFunc(t *testing.T) { })) require.True(t, rc.InvokedNext) - require.Equal(t, 22, rc.ComplexityLimit) + require.Equal(t, 22, rc.ResultContext.ComplexityLimit) } diff --git a/graphql/handler/errors_test.go b/graphql/handler/errors_test.go index 683d2a0bd2d..2ccacba9386 100644 --- a/graphql/handler/errors_test.go +++ b/graphql/handler/errors_test.go @@ -18,7 +18,7 @@ func TestErrorPresenter(t *testing.T) { require.True(t, rc.InvokedNext) // cant test for function equality in go, so testing the return type instead - require.Equal(t, "boom", rc.ErrorPresenter(nil, nil).Message) + require.Equal(t, "boom", rc.ResultContext.ErrorPresenter(nil, nil).Message) } func TestRecoverFunc(t *testing.T) { @@ -28,5 +28,5 @@ func TestRecoverFunc(t *testing.T) { require.True(t, rc.InvokedNext) // cant test for function equality in go, so testing the return type instead - assert.Equal(t, "boom", rc.Recover(nil, nil).Error()) + assert.Equal(t, "boom", rc.ResultContext.Recover(nil, nil).Error()) } diff --git a/graphql/handler/introspection_test.go b/graphql/handler/introspection_test.go index d4b65236b0f..e40de31cb04 100644 --- a/graphql/handler/introspection_test.go +++ b/graphql/handler/introspection_test.go @@ -15,5 +15,5 @@ func TestIntrospection(t *testing.T) { require.True(t, rc.InvokedNext) // cant test for function equality in go, so testing the return type instead - assert.False(t, rc.DisableIntrospection) + assert.False(t, rc.ResultContext.DisableIntrospection) } diff --git a/graphql/handler/server.go b/graphql/handler/server.go index 9bff842c9dd..6fd06ab3787 100644 --- a/graphql/handler/server.go +++ b/graphql/handler/server.go @@ -35,6 +35,18 @@ type ( ResponseStream func() *graphql.Response ) +func (w Writer) Errorf(format string, args ...interface{}) { + w(&graphql.Response{ + Errors: gqlerror.List{{Message: fmt.Sprintf(format, args...)}}, + }) +} + +func (w Writer) Error(msg string) { + w(&graphql.Response{ + Errors: gqlerror.List{{Message: msg}}, + }) +} + func (s *Server) AddTransport(transport Transport) { s.transports = append(s.transports, transport) } diff --git a/graphql/handler/utils_test.go b/graphql/handler/utils_test.go index f5feae8318f..3ec9dae658f 100644 --- a/graphql/handler/utils_test.go +++ b/graphql/handler/utils_test.go @@ -7,24 +7,24 @@ import ( ) type middlewareContext struct { - *graphql.RequestContext - InvokedNext bool + InvokedNext bool + ResultContext graphql.RequestContext + Response graphql.Response } func testMiddleware(m Middleware, initialContexts ...graphql.RequestContext) middlewareContext { - rc := &graphql.RequestContext{} + var c middlewareContext + initial := &graphql.RequestContext{} if len(initialContexts) > 0 { - rc = &initialContexts[0] + initial = &initialContexts[0] } m(func(ctx context.Context, writer Writer) { - rc = graphql.GetRequestContext(ctx) - })(graphql.WithRequestContext(context.Background(), rc), noopWriter) + c.ResultContext = *graphql.GetRequestContext(ctx) + c.InvokedNext = true + })(graphql.WithRequestContext(context.Background(), initial), func(response *graphql.Response) { + c.Response = *response + }) - return middlewareContext{ - InvokedNext: rc != nil, - RequestContext: rc, - } + return c } - -func noopWriter(response *graphql.Response) {} diff --git a/handler/graphql.go b/handler/graphql.go index ed0317a3f2e..28f4a9e7030 100644 --- a/handler/graphql.go +++ b/handler/graphql.go @@ -2,8 +2,6 @@ package handler import ( "context" - "crypto/sha256" - "encoding/hex" "encoding/json" "errors" "fmt" @@ -42,16 +40,6 @@ type persistedQuery struct { Version int64 `json:"version"` } -const ( - errPersistedQueryNotSupported = "PersistedQueryNotSupported" - errPersistedQueryNotFound = "PersistedQueryNotFound" -) - -type PersistedQueryCache interface { - Add(ctx context.Context, hash string, query string) - Get(ctx context.Context, hash string) (string, bool) -} - type websocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error) type Config struct { @@ -69,7 +57,6 @@ type Config struct { connectionKeepAlivePingInterval time.Duration uploadMaxMemory int64 uploadMaxSize int64 - apqCache PersistedQueryCache } func (c *Config) newRequestContext(ctx context.Context, es graphql.ExecutableSchema, doc *ast.QueryDocument, op *ast.OperationDefinition, operationName, query string, variables map[string]interface{}) (*graphql.RequestContext, error) { @@ -271,13 +258,6 @@ func WebsocketKeepAliveDuration(duration time.Duration) Option { } } -// Add cache that will hold queries for automatic persisted queries (APQ) -func EnablePersistedQueryCache(cache PersistedQueryCache) Option { - return func(cfg *Config) { - cfg.apqCache = cache - } -} - const DefaultCacheSize = 1000 const DefaultConnectionKeepAlivePingInterval = 25 * time.Second @@ -337,11 +317,6 @@ type graphqlHandler struct { exec graphql.ExecutableSchema } -func computeQueryHash(query string) string { - b := sha256.Sum256([]byte(query)) - return hex.EncodeToString(b[:]) -} - func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodOptions { w.Header().Set("Allow", "OPTIONS, GET, POST") @@ -414,37 +389,7 @@ func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - var queryHash string - apqRegister := false - apq := reqParams.Extensions != nil && reqParams.Extensions.PersistedQuery != nil - if apq { - // client has enabled apq - queryHash = reqParams.Extensions.PersistedQuery.Sha256 - if gh.cfg.apqCache == nil { - // server has disabled apq - sendErrorf(w, http.StatusOK, errPersistedQueryNotSupported) - return - } - if reqParams.Extensions.PersistedQuery.Version != 1 { - sendErrorf(w, http.StatusOK, "Unsupported persisted query version") - return - } - if reqParams.Query == "" { - // client sent optimistic query hash without query string - query, ok := gh.cfg.apqCache.Get(ctx, queryHash) - if !ok { - sendErrorf(w, http.StatusOK, errPersistedQueryNotFound) - return - } - reqParams.Query = query - } else { - if computeQueryHash(reqParams.Query) != queryHash { - sendErrorf(w, http.StatusOK, "provided sha does not match query") - return - } - apqRegister = true - } - } else if reqParams.Query == "" { + if reqParams.Query == "" { sendErrorf(w, http.StatusUnprocessableEntity, "Must provide query string") return } @@ -507,11 +452,6 @@ func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - if apqRegister && gh.cfg.apqCache != nil { - // Add to persisted query cache - gh.cfg.apqCache.Add(ctx, queryHash, reqParams.Query) - } - switch op.Operation { case ast.Query: b, err := json.Marshal(gh.exec.Query(ctx, op)) diff --git a/handler/graphql_test.go b/handler/graphql_test.go index 06ba718429a..c932e9a0729 100644 --- a/handler/graphql_test.go +++ b/handler/graphql_test.go @@ -15,7 +15,6 @@ import ( "testing" "github.com/99designs/gqlgen/graphql" - lru "github.com/hashicorp/golang-lru" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vektah/gqlparser/ast" @@ -765,68 +764,3 @@ func TestBytesRead(t *testing.T) { require.Equal(t, "0193456789", string(got)) }) } - -type memoryPersistedQueryCache struct { - cache *lru.Cache -} - -func newMemoryPersistedQueryCache(size int) (*memoryPersistedQueryCache, error) { - cache, err := lru.New(size) - return &memoryPersistedQueryCache{cache: cache}, err -} - -func (c *memoryPersistedQueryCache) Add(ctx context.Context, hash string, query string) { - c.cache.Add(hash, query) -} - -func (c *memoryPersistedQueryCache) Get(ctx context.Context, hash string) (string, bool) { - val, ok := c.cache.Get(hash) - if !ok { - return "", ok - } - return val.(string), ok -} -func TestAutomaticPersistedQuery(t *testing.T) { - cache, err := newMemoryPersistedQueryCache(1000) - require.NoError(t, err) - h := GraphQL(&executableSchemaStub{}, EnablePersistedQueryCache(cache)) - t.Run("automatic persisted query POST", func(t *testing.T) { - // normal queries should be unaffected - resp := doRequest(h, "POST", "/graphql", `{"query":"{ me { name } }"}`) - assert.Equal(t, http.StatusOK, resp.Code) - assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) - - // first pass: optimistic hash without query string - resp = doRequest(h, "POST", "/graphql", `{"extensions":{"persistedQuery":{"sha256Hash":"b8d9506e34c83b0e53c2aa463624fcea354713bc38f95276e6f0bd893ffb5b88","version":1}}}`) - assert.Equal(t, http.StatusOK, resp.Code) - assert.Equal(t, `{"errors":[{"message":"PersistedQueryNotFound"}],"data":null}`, resp.Body.String()) - // second pass: query with query string and query hash - resp = doRequest(h, "POST", "/graphql", `{"query":"{ me { name } }", "extensions":{"persistedQuery":{"sha256Hash":"b8d9506e34c83b0e53c2aa463624fcea354713bc38f95276e6f0bd893ffb5b88","version":1}}}`) - assert.Equal(t, http.StatusOK, resp.Code) - assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) - // future requests without query string - resp = doRequest(h, "POST", "/graphql", `{"extensions":{"persistedQuery":{"sha256Hash":"b8d9506e34c83b0e53c2aa463624fcea354713bc38f95276e6f0bd893ffb5b88","version":1}}}`) - assert.Equal(t, http.StatusOK, resp.Code) - assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) - }) - - t.Run("automatic persisted query GET", func(t *testing.T) { - // normal queries should be unaffected - resp := doRequest(h, "GET", "/graphql?query={me{name}}", "") - assert.Equal(t, http.StatusOK, resp.Code) - assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) - - // first pass: optimistic hash without query string - resp = doRequest(h, "GET", `/graphql?extensions={"persistedQuery":{"version":1,"sha256Hash":"b58723c4fd7ce18043ae53635b304ba6cee765a67009645b04ca01e80ce1c065"}}`, "") - assert.Equal(t, http.StatusOK, resp.Code) - assert.Equal(t, `{"errors":[{"message":"PersistedQueryNotFound"}],"data":null}`, resp.Body.String()) - // second pass: query with query string and query hash - resp = doRequest(h, "GET", `/graphql?query={me{name}}&extensions={"persistedQuery":{"sha256Hash":"b58723c4fd7ce18043ae53635b304ba6cee765a67009645b04ca01e80ce1c065","version":1}}}`, "") - assert.Equal(t, http.StatusOK, resp.Code) - assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) - // future requests without query string - resp = doRequest(h, "GET", `/graphql?extensions={"persistedQuery":{"version":1,"sha256Hash":"b58723c4fd7ce18043ae53635b304ba6cee765a67009645b04ca01e80ce1c065"}}`, "") - assert.Equal(t, http.StatusOK, resp.Code) - assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) - }) -}