diff --git a/graphql/e2e/common/common.go b/graphql/e2e/common/common.go index 0bae94521d8..c4868926ad4 100644 --- a/graphql/e2e/common/common.go +++ b/graphql/e2e/common/common.go @@ -221,6 +221,9 @@ func RunAll(t *testing.T) { // schema tests t.Run("graphql descriptions", graphQLDescriptions) + // header tests + t.Run("touched uids header", touchedUidsHeader) + // encoding t.Run("gzip compression", gzipCompression) t.Run("gzip compression header", gzipCompressionHeader) diff --git a/graphql/e2e/common/query.go b/graphql/e2e/common/query.go index 1bf07fa50ec..b2e01f46e5d 100644 --- a/graphql/e2e/common/query.go +++ b/graphql/e2e/common/query.go @@ -19,8 +19,11 @@ package common import ( "encoding/json" "fmt" + "io/ioutil" "math/rand" + "net/http" "sort" + "strconv" "strings" "testing" "time" @@ -61,6 +64,33 @@ func queryCountryByRegExp(t *testing.T, regexp string, expectedCountries []*coun } } +func touchedUidsHeader(t *testing.T) { + query := &GraphQLParams{ + Query: `query { + queryCountry { + name + } + }`, + } + req, err := query.createGQLPost(graphqlURL) + require.NoError(t, err) + + client := http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + require.NoError(t, err) + + // confirm that the header value is a non-negative integer + touchedUidsInHeader, err := strconv.ParseUint(resp.Header.Get("Graphql-TouchedUids"), 10, 64) + require.NoError(t, err) + + // confirm that the value in header is same as the value in body + var gqlResp GraphQLResponse + b, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(b, &gqlResp)) + require.Equal(t, touchedUidsInHeader, uint64(gqlResp.Extensions["touched_uids"].(float64))) +} + // This test checks that all the different combinations of // request sending compressed / uncompressed query and receiving // compressed / uncompressed result. diff --git a/graphql/schema/response.go b/graphql/schema/response.go index 156b9f08c52..2b583c1414b 100644 --- a/graphql/schema/response.go +++ b/graphql/schema/response.go @@ -32,6 +32,14 @@ type Extensions struct { TouchedUids uint64 `json:"touched_uids,omitempty"` } +// GetTouchedUids returns TouchedUids +func (e *Extensions) GetTouchedUids() uint64 { + if e == nil { + return 0 + } + return e.TouchedUids +} + // Merge merges ext with e func (e *Extensions) Merge(ext *Extensions) { if e == nil || ext == nil { @@ -69,6 +77,14 @@ func ErrorResponse(err error) *Response { } } +// GetExtensions returns a *Extensions +func (r *Response) GetExtensions() *Extensions { + if r == nil { + return nil + } + return r.Extensions +} + // WithError generates GraphQL errors from err and records those in r. func (r *Response) WithError(err error) { r.Errors = append(r.Errors, AsGQLErrors(err)...) diff --git a/graphql/web/http.go b/graphql/web/http.go index fd24c112c84..998132d142b 100644 --- a/graphql/web/http.go +++ b/graphql/web/http.go @@ -20,6 +20,7 @@ import ( "compress/gzip" "context" "encoding/json" + "strconv" "io" "io/ioutil" @@ -39,6 +40,8 @@ import ( "go.opencensus.io/trace" ) +const touchedUidsHeader = "Graphql-TouchedUids" + // An IServeGraphQL can serve a GraphQL endpoint (currently only ons http) type IServeGraphQL interface { @@ -86,6 +89,9 @@ func (gh *graphqlHandler) Resolve(ctx context.Context, gqlReq *schema.Request) * func write(w http.ResponseWriter, rr *schema.Response, acceptGzip bool) { var out io.Writer = w + // set TouchedUids header + w.Header().Set(touchedUidsHeader, strconv.FormatUint(rr.GetExtensions().GetTouchedUids(), 10)) + // If the receiver accepts gzip, then we would update the writer // and send gzipped content instead. if acceptGzip {