diff --git a/codegen/testserver/generated_test.go b/codegen/testserver/generated_test.go index 8ea80bef1b..dab97398f4 100644 --- a/codegen/testserver/generated_test.go +++ b/codegen/testserver/generated_test.go @@ -127,6 +127,7 @@ func TestGeneratedServer(t *testing.T) { t.Run("subscriptions", func(t *testing.T) { t.Run("wont leak goroutines", func(t *testing.T) { + runtime.GC() // ensure no go-routines left from preceding tests initialGoroutineCount := runtime.NumGoroutine() sub := c.Websocket(`subscription { updated }`) diff --git a/handler/graphql.go b/handler/graphql.go index eb8880deb3..918671a9a2 100644 --- a/handler/graphql.go +++ b/handler/graphql.go @@ -58,8 +58,6 @@ func (c *Config) newRequestContext(es graphql.ExecutableSchema, doc *ast.QueryDo if hook := c.tracer; hook != nil { reqCtx.Tracer = hook - } else { - reqCtx.Tracer = &graphql.NopTracer{} } if c.complexityLimit > 0 { @@ -259,7 +257,7 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc var cache *lru.Cache if cfg.cacheSize > 0 { var err error - cache, err = lru.New(DefaultCacheSize) + cache, err = lru.New(cfg.cacheSize) if err != nil { // An error is only returned for non-positive cache size // and we already checked for that. @@ -295,7 +293,7 @@ func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } if strings.Contains(r.Header.Get("Upgrade"), "websocket") { - connectWs(gh.exec, w, r, gh.cfg) + connectWs(gh.exec, w, r, gh.cfg, gh.cache) return } diff --git a/handler/websocket.go b/handler/websocket.go index dae262bdf3..c3dc38e0ea 100644 --- a/handler/websocket.go +++ b/handler/websocket.go @@ -11,6 +11,7 @@ import ( "github.com/99designs/gqlgen/graphql" "github.com/gorilla/websocket" + "github.com/hashicorp/golang-lru" "github.com/vektah/gqlparser" "github.com/vektah/gqlparser/ast" "github.com/vektah/gqlparser/gqlerror" @@ -43,11 +44,12 @@ type wsConnection struct { active map[string]context.CancelFunc mu sync.Mutex cfg *Config + cache *lru.Cache initPayload InitPayload } -func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Request, cfg *Config) { +func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Request, cfg *Config, cache *lru.Cache) { ws, err := cfg.upgrader.Upgrade(w, r, http.Header{ "Sec-Websocket-Protocol": []string{"graphql-ws"}, }) @@ -63,6 +65,7 @@ func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Req conn: ws, ctx: r.Context(), cfg: cfg, + cache: cache, } if !conn.init() { @@ -148,10 +151,27 @@ func (c *wsConnection) subscribe(message *operationMessage) bool { return false } - doc, qErr := gqlparser.LoadQuery(c.exec.Schema(), reqParams.Query) - if qErr != nil { - c.sendError(message.ID, qErr...) - return true + var ( + doc *ast.QueryDocument + cacheHit bool + ) + if c.cache != nil { + val, ok := c.cache.Get(reqParams.Query) + if ok { + doc = val.(*ast.QueryDocument) + cacheHit = true + } + } + if !cacheHit { + var qErr gqlerror.List + doc, qErr = gqlparser.LoadQuery(c.exec.Schema(), reqParams.Query) + if qErr != nil { + c.sendError(message.ID, qErr...) + return true + } + if c.cache != nil { + c.cache.Add(reqParams.Query, doc) + } } op := doc.Operations.ForName(reqParams.OperationName)