Skip to content

Commit

Permalink
Allow absolute URLs to the GraphQL playground (#2142)
Browse files Browse the repository at this point in the history
* Allow absolute URLs to the GraphQL playground

* Add test for playground URLs

* Close res.Body in playground test
  • Loading branch information
marcusirgens authored May 5, 2022
1 parent 3228f36 commit d38911f
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 10 deletions.
52 changes: 42 additions & 10 deletions graphql/playground/playground.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package playground
import (
"html/template"
"net/http"
"net/url"
)

var page = template.Must(template.New("graphiql").Parse(`<!DOCTYPE html>
Expand Down Expand Up @@ -36,9 +37,14 @@ var page = template.Must(template.New("graphiql").Parse(`<!DOCTYPE html>
></script>
<script>
const url = location.protocol + '//' + location.host + '{{.endpoint}}';
{{- if .endpointIsAbsolute}}
const url = {{.endpoint}};
const subscriptionUrl = {{.subscriptionEndpoint}};
{{- else}}
const url = location.protocol + '//' + location.host + {{.endpoint}};
const wsProto = location.protocol == 'https:' ? 'wss:' : 'ws:';
const subscriptionUrl = wsProto + '//' + location.host + '{{.endpoint}}';
const subscriptionUrl = wsProto + '//' + location.host + '/foo';
{{- end}}
const fetcher = GraphiQL.createFetcher({ url, subscriptionUrl });
ReactDOM.render(
Expand All @@ -59,17 +65,43 @@ var page = template.Must(template.New("graphiql").Parse(`<!DOCTYPE html>
func Handler(title string, endpoint string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Content-Type", "text/html")
err := page.Execute(w, map[string]string{
"title": title,
"endpoint": endpoint,
"version": "1.8.2",
"cssSRI": "sha256-CDHiHbYkDSUc3+DS2TU89I9e2W3sJRUOqSmp7JC+LBw=",
"jsSRI": "sha256-X8vqrqZ6Rvvoq4tvRVM3LoMZCQH8jwW92tnX0iPiHPc=",
"reactSRI": "sha256-Ipu/TQ50iCCVZBUsZyNJfxrDk0E2yhaEIz0vqI+kFG8=",
"reactDOMSRI": "sha256-nbMykgB6tsOFJ7OdVmPpdqMFVk4ZsqWocT6issAPUF0=",
err := page.Execute(w, map[string]interface{}{
"title": title,
"endpoint": endpoint,
"endpointIsAbsolute": endpointHasScheme(endpoint),
"subscriptionEndpoint": getSubscriptionEndpoint(endpoint),
"version": "1.8.2",
"cssSRI": "sha256-CDHiHbYkDSUc3+DS2TU89I9e2W3sJRUOqSmp7JC+LBw=",
"jsSRI": "sha256-X8vqrqZ6Rvvoq4tvRVM3LoMZCQH8jwW92tnX0iPiHPc=",
"reactSRI": "sha256-Ipu/TQ50iCCVZBUsZyNJfxrDk0E2yhaEIz0vqI+kFG8=",
"reactDOMSRI": "sha256-nbMykgB6tsOFJ7OdVmPpdqMFVk4ZsqWocT6issAPUF0=",
})
if err != nil {
panic(err)
}
}
}

// endpointHasScheme checks if the endpoint has a scheme.
func endpointHasScheme(endpoint string) bool {
u, err := url.Parse(endpoint)
return err == nil && u.Scheme != ""
}

// getSubscriptionEndpoint returns the subscription endpoint for the given
// endpoint if it is parsable as a URL, or an empty string.
func getSubscriptionEndpoint(endpoint string) string {
u, err := url.Parse(endpoint)
if err != nil {
return ""
}

switch u.Scheme {
case "https":
u.Scheme = "wss"
default:
u.Scheme = "ws"
}

return u.String()
}
56 changes: 56 additions & 0 deletions graphql/playground/playground_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package playground

import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"regexp"
"testing"
)

func TestHandler_createsAbsoluteURLs(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "https://example.org/query", nil)
h := Handler("example.org API", "https://example.org/query")
h.ServeHTTP(rec, req)

res := rec.Result()
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Errorf("res.StatusCode = %d; want %d", res.StatusCode, http.StatusOK)
}

b, err := io.ReadAll(res.Body)
if err != nil {
panic(fmt.Errorf("reading res.Body: %w", err))
}

want := regexp.MustCompile(`(?m)^.*url\s*=\s*['"]https:\/\/example\.org\/query["'].*$`)
if !want.Match(b) {
t.Errorf("no match for %s in response body", want.String())
}
}

func TestHandler_createsRelativeURLs(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "http://localhost:8080/query", nil)
h := Handler("example.org API", "/query")
h.ServeHTTP(rec, req)

res := rec.Result()
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
t.Errorf("res.StatusCode = %d; want %d", res.StatusCode, http.StatusOK)
}

b, err := io.ReadAll(res.Body)
if err != nil {
panic(fmt.Errorf("reading res.Body: %w", err))
}

want := regexp.MustCompile(`(?m)^.*url\s*=\s*location.protocol.*$`)
if !want.Match(b) {
t.Errorf("no match for %s in response body", want.String())
}
}

0 comments on commit d38911f

Please sign in to comment.