diff --git a/README.md b/README.md index 77fcadf..21f76e1 100644 --- a/README.md +++ b/README.md @@ -71,3 +71,10 @@ use the `-yml` and `-port` flags. ``` $ sally -yml site.yaml -port 5000 ``` + +### Custom Templates + +You can provide your own custom templates. For this, create a directory with `.html` +templates and provide it via the `-templates` flag. You only need to provide the +templates you want to override. See [templates](./templates/) for the available +templates. diff --git a/config_test.go b/config_test.go index e7ac247..21e67ef 100644 --- a/config_test.go +++ b/config_test.go @@ -9,7 +9,7 @@ import ( ) func TestParse(t *testing.T) { - path, clean := TempFile(t, ` + path := TempFile(t, ` url: google.golang.org packages: @@ -19,7 +19,6 @@ packages: vcs: svn `) - defer clean() config, err := Parse(path) assert.NoError(t, err) @@ -33,7 +32,7 @@ packages: } func TestParsePackageLevelURL(t *testing.T) { - path, clean := TempFile(t, ` + path := TempFile(t, ` url: google.golang.org packages: @@ -42,7 +41,6 @@ packages: url: go.uber.org `) - defer clean() config, err := Parse(path) assert.NoError(t, err) @@ -65,7 +63,7 @@ func TestParseGodocServer(t *testing.T) { for _, tt := range tests { t.Run(tt.give, func(t *testing.T) { - path, clean := TempFile(t, fmt.Sprintf(` + path := TempFile(t, fmt.Sprintf(` godoc: host: %q url: google.golang.org @@ -73,7 +71,6 @@ packages: grpc: repo: github.com/grpc/grpc-go `, tt.give)) - defer clean() config, err := Parse(path) require.NoError(t, err) diff --git a/handler.go b/handler.go index ccf8d27..d25ed38 100644 --- a/handler.go +++ b/handler.go @@ -2,26 +2,26 @@ package main import ( "cmp" - "fmt" + "embed" + "errors" "html/template" "net/http" "path" "slices" "strings" - - "go.uber.org/sally/templates" ) var ( - indexTemplate = template.Must( - template.New("index.html").Parse(templates.Index)) - packageTemplate = template.Must( - template.New("package.html").Parse(templates.Package)) + //go:embed templates/*.html + templateFiles embed.FS + + _templates = template.Must(template.ParseFS(templateFiles, "templates/*.html")) ) -// CreateHandler builds a new handler -// with the provided package configuration. -// The returned handler provides the following endpoints: +// CreateHandler builds a new handler with the provided package configuration, +// and templates. The templates object must contain the following: index.html, +// package.html, and 404.html. The returned handler provides the following +// endpoints: // // GET / // Index page listing all packages. @@ -32,7 +32,22 @@ var ( // assuming that there's no package with the given name. // GET // // Package page for the given subpackage. -func CreateHandler(config *Config) http.Handler { +func CreateHandler(config *Config, templates *template.Template) (http.Handler, error) { + indexTemplate := templates.Lookup("index.html") + if indexTemplate == nil { + return nil, errors.New("template index.html is missing") + } + + notFoundTemplate := templates.Lookup("404.html") + if notFoundTemplate == nil { + return nil, errors.New("template 404.html is missing") + } + + packageTemplate := templates.Lookup("package.html") + if packageTemplate == nil { + return nil, errors.New("template package.html is missing") + } + mux := http.NewServeMux() pkgs := make([]*sallyPackage, 0, len(config.Packages)) for name, pkg := range config.Packages { @@ -56,13 +71,13 @@ func CreateHandler(config *Config) http.Handler { // Double-register so that "/foo" // does not redirect to "/foo/" with a 300. - handler := &packageHandler{Pkg: pkg} + handler := &packageHandler{pkg: pkg, template: packageTemplate} mux.Handle("/"+name, handler) mux.Handle("/"+name+"/", handler) } - mux.Handle("/", newIndexHandler(pkgs)) - return requireMethod(http.MethodGet, mux) + mux.Handle("/", newIndexHandler(pkgs, indexTemplate, notFoundTemplate)) + return requireMethod(http.MethodGet, mux), nil } func requireMethod(method string, handler http.Handler) http.Handler { @@ -99,18 +114,22 @@ type sallyPackage struct { } type indexHandler struct { - pkgs []*sallyPackage // sorted by name + pkgs []*sallyPackage // sorted by name + indexTemplate *template.Template + notFoundTemplate *template.Template } var _ http.Handler = (*indexHandler)(nil) -func newIndexHandler(pkgs []*sallyPackage) *indexHandler { +func newIndexHandler(pkgs []*sallyPackage, indexTemplate, notFoundTemplate *template.Template) *indexHandler { slices.SortFunc(pkgs, func(a, b *sallyPackage) int { return cmp.Compare(a.Name, b.Name) }) return &indexHandler{ - pkgs: pkgs, + pkgs: pkgs, + indexTemplate: indexTemplate, + notFoundTemplate: notFoundTemplate, } } @@ -145,22 +164,20 @@ func (h *indexHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // If start == end, then there are no packages if start == end { - w.WriteHeader(http.StatusNotFound) - fmt.Fprintf(w, "no packages found under path: %v\n", path) + serveHTML(w, http.StatusNotFound, h.notFoundTemplate, struct{ Path string }{ + Path: path, + }) return } - err := indexTemplate.Execute(w, - struct{ Packages []*sallyPackage }{ - Packages: h.pkgs[start:end], - }) - if err != nil { - http.Error(w, err.Error(), 500) - } + serveHTML(w, http.StatusOK, h.indexTemplate, struct{ Packages []*sallyPackage }{ + Packages: h.pkgs[start:end], + }) } type packageHandler struct { - Pkg *sallyPackage + pkg *sallyPackage + template *template.Template } var _ http.Handler = (*packageHandler)(nil) @@ -169,24 +186,38 @@ func (h *packageHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Extract the relative path to subpackages, if any. // "/foo/bar" => "/bar" // "/foo" => "" - relPath := strings.TrimPrefix(r.URL.Path, "/"+h.Pkg.Name) + relPath := strings.TrimPrefix(r.URL.Path, "/"+h.pkg.Name) - err := packageTemplate.Execute(w, struct { + serveHTML(w, http.StatusOK, h.template, struct { ModulePath string VCS string RepoURL string DocURL string }{ - ModulePath: h.Pkg.ModulePath, - VCS: h.Pkg.VCS, - RepoURL: h.Pkg.RepoURL, - DocURL: h.Pkg.DocURL + relPath, + ModulePath: h.pkg.ModulePath, + VCS: h.pkg.VCS, + RepoURL: h.pkg.RepoURL, + DocURL: h.pkg.DocURL + relPath, }) - if err != nil { - http.Error(w, err.Error(), 500) - } } func descends(from, to string) bool { return to == from || (strings.HasPrefix(to, from) && to[len(from)] == '/') } + +func serveHTML(w http.ResponseWriter, status int, template *template.Template, data interface{}) { + if status >= 400 { + w.Header().Set("Cache-Control", "no-cache") + } + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(status) + + err := template.Execute(w, data) + if err != nil { + // The status has already been sent, so we cannot use [http.Error] - otherwise + // we'll get a superfluous call warning. The other option is to execute the template + // to a temporary buffer, but memory. + _, _ = w.Write([]byte(err.Error())) + } +} diff --git a/handler_test.go b/handler_test.go index ac0c1e4..b8c667b 100644 --- a/handler_test.go +++ b/handler_test.go @@ -1,6 +1,7 @@ package main import ( + "html/template" "io" "net/http" "net/http/httptest" @@ -31,7 +32,7 @@ packages: ` func TestIndex(t *testing.T) { - rr := CallAndRecord(t, config, "/") + rr := CallAndRecord(t, config, getTestTemplates(t, nil), "/") assert.Equal(t, 200, rr.Code) body := rr.Body.String() @@ -43,7 +44,7 @@ func TestIndex(t *testing.T) { } func TestSubindex(t *testing.T) { - rr := CallAndRecord(t, config, "/net") + rr := CallAndRecord(t, config, getTestTemplates(t, nil), "/net") assert.Equal(t, 200, rr.Code) body := rr.Body.String() @@ -54,7 +55,7 @@ func TestSubindex(t *testing.T) { } func TestPackageShouldExist(t *testing.T) { - rr := CallAndRecord(t, config, "/yarpc") + rr := CallAndRecord(t, config, getTestTemplates(t, nil), "/yarpc") AssertResponse(t, rr, 200, ` @@ -70,14 +71,24 @@ func TestPackageShouldExist(t *testing.T) { } func TestNonExistentPackageShould404(t *testing.T) { - rr := CallAndRecord(t, config, "/nonexistent") - AssertResponse(t, rr, 404, ` -no packages found under path: nonexistent + rr := CallAndRecord(t, config, getTestTemplates(t, nil), "/nonexistent") + assert.Equal(t, "no-cache", rr.Header().Get("Cache-Control")) + AssertResponse(t, rr, 404, ` + + + + + +
+

No packages found under: "nonexistent".

+
+ + `) } func TestTrailingSlash(t *testing.T) { - rr := CallAndRecord(t, config, "/yarpc/") + rr := CallAndRecord(t, config, getTestTemplates(t, nil), "/yarpc/") AssertResponse(t, rr, 200, ` @@ -93,7 +104,7 @@ func TestTrailingSlash(t *testing.T) { } func TestDeepImports(t *testing.T) { - rr := CallAndRecord(t, config, "/yarpc/heeheehee") + rr := CallAndRecord(t, config, getTestTemplates(t, nil), "/yarpc/heeheehee") AssertResponse(t, rr, 200, ` @@ -107,7 +118,7 @@ func TestDeepImports(t *testing.T) { `) - rr = CallAndRecord(t, config, "/yarpc/heehee/hawhaw") + rr = CallAndRecord(t, config, getTestTemplates(t, nil), "/yarpc/heehee/hawhaw") AssertResponse(t, rr, 200, ` @@ -123,7 +134,7 @@ func TestDeepImports(t *testing.T) { } func TestPackageLevelURL(t *testing.T) { - rr := CallAndRecord(t, config, "/zap") + rr := CallAndRecord(t, config, getTestTemplates(t, nil), "/zap") AssertResponse(t, rr, 200, ` @@ -141,14 +152,15 @@ func TestPackageLevelURL(t *testing.T) { func TestPostRejected(t *testing.T) { t.Parallel() - h := CreateHandler(&Config{ + h, err := CreateHandler(&Config{ URL: "go.uberalt.org", Packages: map[string]PackageConfig{ "zap": { Repo: "github.com/uber-go/zap", }, }, - }) + }, getTestTemplates(t, nil)) + require.NoError(t, err) srv := httptest.NewServer(h) t.Cleanup(srv.Close) @@ -248,7 +260,8 @@ func TestIndexHandler_rangeOf(t *testing.T) { for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { - h := newIndexHandler(tt.pkgs) + templates := getTestTemplates(t, nil) + h := newIndexHandler(tt.pkgs, templates.Lookup("index.html"), templates.Lookup("404.html")) start, end := h.rangeOf(tt.path) var got []string @@ -260,8 +273,56 @@ func TestIndexHandler_rangeOf(t *testing.T) { } } +func TestCustomTemplates(t *testing.T) { + t.Run("missing", func(t *testing.T) { + for _, name := range []string{"index.html", "package.html", "404.html"} { + templatesText := map[string]string{ + "index.html": "index", + "package.html": "package", + "404.html": "404", + } + delete(templatesText, name) + + templates := template.New("") + for tplName, tplText := range templatesText { + var err error + templates, err = templates.New(tplName).Parse(tplText) + require.NoError(t, err) + } + + _, err := CreateHandler(&Config{}, templates) + require.Error(t, err, name) + } + }) + + t.Run("replace", func(t *testing.T) { + templates := getTestTemplates(t, map[string]string{ + "404.html": "not found: {{ .Path }}", + }) + + // Overrides 404.html + rr := CallAndRecord(t, config, templates, "/blah") + require.Equal(t, http.StatusNotFound, rr.Result().StatusCode) + + // But not package.html + rr = CallAndRecord(t, config, templates, "/zap") + AssertResponse(t, rr, 200, ` + + + + + + + + Nothing to see here. Please move along. + + +`) + }) +} + func BenchmarkHandlerDispatch(b *testing.B) { - handler := CreateHandler(&Config{ + handler, err := CreateHandler(&Config{ URL: "go.uberalt.org", Packages: map[string]PackageConfig{ "zap": { @@ -271,7 +332,8 @@ func BenchmarkHandlerDispatch(b *testing.B) { Repo: "github.com/yarpc/metrics", }, }, - }) + }, getTestTemplates(b, nil)) + require.NoError(b, err) resw := new(nopResponseWriter) tests := []struct { @@ -297,6 +359,6 @@ func BenchmarkHandlerDispatch(b *testing.B) { type nopResponseWriter struct{} -func (nopResponseWriter) Header() http.Header { return nil } +func (nopResponseWriter) Header() http.Header { return http.Header{} } func (nopResponseWriter) Write([]byte) (int, error) { return 0, nil } func (nopResponseWriter) WriteHeader(int) {} diff --git a/main.go b/main.go index 5d54071..522ecf8 100644 --- a/main.go +++ b/main.go @@ -5,12 +5,15 @@ package main // import "go.uber.org/sally" import ( "flag" "fmt" + "html/template" "log" "net/http" + "path/filepath" ) func main() { yml := flag.String("yml", "sally.yaml", "yaml file to read config from") + tpls := flag.String("templates", "", "directory of .html templates to use") port := flag.Int("port", 8080, "port to listen and serve on") flag.Parse() @@ -20,9 +23,34 @@ func main() { log.Fatalf("Failed to parse %s: %v", *yml, err) } + var templates *template.Template + if *tpls != "" { + log.Printf("Parsing templates at path: %s\n", *tpls) + templates, err = getCombinedTemplates(*tpls) + if err != nil { + log.Fatalf("Failed to parse templates at %s: %v", *tpls, err) + } + } else { + templates = _templates + } + log.Printf("Creating HTTP handler with config: %v", config) - handler := CreateHandler(config) + handler, err := CreateHandler(config, templates) + if err != nil { + log.Fatalf("Failed to create handler: %v", err) + } log.Printf(`Starting HTTP handler on ":%d"`, *port) log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", *port), handler)) } + +func getCombinedTemplates(dir string) (*template.Template, error) { + // Clones default templates to then merge with the user defined templates. + // This allows for the user to only override certain templates, but not all + // if they don't want. + templates, err := _templates.Clone() + if err != nil { + return nil, err + } + return templates.ParseGlob(filepath.Join(dir, "*.html")) +} diff --git a/templates/404.html b/templates/404.html new file mode 100644 index 0000000..ab5b34a --- /dev/null +++ b/templates/404.html @@ -0,0 +1,11 @@ + + + + + + +
+

No packages found under: "{{ .Path }}".

+
+ + diff --git a/templates/templates.go b/templates/templates.go deleted file mode 100644 index 4217a5a..0000000 --- a/templates/templates.go +++ /dev/null @@ -1,15 +0,0 @@ -// Package templates exposes the template used by Sally -// to render the HTML pages. -package templates - -import _ "embed" // needed for go:embed - -// Index holds the contents of the index.html template. -// -//go:embed index.html -var Index string - -// Package holds the contents of the package.html template. -// -//go:embed package.html -var Package string diff --git a/utils_test.go b/utils_test.go index 5eaf0ba..09442bb 100644 --- a/utils_test.go +++ b/utils_test.go @@ -2,9 +2,11 @@ package main import ( "bytes" + "html/template" "net/http" "net/http/httptest" "os" + "path/filepath" "strings" "testing" @@ -14,46 +16,43 @@ import ( ) // TempFile persists contents and returns the path and a clean func -func TempFile(t *testing.T, contents string) (path string, clean func()) { +func TempFile(t *testing.T, contents string) (path string) { content := []byte(contents) tmpfile, err := os.CreateTemp("", "sally-tmp") - if err != nil { - t.Fatal("Unable to create tmpfile", err) - } + require.NoError(t, err, "unable to create tmpfile") - if _, err := tmpfile.Write(content); err != nil { - t.Fatal("Unable to write tmpfile", err) - } - if err := tmpfile.Close(); err != nil { - t.Fatal("Unable to close tmpfile", err) - } + _, err = tmpfile.Write(content) + require.NoError(t, err, "unable to write tmpfile") + + err = tmpfile.Close() + require.NoError(t, err, "unable to close tmpfile") - return tmpfile.Name(), func() { + t.Cleanup(func() { _ = os.Remove(tmpfile.Name()) - } + }) + + return tmpfile.Name() } // CreateHandlerFromYAML builds the Sally handler from a yaml config string -func CreateHandlerFromYAML(t *testing.T, content string) (handler http.Handler, clean func()) { - path, clean := TempFile(t, content) +func CreateHandlerFromYAML(t *testing.T, templates *template.Template, content string) (handler http.Handler) { + path := TempFile(t, content) config, err := Parse(path) - if err != nil { - t.Fatalf("Unable to parse %s: %v", path, err) - } + require.NoError(t, err, "unable to parse path %s", path) + + handler, err = CreateHandler(config, templates) + require.NoError(t, err) - return CreateHandler(config), clean + return handler } // CallAndRecord makes a GET request to the Sally handler and returns a response recorder -func CallAndRecord(t *testing.T, config string, uri string) *httptest.ResponseRecorder { - handler, clean := CreateHandlerFromYAML(t, config) - defer clean() +func CallAndRecord(t *testing.T, config string, templates *template.Template, uri string) *httptest.ResponseRecorder { + handler := CreateHandlerFromYAML(t, templates, config) req, err := http.NewRequest("GET", uri, nil) - if err != nil { - t.Fatalf("Unable to create request to %s: %v", uri, err) - } + require.NoError(t, err, "unable to create request to %s", uri) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) @@ -67,6 +66,29 @@ func AssertResponse(t *testing.T, rr *httptest.ResponseRecorder, code int, want assert.Equal(t, reformatHTML(t, want), reformatHTML(t, rr.Body.String())) } +// getTestTemplates returns a [template.Template] object with the default templates, +// overwritten by the [overrideTemplates]. If [overrideTemplates] is nil, the returned +// templates are a clone of the global [_templates]. +func getTestTemplates(tb testing.TB, overrideTemplates map[string]string) *template.Template { + if len(overrideTemplates) == 0 { + // We must clone! Cloning can only be done before templates are executed. Therefore, + // we cannot run some tests without cloning, and then attempt cloning it. It'll panic. + templates, err := _templates.Clone() + require.NoError(tb, err) + return templates + } + + templatesDir := tb.TempDir() // This is automatically removed at the end of the test. + for name, content := range overrideTemplates { + err := os.WriteFile(filepath.Join(templatesDir, name), []byte(content), 0o666) + require.NoError(tb, err) + } + + templates, err := getCombinedTemplates(templatesDir) + require.NoError(tb, err) + return templates +} + func reformatHTML(t *testing.T, s string) string { n, err := html.Parse(strings.NewReader(s)) require.NoError(t, err)