Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add brotli support to proxy, warn on unsupported encoding #695

Merged
merged 4 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.2.668
0.2.668
2 changes: 1 addition & 1 deletion cmd/templ/generatecmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ func (cmd *Generate) StartProxy(ctx context.Context) (p *proxy.Handler, err erro
if cmd.Args.ProxyBind == "" {
cmd.Args.ProxyBind = "127.0.0.1"
}
p = proxy.New(cmd.Args.ProxyBind, cmd.Args.ProxyPort, target)
p = proxy.New(cmd.Log, cmd.Args.ProxyBind, cmd.Args.ProxyPort, target)
go func() {
cmd.Log.Info("Proxying", slog.String("from", p.URL), slog.String("to", p.Target.String()))
if err := http.ListenAndServe(fmt.Sprintf("%s:%d", cmd.Args.ProxyBind, cmd.Args.ProxyPort), p); err != nil {
Expand Down
116 changes: 73 additions & 43 deletions cmd/templ/generatecmd/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ import (
"compress/gzip"
"fmt"
"io"
"log"
stdlog "log"
"log/slog"
"math"
"net/http"
"net/http/httputil"
Expand All @@ -16,6 +17,7 @@ import (
"time"

"github.com/a-h/templ/cmd/templ/generatecmd/sse"
"github.com/andybalholm/brotli"

_ "embed"
)
Expand All @@ -26,85 +28,113 @@ var script string
const scriptTag = `<script src="/_templ/reload/script.js"></script>`

type Handler struct {
log *slog.Logger
URL string
Target *url.URL
p *httputil.ReverseProxy
sse *sse.Handler
}

func updateGzipResponse(r *http.Response) error {
plainr, err := gzip.NewReader(r.Body)
func insertScriptTagIntoBody(body string) (updated string) {
return strings.Replace(body, "</body>", scriptTag+"</body>", -1)
}

type passthroughWriteCloser struct {
io.Writer
}

func (pwc passthroughWriteCloser) Close() error {
return nil
}

const unsupportedContentEncoding = "Unsupported content encoding, hot reload script not inserted."

func (h *Handler) modifyResponse(r *http.Response) error {
if r.Header.Get("templ-skip-modify") == "true" {
return nil
}
if contentType := r.Header.Get("Content-Type"); !strings.HasPrefix(contentType, "text/html") {
return nil
}

// Set up readers and writers.
newReader := func(in io.Reader) (out io.Reader, err error) {
return in, nil
}
newWriter := func(out io.Writer) io.WriteCloser {
return passthroughWriteCloser{out}
}
switch r.Header.Get("Content-Encoding") {
case "gzip":
newReader = func(in io.Reader) (out io.Reader, err error) {
return gzip.NewReader(in)
}
newWriter = func(out io.Writer) io.WriteCloser {
return gzip.NewWriter(out)
}
case "br":
newReader = func(in io.Reader) (out io.Reader, err error) {
return brotli.NewReader(in), nil
}
newWriter = func(out io.Writer) io.WriteCloser {
return brotli.NewWriter(out)
}
case "":
// No content encoding.
default:
h.log.Warn(unsupportedContentEncoding, slog.String("encoding", r.Header.Get("Content-Encoding")))
}

// Read the encoded body.
encr, err := newReader(r.Body)
if err != nil {
return err
}
defer plainr.Close()
body, err := io.ReadAll(plainr)
defer r.Body.Close()
body, err := io.ReadAll(encr)
if err != nil {
return err
}

// Update it.
updated := insertScriptTagIntoBody(string(body))

// Encode the response.
var buf bytes.Buffer
gzw := gzip.NewWriter(&buf)
defer gzw.Close()
_, err = gzw.Write([]byte(updated))
encw := newWriter(&buf)
_, err = encw.Write([]byte(updated))
if err != nil {
return err
}
err = gzw.Close()
err = encw.Close()
if err != nil {
return err
}

// Update the response.
r.Body = io.NopCloser(&buf)
r.ContentLength = int64(buf.Len())
r.Header.Set("Content-Length", strconv.Itoa(buf.Len()))
return nil
}

func updatePlainResponse(r *http.Response) error {
body, err := io.ReadAll(r.Body)
if err != nil {
return err
}
updated := insertScriptTagIntoBody(string(body))
r.Body = io.NopCloser(strings.NewReader(updated))
r.ContentLength = int64(len(updated))
r.Header.Set("Content-Length", strconv.Itoa(len(updated)))
return nil
}

func insertScriptTagIntoBody(body string) (updated string) {
return strings.Replace(body, "</body>", scriptTag+"</body>", -1)
}

func modifyResponse(r *http.Response) error {
if r.Header.Get("templ-skip-modify") == "true" {
return nil
}
if contentType := r.Header.Get("Content-Type"); !strings.HasPrefix(contentType, "text/html") {
return nil
}
modifier := updatePlainResponse
if r.Header.Get("Content-Encoding") == "gzip" {
modifier = updateGzipResponse
}
return modifier(r)
}

func New(bind string, port int, target *url.URL) *Handler {
func New(log *slog.Logger, bind string, port int, target *url.URL) (h *Handler) {
p := httputil.NewSingleHostReverseProxy(target)
p.ErrorLog = log.New(os.Stderr, "Proxy to target error: ", 0)
p.ErrorLog = stdlog.New(os.Stderr, "Proxy to target error: ", 0)
p.Transport = &roundTripper{
maxRetries: 10,
initialDelay: 100 * time.Millisecond,
backoffExponent: 1.5,
}
p.ModifyResponse = modifyResponse
return &Handler{
h = &Handler{
log: log,
URL: fmt.Sprintf("http://%s:%d", bind, port),
Target: target,
p: p,
sse: sse.New(),
}
p.ModifyResponse = h.modifyResponse
return h
}

func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
Expand Down
138 changes: 132 additions & 6 deletions cmd/templ/generatecmd/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@ import (
"context"
"fmt"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"strings"
"sync"
"testing"
"time"

"github.com/andybalholm/brotli"
"github.com/google/go-cmp/cmp"
)

Expand Down Expand Up @@ -57,7 +60,9 @@ func TestProxy(t *testing.T) {
r.Header.Set("Content-Length", "16")

// Act
err := modifyResponse(r)
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
err := h.modifyResponse(r)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
Expand Down Expand Up @@ -85,7 +90,9 @@ func TestProxy(t *testing.T) {
r.Header.Set("templ-skip-modify", "true")

// Act
err := modifyResponse(r)
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
err := h.modifyResponse(r)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
Expand Down Expand Up @@ -117,7 +124,9 @@ func TestProxy(t *testing.T) {
}

// Act
err := modifyResponse(r)
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
err := h.modifyResponse(r)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
Expand Down Expand Up @@ -147,7 +156,9 @@ func TestProxy(t *testing.T) {
r.Header.Set("Content-Length", "16")

// Act
err := modifyResponse(r)
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
err := h.modifyResponse(r)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
Expand Down Expand Up @@ -195,7 +206,10 @@ func TestProxy(t *testing.T) {
r.Header.Set("Content-Length", fmt.Sprintf("%d", expectedLength))

// Act
if err = modifyResponse(r); err != nil {
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
err = h.modifyResponse(r)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

Expand All @@ -216,7 +230,57 @@ func TestProxy(t *testing.T) {
t.Errorf("unexpected response body (-got +want):\n%s", diff)
}
})
t.Run("brotli: body tags get the script inserted", func(t *testing.T) {
// Arrange
body := `<html><body></body></html>`
var buf bytes.Buffer
brw := brotli.NewWriter(&buf)
_, err := brw.Write([]byte(body))
if err != nil {
t.Fatalf("unexpected error writing gzip: %v", err)
}
brw.Close()

expectedString := insertScriptTagIntoBody(body)

var expectedBytes bytes.Buffer
brw = brotli.NewWriter(&expectedBytes)
_, err = brw.Write([]byte(expectedString))
if err != nil {
t.Fatalf("unexpected error writing gzip: %v", err)
}
brw.Close()
expectedLength := len(expectedBytes.Bytes())

r := &http.Response{
Body: io.NopCloser(&buf),
Header: make(http.Header),
}
r.Header.Set("Content-Type", "text/html, charset=utf-8")
r.Header.Set("Content-Encoding", "br")
r.Header.Set("Content-Length", fmt.Sprintf("%d", expectedLength))

// Act
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
err = h.modifyResponse(r)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

// Assert
if r.Header.Get("Content-Length") != fmt.Sprintf("%d", expectedLength) {
t.Errorf("expected content length to be %d, got %v", expectedLength, r.Header.Get("Content-Length"))
}

actualBody, err := io.ReadAll(brotli.NewReader(r.Body))
if err != nil {
t.Fatalf("unexpected error reading response: %v", err)
}
if diff := cmp.Diff(expectedString, string(actualBody)); diff != "" {
t.Errorf("unexpected response body (-got +want):\n%s", diff)
}
})
t.Run("notify-proxy: sending POST request to /_templ/reload/events should receive reload sse event", func(t *testing.T) {
// Arrange 1: create a test proxy server.
dummyHandler := func(w http.ResponseWriter, r *http.Request) {}
Expand All @@ -227,7 +291,8 @@ func TestProxy(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error parsing URL: %v", err)
}
handler := New("0.0.0.0", 0, u)
log := slog.New(slog.NewJSONHandler(io.Discard, nil))
handler := New(log, "0.0.0.0", 0, u)
proxyServer := httptest.NewServer(handler)
defer proxyServer.Close()

Expand Down Expand Up @@ -305,4 +370,65 @@ func TestProxy(t *testing.T) {
t.Fatalf("timeout waiting for sse response")
}
})
t.Run("unsupported encodings result in a warning", func(t *testing.T) {
// Arrange
r := &http.Response{
Body: io.NopCloser(bytes.NewReader([]byte("<p>Data</p>"))),
Header: make(http.Header),
}
r.Header.Set("Content-Type", "text/html, charset=utf-8")
r.Header.Set("Content-Encoding", "weird-encoding")

// Act
lh := newTestLogHandler()
log := slog.New(lh)
h := New(log, "127.0.0.1", 7474, &url.URL{Scheme: "http", Host: "example.com"})
err := h.modifyResponse(r)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

// Assert
if len(lh.records) != 1 {
t.Fatalf("expected 1 log entry, but got %d", len(lh.records))
}
record := lh.records[0]
if record.Message != unsupportedContentEncoding {
t.Errorf("expected warning message %q, got %q", unsupportedContentEncoding, record.Message)
}
if record.Level != slog.LevelWarn {
t.Errorf("expected warning, got level %v", record.Level)
}
})
}

func newTestLogHandler() *testLogHandler {
return &testLogHandler{
m: new(sync.Mutex),
records: nil,
}
}

type testLogHandler struct {
m *sync.Mutex
records []slog.Record
}

func (h *testLogHandler) Enabled(context.Context, slog.Level) bool {
return true
}

func (h *testLogHandler) Handle(ctx context.Context, r slog.Record) error {
h.m.Lock()
defer h.m.Unlock()
h.records = append(h.records, r)
return nil
}

func (h *testLogHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
return h
}

func (h *testLogHandler) WithGroup(name string) slog.Handler {
return h
}
Loading
Loading