Skip to content

Commit

Permalink
Merge branch 'master' into sanitize-request
Browse files Browse the repository at this point in the history
  • Loading branch information
francislavoie authored Oct 18, 2021
2 parents cad9f90 + 062657d commit dc03e31
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 48 deletions.
20 changes: 14 additions & 6 deletions cmd/commandfuncs.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ func cmdBuildInfo(fl Flags) (int, error) {
func cmdListModules(fl Flags) (int, error) {
packages := fl.Bool("packages")
versions := fl.Bool("versions")
skipStandard := fl.Bool("skip-standard")

printModuleInfo := func(mi moduleInfo) {
fmt.Print(mi.caddyModuleID)
Expand Down Expand Up @@ -388,23 +389,30 @@ func cmdListModules(fl Flags) (int, error) {
return caddy.ExitCodeSuccess, nil
}

if len(standard) > 0 {
for _, mod := range standard {
printModuleInfo(mod)
// Standard modules (always shipped with Caddy)
if !skipStandard {
if len(standard) > 0 {
for _, mod := range standard {
printModuleInfo(mod)
}
}
fmt.Printf("\n Standard modules: %d\n", len(standard))
}
fmt.Printf("\n Standard modules: %d\n", len(standard))

// Non-standard modules (third party plugins)
if len(nonstandard) > 0 {
if len(standard) > 0 {
if len(standard) > 0 && !skipStandard {
fmt.Println()
}
for _, mod := range nonstandard {
printModuleInfo(mod)
}
}
fmt.Printf("\n Non-standard modules: %d\n", len(nonstandard))

// Unknown modules (couldn't get Caddy module info)
if len(unknown) > 0 {
if len(standard) > 0 || len(nonstandard) > 0 {
if (len(standard) > 0 && !skipStandard) || len(nonstandard) > 0 {
fmt.Println()
}
for _, mod := range unknown {
Expand Down
1 change: 1 addition & 0 deletions cmd/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ config file; otherwise the default is assumed.`,
fs := flag.NewFlagSet("list-modules", flag.ExitOnError)
fs.Bool("packages", false, "Print package paths")
fs.Bool("versions", false, "Print version information")
fs.Bool("skip-standard", false, "Skip printing standard modules")
return fs
}(),
})
Expand Down
2 changes: 1 addition & 1 deletion cmd/packagesfuncs.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ func getModules() (standard, nonstandard, unknown []moduleInfo, err error) {
}

func listModules(path string) error {
cmd := exec.Command(path, "list-modules", "--versions")
cmd := exec.Command(path, "list-modules", "--versions", "--skip-standard")
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err := cmd.Run()
Expand Down
39 changes: 27 additions & 12 deletions modules/caddyhttp/reverseproxy/reverseproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,11 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, repl *
res.Body = h.bufferedBody(res.Body)
}

// the response body may get closed by a response handler,
// and we need to keep track to make sure we don't try to copy
// the response if it was already closed
bodyClosed := false

// see if any response handler is configured for this response from the backend
for i, rh := range h.HandleResponse {
if rh.Match != nil && !rh.Match.Match(res.StatusCode, res.Header) {
Expand All @@ -652,8 +657,6 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, repl *
continue
}

res.Body.Close()

// set up the replacer so that parts of the original response can be
// used for routing decisions
for field, value := range res.Header {
Expand All @@ -663,7 +666,17 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, repl *
repl.Set("http.reverse_proxy.status_text", res.Status)

h.logger.Debug("handling response", zap.Int("handler", i))
if routeErr := rh.Routes.Compile(next).ServeHTTP(rw, req); routeErr != nil {

// pass the request through the response handler routes
routeErr := rh.Routes.Compile(next).ServeHTTP(rw, req)

// always close the response body afterwards since it's expected
// that the response handler routes will have written to the
// response writer with a new body
res.Body.Close()
bodyClosed = true

if routeErr != nil {
// wrap error in roundtripSucceeded so caller knows that
// the roundtrip was successful and to not retry
return roundtripSucceeded{routeErr}
Expand Down Expand Up @@ -704,15 +717,17 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, repl *
}

rw.WriteHeader(res.StatusCode)
err = h.copyResponse(rw, res.Body, h.flushInterval(req, res))
res.Body.Close() // close now, instead of defer, to populate res.Trailer
if err != nil {
// we're streaming the response and we've already written headers, so
// there's nothing an error handler can do to recover at this point;
// the standard lib's proxy panics at this point, but we'll just log
// the error and abort the stream here
h.logger.Error("aborting with incomplete response", zap.Error(err))
return nil
if !bodyClosed {
err = h.copyResponse(rw, res.Body, h.flushInterval(req, res))
res.Body.Close() // close now, instead of defer, to populate res.Trailer
if err != nil {
// we're streaming the response and we've already written headers, so
// there's nothing an error handler can do to recover at this point;
// the standard lib's proxy panics at this point, but we'll just log
// the error and abort the stream here
h.logger.Error("aborting with incomplete response", zap.Error(err))
return nil
}
}

if len(res.Trailer) > 0 {
Expand Down
1 change: 1 addition & 0 deletions modules/caddyhttp/templates/tplcontext.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ func (c TemplateContext) funcHTTPInclude(uri string) (string, error) {
}
virtReq.Host = c.Req.Host
virtReq.Header = c.Req.Header.Clone()
virtReq.Header.Set("Accept-Encoding", "identity") // https://github.com/caddyserver/caddy/issues/4352
virtReq.Trailer = c.Req.Trailer.Clone()
virtReq.Header.Set(recursionPreventionHeader, strconv.Itoa(recursionCount))

Expand Down
98 changes: 69 additions & 29 deletions modules/caddyhttp/templates/tplcontext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package templates

import (
"bytes"
"context"
"fmt"
"net/http"
"os"
Expand All @@ -25,10 +26,49 @@ import (
"strings"
"testing"
"time"

"github.com/caddyserver/caddy/v2/modules/caddyhttp"
)

type handle struct {
}

func (h *handle) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Accept-Encoding") == "identity" {
w.Write([]byte("good contents"))
} else {
w.Write([]byte("bad cause Accept-Encoding: " + r.Header.Get("Accept-Encoding")))
}
}

func TestHTTPInclude(t *testing.T) {
tplContext := getContextOrFail(t)
for i, test := range []struct {
uri string
handler *handle
expect string
}{
{
uri: "https://example.com/foo/bar",
handler: &handle{},
expect: "good contents",
},
} {
ctx := context.WithValue(tplContext.Req.Context(), caddyhttp.ServerCtxKey, test.handler)
tplContext.Req = tplContext.Req.WithContext(ctx)
tplContext.Req.Header.Add("Accept-Encoding", "gzip")
result, err := tplContext.funcHTTPInclude(test.uri)
if result != test.expect {
t.Errorf("Test %d: expected '%s' but got '%s'", i, test.expect, result)
}
if err != nil {
t.Errorf("Test %d: got error: %v", i, result)
}
}
}

func TestMarkdown(t *testing.T) {
context := getContextOrFail(t)
tplContext := getContextOrFail(t)

for i, test := range []struct {
body string
Expand All @@ -39,7 +79,7 @@ func TestMarkdown(t *testing.T) {
expect: "<ul>\n<li>str1</li>\n<li>str2</li>\n</ul>\n",
},
} {
result, err := context.funcMarkdown(test.body)
result, err := tplContext.funcMarkdown(test.body)
if result != test.expect {
t.Errorf("Test %d: expected '%s' but got '%s'", i, test.expect, result)
}
Expand Down Expand Up @@ -80,9 +120,9 @@ func TestCookie(t *testing.T) {
expect: "cookieValue",
},
} {
context := getContextOrFail(t)
context.Req.AddCookie(test.cookie)
actual := context.Cookie(test.cookieName)
tplContext := getContextOrFail(t)
tplContext.Req.AddCookie(test.cookie)
actual := tplContext.Cookie(test.cookieName)
if actual != test.expect {
t.Errorf("Test %d: Expected cookie value '%s' but got '%s' for cookie with name '%s'",
i, test.expect, actual, test.cookieName)
Expand Down Expand Up @@ -111,22 +151,22 @@ func TestImport(t *testing.T) {
shouldErr: true,
},
} {
context := getContextOrFail(t)
tplContext := getContextOrFail(t)
var absFilePath string

// create files for test case
if test.fileName != "" {
absFilePath := filepath.Join(fmt.Sprintf("%s", context.Root), test.fileName)
absFilePath := filepath.Join(fmt.Sprintf("%s", tplContext.Root), test.fileName)
if err := os.WriteFile(absFilePath, []byte(test.fileContent), os.ModePerm); err != nil {
os.Remove(absFilePath)
t.Fatalf("Test %d: Expected no error creating file, got: '%s'", i, err.Error())
}
}

// perform test
context.NewTemplate("parent")
actual, err := context.funcImport(test.fileName)
templateWasDefined := strings.Contains(context.tpl.DefinedTemplates(), test.expect)
tplContext.NewTemplate("parent")
actual, err := tplContext.funcImport(test.fileName)
templateWasDefined := strings.Contains(tplContext.tpl.DefinedTemplates(), test.expect)
if err != nil {
if !test.shouldErr {
t.Errorf("Test %d: Expected no error, got: '%s'", i, err)
Expand All @@ -135,7 +175,7 @@ func TestImport(t *testing.T) {
t.Errorf("Test %d: Expected error but had none", i)
} else if !templateWasDefined && actual != "" {
// template should be defined, return value should be an empty string
t.Errorf("Test %d: Expected template %s to be define but got %s", i, test.expect, context.tpl.DefinedTemplates())
t.Errorf("Test %d: Expected template %s to be define but got %s", i, test.expect, tplContext.tpl.DefinedTemplates())

}

Expand Down Expand Up @@ -191,20 +231,20 @@ func TestInclude(t *testing.T) {
args: "text",
},
} {
context := getContextOrFail(t)
tplContext := getContextOrFail(t)
var absFilePath string

// create files for test case
if test.fileName != "" {
absFilePath := filepath.Join(fmt.Sprintf("%s", context.Root), test.fileName)
absFilePath := filepath.Join(fmt.Sprintf("%s", tplContext.Root), test.fileName)
if err := os.WriteFile(absFilePath, []byte(test.fileContent), os.ModePerm); err != nil {
os.Remove(absFilePath)
t.Fatalf("Test %d: Expected no error creating file, got: '%s'", i, err.Error())
}
}

// perform test
actual, err := context.funcInclude(test.fileName, test.args)
actual, err := tplContext.funcInclude(test.fileName, test.args)
if err != nil {
if !test.shouldErr {
t.Errorf("Test %d: Expected no error, got: '%s'", i, err)
Expand All @@ -225,28 +265,28 @@ func TestInclude(t *testing.T) {
}

func TestCookieMultipleCookies(t *testing.T) {
context := getContextOrFail(t)
tplContext := getContextOrFail(t)

cookieNameBase, cookieValueBase := "cookieName", "cookieValue"

for i := 0; i < 10; i++ {
context.Req.AddCookie(&http.Cookie{
tplContext.Req.AddCookie(&http.Cookie{
Name: fmt.Sprintf("%s%d", cookieNameBase, i),
Value: fmt.Sprintf("%s%d", cookieValueBase, i),
})
}

for i := 0; i < 10; i++ {
expectedCookieVal := fmt.Sprintf("%s%d", cookieValueBase, i)
actualCookieVal := context.Cookie(fmt.Sprintf("%s%d", cookieNameBase, i))
actualCookieVal := tplContext.Cookie(fmt.Sprintf("%s%d", cookieNameBase, i))
if actualCookieVal != expectedCookieVal {
t.Errorf("Expected cookie value %s, found %s", expectedCookieVal, actualCookieVal)
}
}
}

func TestIP(t *testing.T) {
context := getContextOrFail(t)
tplContext := getContextOrFail(t)
for i, test := range []struct {
inputRemoteAddr string
expect string
Expand All @@ -257,15 +297,15 @@ func TestIP(t *testing.T) {
{"[2001:db8:a0b:12f0::1]", "[2001:db8:a0b:12f0::1]"},
{`[fe80:1::3%eth0]:44`, `fe80:1::3%eth0`},
} {
context.Req.RemoteAddr = test.inputRemoteAddr
if actual := context.RemoteIP(); actual != test.expect {
tplContext.Req.RemoteAddr = test.inputRemoteAddr
if actual := tplContext.RemoteIP(); actual != test.expect {
t.Errorf("Test %d: Expected %s but got %s", i, test.expect, actual)
}
}
}

func TestStripHTML(t *testing.T) {
context := getContextOrFail(t)
tplContext := getContextOrFail(t)

for i, test := range []struct {
input string
Expand Down Expand Up @@ -302,7 +342,7 @@ func TestStripHTML(t *testing.T) {
expect: `<h1hi`,
},
} {
actual := context.funcStripHTML(test.input)
actual := tplContext.funcStripHTML(test.input)
if actual != test.expect {
t.Errorf("Test %d: Expected %s, found %s. Input was StripHTML(%s)", i, test.expect, actual, test.input)
}
Expand Down Expand Up @@ -350,13 +390,13 @@ func TestFileListing(t *testing.T) {
verifyErr: os.IsNotExist,
},
} {
context := getContextOrFail(t)
tplContext := getContextOrFail(t)
var dirPath string
var err error

// create files for test case
if test.fileNames != nil {
dirPath, err = os.MkdirTemp(fmt.Sprintf("%s", context.Root), "caddy_ctxtest")
dirPath, err = os.MkdirTemp(fmt.Sprintf("%s", tplContext.Root), "caddy_ctxtest")
if err != nil {
t.Fatalf("Test %d: Expected no error creating directory, got: '%s'", i, err.Error())
}
Expand All @@ -371,7 +411,7 @@ func TestFileListing(t *testing.T) {

// perform test
input := filepath.ToSlash(filepath.Join(filepath.Base(dirPath), test.inputBase))
actual, err := context.funcListFiles(input)
actual, err := tplContext.funcListFiles(input)
if err != nil {
if !test.shouldErr {
t.Errorf("Test %d: Expected no error, got: '%s'", i, err)
Expand Down Expand Up @@ -404,7 +444,7 @@ func TestFileListing(t *testing.T) {
}

func TestSplitFrontMatter(t *testing.T) {
context := getContextOrFail(t)
tplContext := getContextOrFail(t)

for i, test := range []struct {
input string
Expand Down Expand Up @@ -465,7 +505,7 @@ title = "Welcome"
body: "\n### Test",
},
} {
result, _ := context.funcSplitFrontMatter(test.input)
result, _ := tplContext.funcSplitFrontMatter(test.input)
if result.Meta["title"] != test.expect {
t.Errorf("Test %d: Expected %s, found %s. Input was SplitFrontMatter(%s)", i, test.expect, result.Meta["title"], test.input)
}
Expand All @@ -477,11 +517,11 @@ title = "Welcome"
}

func getContextOrFail(t *testing.T) TemplateContext {
context, err := initTestContext()
tplContext, err := initTestContext()
if err != nil {
t.Fatalf("failed to prepare test context: %v", err)
}
return context
return tplContext
}

func initTestContext() (TemplateContext, error) {
Expand Down

0 comments on commit dc03e31

Please sign in to comment.