diff --git a/handler/graphql.go b/handler/graphql.go index 6f965f8600..6742095058 100644 --- a/handler/graphql.go +++ b/handler/graphql.go @@ -34,7 +34,7 @@ type params struct { } type extensions struct { - PQ *persistedQuery `json:"persistedQuery"` + PersistedQuery *persistedQuery `json:"persistedQuery"` } type persistedQuery struct { @@ -47,6 +47,11 @@ const ( errPersistedQueryNotFound = "PersistedQueryNotFound" ) +type PersistedQueryCache interface { + Add(ctx context.Context, hash string, query string) + Get(ctx context.Context, hash string) (string, bool) +} + type Config struct { cacheSize int upgrader websocket.Upgrader @@ -61,7 +66,7 @@ type Config struct { connectionKeepAlivePingInterval time.Duration uploadMaxMemory int64 uploadMaxSize int64 - apqCacheSize int + apqCache PersistedQueryCache } func (c *Config) newRequestContext(es graphql.ExecutableSchema, doc *ast.QueryDocument, op *ast.OperationDefinition, query string, variables map[string]interface{}) *graphql.RequestContext { @@ -303,11 +308,10 @@ func WebsocketKeepAliveDuration(duration time.Duration) Option { } } -// APQCacheSize sets the maximum size of the automatic persisted query cache. -// If size is less than or equal to 0, the cache is disabled. -func APQCacheSize(size int) Option { +// Add cache that will hold queries for automatic persisted queries (APQ) +func EnablePersistedQueryCache(cache PersistedQueryCache) Option { return func(cfg *Config) { - cfg.apqCacheSize = size + cfg.apqCache = cache } } @@ -353,22 +357,10 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc cfg.tracer = &graphql.NopTracer{} } - var apqCache *lru.Cache - if cfg.apqCacheSize > 0 { - var err error - apqCache, err = lru.New(cfg.apqCacheSize) - if err != nil { - // An error is only returned for non-positive cache size - // and we already checked for that. - panic("unexpected error creating apq cache: " + err.Error()) - } - } - handler := &graphqlHandler{ cfg: cfg, cache: cache, exec: exec, - apqCache: apqCache, } return handler.ServeHTTP @@ -380,7 +372,6 @@ type graphqlHandler struct { cfg *Config cache *lru.Cache exec graphql.ExecutableSchema - apqCache *lru.Cache } func computeQueryHash(query string) string { @@ -461,32 +452,34 @@ func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx := r.Context() var queryHash string - apq := reqParams.Extensions != nil && reqParams.Extensions.PQ != nil + apqRegister := false + apq := reqParams.Extensions != nil && reqParams.Extensions.PersistedQuery != nil if apq { // client has enabled apq - queryHash = reqParams.Extensions.PQ.Sha256 - if gh.apqCache == nil { + queryHash = reqParams.Extensions.PersistedQuery.Sha256 + if gh.cfg.apqCache == nil { // server has disabled apq sendErrorf(w, http.StatusOK, errPersistedQueryNotSupported) return } - if reqParams.Extensions.PQ.Version != 1 { + if reqParams.Extensions.PersistedQuery.Version != 1 { sendErrorf(w, http.StatusOK, "Unsupported persisted query version") return } if reqParams.Query == "" { // client sent optimistic query hash without query string - query, ok := gh.apqCache.Get(queryHash) + query, ok := gh.cfg.apqCache.Get(ctx, queryHash) if !ok { sendErrorf(w, http.StatusOK, errPersistedQueryNotFound) return } - reqParams.Query = query.(string) + reqParams.Query = query } else { if computeQueryHash(reqParams.Query) != queryHash { sendErrorf(w, http.StatusOK, "provided sha does not match query") return } + apqRegister = true } } else if reqParams.Query == "" { sendErrorf(w, http.StatusUnprocessableEntity, "Must provide query string") @@ -547,9 +540,9 @@ func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - if apq && gh.apqCache != nil { + if apqRegister && gh.cfg.apqCache != nil { // Add to persisted query cache - gh.apqCache.Add(queryHash, reqParams.Query) + gh.cfg.apqCache.Add(ctx, queryHash, reqParams.Query) } switch op.Operation { diff --git a/handler/graphql_test.go b/handler/graphql_test.go index 93a257706d..bfcc11082c 100644 --- a/handler/graphql_test.go +++ b/handler/graphql_test.go @@ -15,6 +15,7 @@ import ( "testing" "github.com/99designs/gqlgen/graphql" + lru "github.com/hashicorp/golang-lru" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vektah/gqlparser/ast" @@ -765,8 +766,30 @@ func TestBytesRead(t *testing.T) { }) } +type memoryPersistedQueryCache struct { + cache *lru.Cache +} + +func newMemoryPersistedQueryCache(size int) (*memoryPersistedQueryCache, error) { + cache, err := lru.New(size) + return &memoryPersistedQueryCache{cache: cache}, err +} + +func (c *memoryPersistedQueryCache) Add(ctx context.Context, hash string, query string) { + c.cache.Add(hash, query) +} + +func (c *memoryPersistedQueryCache) Get(ctx context.Context, hash string) (string, bool) { + val, ok := c.cache.Get(hash) + if !ok { + return "", ok + } + return val.(string), ok +} func TestAutomaticPersistedQuery(t *testing.T) { - h := GraphQL(&executableSchemaStub{}, APQCacheSize(1000)) + cache, err := newMemoryPersistedQueryCache(1000) + require.NoError(t, err) + h := GraphQL(&executableSchemaStub{}, EnablePersistedQueryCache(cache)) t.Run("automatic persisted query POST", func(t *testing.T) { // normal queries should be unaffected resp := doRequest(h, "POST", "/graphql", `{"query":"{ me { name } }"}`)