Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non blocking matchers & matching timeout #72

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 103 additions & 37 deletions layer4/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
package layer4

import (
"bytes"
"context"
"io"
"errors"
"net"
"sync"

Expand All @@ -30,7 +29,7 @@ import (
// and variable table. This function is intended for use at the start of a
// connection handler chain where the underlying connection is not yet a layer4
// Connection value.
func WrapConnection(underlying net.Conn, buf *bytes.Buffer, logger *zap.Logger) *Connection {
func WrapConnection(underlying net.Conn, buf []byte, logger *zap.Logger) *Connection {
ydylla marked this conversation as resolved.
Show resolved Hide resolved
repl := caddy.NewReplacer()
repl.Set("l4.conn.remote_addr", underlying.RemoteAddr())
repl.Set("l4.conn.local_addr", underlying.LocalAddr())
Expand Down Expand Up @@ -66,48 +65,52 @@ type Connection struct {

Logger *zap.Logger

buf *bytes.Buffer // stores recordings
bufReader io.Reader // used to read buf so it doesn't discard bytes
buf []byte // stores recorded data
offset int
recording bool

bytesRead, bytesWritten uint64
}

var ErrConsumedAllPrefetchedBytes = errors.New("consumed all prefetched bytes")
var ErrMatchingBufferFull = errors.New("matching buffer is full")

// Read implements io.Reader in such a way that reads first
// deplete any associated buffer from the prior recording,
// and once depleted (or if there isn't one), it continues
// reading from the underlying connection.
func (cx *Connection) Read(p []byte) (n int, err error) {
if cx.recording {
if len(cx.buf) == 0 {
err = cx.prefetch()
if err != nil {
return 0, err
}
}
if len(cx.buf[cx.offset:])-len(p) < 0 {
return 0, ErrConsumedAllPrefetchedBytes
}
}

// if there is a buffer we should read from, start
// with that; we only read from the underlying conn
// after the buffer has been "depleted"
if cx.bufReader != nil {
n, err = cx.bufReader.Read(p)
if err == io.EOF {
cx.bufReader = nil
err = nil
}
// prevent first read from returning 0 bytes because of empty bufReader
if !(n == 0 && err == nil) {
return
}
if cx.offset < len(cx.buf) {
n := copy(p, cx.buf[cx.offset:])
cx.offset += n
return n, nil
}

// buffer has been "depleted" so read from
// underlying connection
n, err = cx.Conn.Read(p)
cx.bytesRead += uint64(n)

if !cx.recording {
return
}

// since we're recording at this point, anything that
// was read needs to be written to the buffer, even
// if there was an error
if n > 0 {
if nw, errw := cx.buf.Write(p[:n]); errw != nil {
return nw, errw
if cx.recording {
cx.buf = append(cx.buf, p[:n]...)
cx.offset += n
if len(cx.buf) >= MaxMatchingBytes {
return n, ErrMatchingBufferFull
}
}

Expand All @@ -117,6 +120,13 @@ func (cx *Connection) Read(p []byte) (n int, err error) {
func (cx *Connection) Write(p []byte) (n int, err error) {
n, err = cx.Conn.Write(p)
cx.bytesWritten += uint64(n)

// reset buf so the next Read during matching calls prefetch again
if len(cx.buf) > 0 {
cx.buf = cx.buf[:0]
cx.offset = 0
}

return
}

Expand All @@ -130,33 +140,69 @@ func (cx *Connection) Wrap(conn net.Conn) *Connection {
Context: cx.Context,
Logger: cx.Logger,
buf: cx.buf,
bufReader: cx.bufReader,
offset: cx.offset,
recording: cx.recording,
bytesRead: cx.bytesRead,
bytesWritten: cx.bytesWritten,
}
}

// record starts recording the stream into cx.buf. It also creates a reader
// to read from the buffer but not to discard any byte.
// prefetch tries to read all bytes that a client initially sent us without blocking.
func (cx *Connection) prefetch() (err error) {
var n int
var tmp []byte

for len(cx.buf) < MaxMatchingBytes {
if len(cx.buf) == 0 && cap(cx.buf) >= PrefetchChunkSize {
n, err = cx.Conn.Read(cx.buf[:PrefetchChunkSize])
cx.buf = cx.buf[:n]
} else {
if tmp == nil {
tmp = bufPool.Get().([]byte)
tmp = tmp[:PrefetchChunkSize]
defer bufPool.Put(tmp)
}
n, err = cx.Conn.Read(tmp)
cx.buf = append(cx.buf, tmp[:n]...)
}

cx.bytesRead += uint64(n)

if err != nil {
return err
}

if n < PrefetchChunkSize {
break
}
}

if cx.Logger.Core().Enabled(zap.DebugLevel) {
cx.Logger.Debug("prefetched",
zap.String("remote", cx.RemoteAddr().String()),
zap.Int("bytes", len(cx.buf)),
)
}

return nil
}

// record starts recording the stream into cx.buf.
func (cx *Connection) record() {
cx.recording = true
cx.bufReader = bytes.NewReader(cx.buf.Bytes()) // Don't discard bytes.
}

// rewind stops recording and creates a reader for the
// buffer so that the next reads from an associated
// recordableConn come from the buffer first, then
// continue with the underlying conn.
// rewind stops recording and resets the buffer offset
// so that the next reads come from the buffer first.
func (cx *Connection) rewind() {
cx.recording = false
cx.bufReader = cx.buf // Actually consume bytes.
cx.offset = 0
}

// SetVar sets a value in the context's variable table with
// the given key. It overwrites any previous value with the
// same key.
func (cx Connection) SetVar(key string, value interface{}) {
func (cx *Connection) SetVar(key string, value interface{}) {
varMap, ok := cx.Context.Value(VarsCtxKey).(map[string]interface{})
if !ok {
return
Expand All @@ -167,14 +213,27 @@ func (cx Connection) SetVar(key string, value interface{}) {
// GetVar gets a value from the context's variable table with
// the given key. It returns the value if found, and true if
// it found a value with that key; false otherwise.
func (cx Connection) GetVar(key string) interface{} {
func (cx *Connection) GetVar(key string) interface{} {
varMap, ok := cx.Context.Value(VarsCtxKey).(map[string]interface{})
if !ok {
return nil
}
return varMap[key]
}

// MatchingBytes returns all bytes currently available for matching. This is only intended for reading.
// Do not write into the slice it's a view of the internal buffer and you will likely mess up the connection.
func (cx *Connection) MatchingBytes() ([]byte, error) {
// ensure prefetch was executed, for example when this was called before the first Read
if cx.recording && len(cx.buf) == 0 {
err := cx.prefetch()
if err != nil {
return nil, err
}
}
return cx.buf[cx.offset:], nil
}

var (
// VarsCtxKey is the key used to store the variables table
// in a Connection's context.
Expand All @@ -187,8 +246,15 @@ var (
listenerCtxKey caddy.CtxKey = "listener"
)

const PrefetchChunkSize = 1024

// MaxMatchingBytes is the amount of bytes that are at most prefetched during matching.
// This is probably most relevant for the http matcher since http requests do not have a size limit.
// 8 KiB should cover most use-cases and is similar to popular webservers.
const MaxMatchingBytes = 8 * 1024

var bufPool = sync.Pool{
New: func() interface{} {
return new(bytes.Buffer)
return make([]byte, 0, PrefetchChunkSize)
},
}
2 changes: 1 addition & 1 deletion layer4/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func TestConnection_RecordAndRewind(t *testing.T) {
defer in.Close()
defer out.Close()

cx := WrapConnection(out, &bytes.Buffer{}, zap.NewNop())
cx := WrapConnection(out, []byte{}, zap.NewNop())
defer cx.Close()

matcherData := []byte("foo")
Expand Down
10 changes: 5 additions & 5 deletions layer4/listener.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
package layer4

import (
"bytes"
"context"
"errors"
"github.com/caddyserver/caddy/v2"
"go.uber.org/zap"
"net"
"runtime"
"sync"
"time"

"github.com/caddyserver/caddy/v2"
"go.uber.org/zap"
)

func init() {
Expand Down Expand Up @@ -115,8 +115,8 @@ func (l *listener) handle(conn net.Conn) {
}
}()

buf := bufPool.Get().(*bytes.Buffer)
buf.Reset()
buf := bufPool.Get().([]byte)
buf = buf[:0]
defer bufPool.Put(buf)

cx := WrapConnection(conn, buf, l.logger)
Expand Down
33 changes: 33 additions & 0 deletions layer4/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ package layer4

import (
"encoding/json"
"errors"
"fmt"
"os"
"time"

"github.com/caddyserver/caddy/v2"
"go.uber.org/zap"
Expand All @@ -39,6 +42,8 @@ type Route struct {
// executed in sequential order if the route's matchers match.
HandlersRaw []json.RawMessage `json:"handle,omitempty" caddy:"namespace=layer4.handlers inline_key=handler"`

MatchingTimeout caddy.Duration `json:"matching_timeout,omitempty"`

matcherSets MatcherSets
middleware []Middleware
}
Expand Down Expand Up @@ -68,6 +73,11 @@ func (r *Route) Provision(ctx caddy.Context) error {
r.middleware = append(r.middleware, wrapHandler(midhandler))
}

// timeouts
if r.MatchingTimeout == 0 {
r.MatchingTimeout = caddy.Duration(10 * time.Second)
}

return nil
}

Expand Down Expand Up @@ -104,6 +114,8 @@ func (routes RouteList) Compile(next Handler, logger *zap.Logger) Handler {
return stack
}

var ErrMatchingTimeout = errors.New("aborted matching according to timeout")

// wrapRoute wraps route with a middleware and handler so that it can
// be chained in and defer evaluation of its matchers to request-time.
// Like wrapMiddleware, it is vital that this wrapping takes place in
Expand All @@ -124,16 +136,37 @@ func wrapRoute(route *Route, logger *zap.Logger) Middleware {
// but I just thought this made more sense
nextCopy := next

if route.MatchingTimeout > 0 {
// timeout matching to protect against malicious or very slow clients
err := cx.Conn.SetReadDeadline(time.Now().Add(time.Duration(route.MatchingTimeout)))
if err != nil {
return err
}
}

// route must match at least one of the matcher sets
matched, err := route.matcherSets.AnyMatch(cx)
if err == ErrConsumedAllPrefetchedBytes {
matched = false
err = nil
}
if err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
err = ErrMatchingTimeout
}
logger.Error("matching connection", zap.String("remote", cx.RemoteAddr().String()), zap.Error(err))
return nil // return nil so the error does not get logged again
}
if !matched {
return nextCopy.Handle(cx)
}

// remove deadline after we matched
err = cx.Conn.SetReadDeadline(time.Time{})
if err != nil {
return err
}

// TODO: other routing features?

// // if route is part of a group, ensure only the
Expand Down
Loading