diff --git a/layer4/routes_test.go b/layer4/routes_test.go index a66558c..82a2eff 100644 --- a/layer4/routes_test.go +++ b/layer4/routes_test.go @@ -7,6 +7,7 @@ import ( "github.com/caddyserver/caddy/v2/modules/caddyhttp" "io" "net" + "os" "testing" "time" @@ -89,3 +90,100 @@ func TestMatchingTimeoutWorks(t *testing.T) { t.Fatal("handler was called but should not") } } + +// used to test the timeout of udp associations +type testIoUdpMatcher struct { +} + +func (testIoUdpMatcher) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + ID: "layer4.matchers.testIoUdpMatcher", + New: func() caddy.Module { return new(testIoUdpMatcher) }, + } +} + +var ( + testConnection *Connection + handlingDone chan struct{} +) + +func (m *testIoUdpMatcher) Match(cx *Connection) (bool, error) { + // normally deadline exceeded error is handled during prefetch, and custom matcher can't + // read more than what's prefetched, but it's a test. + cx.matching = false + buf := make([]byte, 10) + n, err := io.ReadFull(cx, buf) + if err != nil { + cx.SetVar("time", time.Now()) + cx.SetVar("err", err) + testConnection = cx + close(handlingDone) + } + return n > 0, err +} + +func TestMatchingTimeoutWorksUDP(t *testing.T) { + ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()}) + defer cancel() + + caddy.RegisterModule(testIoUdpMatcher{}) + + routes := RouteList{&Route{ + MatcherSetsRaw: caddyhttp.RawMatcherSets{ + caddy.ModuleMap{"testIoUdpMatcher": json.RawMessage("{}")}, // any io using matcher + }, + }} + + err := routes.Provision(ctx) + if err != nil { + t.Fatalf("provision failed | %s", err) + } + + matchingTimeout := time.Second + + compiledRoutes := routes.Compile(zap.NewNop(), matchingTimeout, + HandlerFunc(func(con *Connection) error { + return nil + })) + + handlingDone = make(chan struct{}) + + // Because udp is connectionless and every read can be from different addresses. A mapping between + // addresses and data read is created. A virtual connection can only read data from a certain address. + // Using real udp sockets and server to test timeout. + // We can't wait for the handler to finish this way, but that is tested above. + pc, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to listen | %s", err) + } + defer func() { _ = pc.Close() }() + + server := new(Server) + server.compiledRoute = compiledRoutes + server.logger = zap.NewNop() + go server.servePacket(pc) + + now := time.Now() + + client, err := net.Dial("udp", pc.LocalAddr().String()) + if err != nil { + t.Fatalf("failed to dial | %s", err) + } + defer func() { _ = client.Close() }() + + _, err = client.Write([]byte("hello")) + if err != nil { + t.Fatalf("failed to write | %s", err) + } + + // only wait for the matcher to return + <-handlingDone + if !errors.Is(testConnection.GetVar("err").(error), os.ErrDeadlineExceeded) { + t.Fatalf("expected deadline exceeded error but got %s", testConnection.GetVar("err")) + } + + elasped := testConnection.GetVar("time").(time.Time).Sub(now) + if !(matchingTimeout <= elasped && elasped <= 2*matchingTimeout) { + t.Fatalf("timeout takes too long %s", elasped) + } +} diff --git a/layer4/server.go b/layer4/server.go index 05135fb..74d6014 100644 --- a/layer4/server.go +++ b/layer4/server.go @@ -20,7 +20,9 @@ import ( "fmt" "io" "net" + "os" "sync" + "sync/atomic" "time" "github.com/caddyserver/caddy/v2" @@ -238,7 +240,30 @@ type packetConn struct { // from the buffer, and this packet will be reused in the next Read() // without waiting for readCh. lastPacket *packet - lastBuf *bytes.Buffer + lastBuf *bytes.Reader + + // stores time.Time as Unix as Read maybe called concurrently with SetReadDeadline + deadline atomic.Int64 + deadlineTimer *time.Timer + idleTimer *time.Timer +} + +// SetReadDeadline sets the deadline to wait for data from the underlying net.PacketConn. +func (pc *packetConn) SetReadDeadline(t time.Time) error { + pc.deadline.Store(t.Unix()) + if pc.deadlineTimer != nil { + pc.deadlineTimer.Reset(time.Until(t)) + } else { + pc.deadlineTimer = time.NewTimer(time.Until(t)) + } + return nil +} + +// TODO: idle timeout should be configurable per server +const udpAssociationIdleTimeout = 30 * time.Second + +func isDeadlineExceeded(t time.Time) bool { + return !t.IsZero() && t.Before(time.Now()) } func (pc *packetConn) Read(b []byte) (n int, err error) { @@ -253,27 +278,47 @@ func (pc *packetConn) Read(b []byte) (n int, err error) { } return } - select { - case pkt := <-pc.readCh: - if pkt == nil { - // Channel is closed. Return EOF below. + // check deadline + if isDeadlineExceeded(time.Unix(pc.deadline.Load(), 0)) { + return 0, os.ErrDeadlineExceeded + } + // set or refresh idle timeout + if pc.idleTimer == nil { + pc.idleTimer = time.NewTimer(udpAssociationIdleTimeout) + } else { + pc.idleTimer.Reset(udpAssociationIdleTimeout) + } + var done bool + for !done { + select { + case pkt := <-pc.readCh: + if pkt == nil { + // Channel is closed. Return EOF below. + done = true + break + } + buf := bytes.NewReader(pkt.pooledBuf[:pkt.n]) + n, err = buf.Read(b) + if buf.Len() == 0 { + // Buffer fully consumed, release it. + udpBufPool.Put(pkt.pooledBuf) + } else { + // Buffer only partially consumed. Keep track of it for + // next Read() call. + pc.lastPacket = pkt + pc.lastBuf = buf + } + return + case <-pc.deadlineTimer.C: + // deadline may change during the wait, recheck + if isDeadlineExceeded(time.Unix(pc.deadline.Load(), 0)) { + return 0, os.ErrDeadlineExceeded + } + // next loop will run. Don't call Read as that will reset the idle timer. + case <-pc.idleTimer.C: + done = true break } - buf := bytes.NewBuffer(pkt.pooledBuf[:pkt.n]) - n, err = buf.Read(b) - if buf.Len() == 0 { - // Buffer fully consumed, release it. - udpBufPool.Put(pkt.pooledBuf) - } else { - // Buffer only partially consumed. Keep track of it for - // next Read() call. - pc.lastPacket = pkt - pc.lastBuf = buf - } - return - // TODO: idle timeout should be configurable per server - case <-time.After(30 * time.Second): - break } // Idle timeout simulates socket closure. //