diff --git a/internal/mx/mx_test.go b/internal/mx/mx_test.go index cff08c83..0e3562c9 100644 --- a/internal/mx/mx_test.go +++ b/internal/mx/mx_test.go @@ -408,6 +408,105 @@ func TestMux(t *testing.T) { }) } +func TestMuxFlexiblePattern(t *testing.T) { + t.Parallel() + + tr := &http.Transport{ + // since we are using self-signed certificates, we need to skip verification. + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + client := &http.Client{Transport: tr} + + httpsPort := tst.GetPort() + domain := "localhost" + + l := log.New(context.Background(), &bytes.Buffer{}, 500) + + t.Run("flexible pattern accepts all uris", func(t *testing.T) { + t.Parallel() + + msg := "hello world" + rt, err := NewRoute( + "/*", + MethodGet, + someMuxHandler(msg), + ) + attest.Ok(t, err) + mux, err := New( + config.WithOpts(domain, httpsPort, tst.SecretKey(), config.DirectIpStrategy, l), + nil, + rt, + ) + attest.Ok(t, err) + + ts, err := tst.TlsServer(mux, domain, httpsPort) + attest.Ok(t, err) + defer ts.Close() + + { + res, errA := client.Get(ts.URL + "/UnknownUri") + attest.Ok(t, errA) + + rb, errB := io.ReadAll(res.Body) + attest.Ok(t, errB) + + attest.Equal(t, res.StatusCode, http.StatusOK) + attest.Equal(t, string(rb), msg) + } + + { + res, errC := client.Get(ts.URL + "/") + attest.Ok(t, errC) + + rb, errD := io.ReadAll(res.Body) + attest.Ok(t, errD) + + attest.Equal(t, res.StatusCode, http.StatusOK) + attest.Equal(t, string(rb), msg) + } + + { + res, errE := client.Get(ts.URL + "/hey/a/b/cool") + attest.Ok(t, errE) + + rb, errF := io.ReadAll(res.Body) + attest.Ok(t, errF) + + attest.Equal(t, res.StatusCode, http.StatusOK) + attest.Equal(t, string(rb), msg) + } + }) + + t.Run("conflict", func(t *testing.T) { + t.Parallel() + + msg := "hello world" + + rt1, err := NewRoute( + "/*", + MethodGet, + someMuxHandler(msg), + ) + attest.Ok(t, err) + + rt2, err := NewRoute( + "/hi", + MethodGet, + thisIsAnotherMuxHandler(), + ) + attest.Ok(t, err) + + _, err = New( + config.WithOpts(domain, httpsPort, tst.SecretKey(), config.DirectIpStrategy, l), + nil, + rt1, + rt2, + ) + attest.Error(t, err) + attest.Subsequence(t, err.Error(), "would conflict") + }) +} + func getManyRoutes(b *testing.B) []Route { b.Helper() diff --git a/internal/mx/route.go b/internal/mx/route.go index 37510172..eb0a6bc8 100644 --- a/internal/mx/route.go +++ b/internal/mx/route.go @@ -70,9 +70,15 @@ func NewRoute( } func (r Route) match(ctx context.Context, segs []string) (context.Context, bool) { + if len(r.segments) == 1 && r.segments[0] == "*" { + // The router is allowed to handle all request paths + return ctx, true + } + if len(segs) > len(r.segments) { return nil, false } + for i, seg := range r.segments { if i > len(segs)-1 { return nil, false @@ -91,6 +97,7 @@ func (r Route) match(ctx context.Context, segs []string) (context.Context, bool) ctx = context.WithValue(ctx, muxContextKey(seg), segs[i]) } } + return ctx, true } @@ -222,6 +229,10 @@ already exists and would conflict`, getfunc(rt.originalHandler), ) + if len(existingSegments) == 1 && existingSegments[0] == "*" && len(incomingSegments) > 0 { + return errMsg + } + if pattern == rt.pattern { return errMsg }