Skip to content

Commit

Permalink
shared_client: Clean up waiters after timeouts
Browse files Browse the repository at this point in the history
Tell handler to delete waiters after request times out.

Signed-off-by: Jarno Rajahalme <[email protected]>
  • Loading branch information
jrajahalme committed Apr 29, 2024
1 parent c8d105a commit b993c3d
Show file tree
Hide file tree
Showing 2 changed files with 290 additions and 42 deletions.
200 changes: 158 additions & 42 deletions shared_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package dns

import (
"container/heap"
"context"
"errors"
"fmt"
Expand Down Expand Up @@ -138,6 +139,145 @@ func (c *SharedClient) ExchangeShared(m *Msg) (r *Msg, rtt time.Duration, err er
return c.ExchangeSharedContext(context.Background(), m)
}

// waiter is a single waiter in the deadline heap
type waiter struct {
id uint16
index int32
start time.Time
deadline time.Time
ch chan sharedClientResponse
}

// waiterQueue is a collection of waiters, implements container/heap
// so that the the waiter with the closest deadline is on the top of the heap.
// Map is used as an index on the DNS request ID for quick access.
type waiterQueue struct {
waiters []*waiter
index map[uint16]*waiter
}

// Len from heap.Interface
func (wq waiterQueue) Len() int { return len(wq.waiters) }

// Less from heap.Interface
func (wq waiterQueue) Less(i, j int) bool {
return wq.waiters[j].deadline.After(wq.waiters[i].deadline)
}

// Swap from heap.Interface
func (wq waiterQueue) Swap(i, j int) {
wq.waiters[i], wq.waiters[j] = wq.waiters[j], wq.waiters[i]
wq.waiters[i].index = int32(i)
wq.waiters[j].index = int32(j)
}

// Push from heap.Interface
func (wq *waiterQueue) Push(x any) {
wtr := x.(*waiter)
wtr.index = int32(len(wq.waiters))
wq.waiters = append(wq.waiters, wtr)
wq.index[wtr.id] = wtr
}

// Pop from heap.Interface
func (wq *waiterQueue) Pop() any {
old := wq.waiters
n := len(old)
wtr := old[n-1]
old[n-1] = nil // avoid memory leak
wtr.index = -1 // for safety
wq.waiters = old[:n-1]

delete(wq.index, wtr.id)
return wtr
}

// newWaiterQueue returns a new waiterQueue
func newWaiterQueue() *waiterQueue {
wq := waiterQueue{
waiters: make([]*waiter, 0, 1),
index: make(map[uint16]*waiter),
}
heap.Init(&wq)
return &wq
}

func (wq *waiterQueue) Insert(id uint16, start, deadline time.Time, ch chan sharedClientResponse) {
if wq.Exists(id) {
panic("waiterQueue: entry exists")
}

heap.Push(wq, &waiter{id, -1, start, deadline, ch})
}

func (wq *waiterQueue) Exists(id uint16) bool {
_, exists := wq.index[id]
return exists
}

func (wq *waiterQueue) FailAll(err error) {
for wq.Len() > 0 {
wtr := heap.Pop(wq).(*waiter)
wtr.ch <- sharedClientResponse{nil, 0, err}
close(wtr.ch)
}
}

func (wq *waiterQueue) Dequeue(id uint16) *waiter {
wtr, ok := wq.index[id]
if !ok {
return nil // not found
}
if wtr.id != id {
panic(fmt.Sprintf("waiterQueue: invalid id %d != %d", wtr.id, id))
}
if wtr.index < 0 {
panic(fmt.Sprintf("waiterQueue: invalid index: %d", wtr.index))
}

heap.Remove(wq, int(wtr.index))

return wtr
}

// GetTimeout returns the time from now till the earliers deadline
func (wq *waiterQueue) GetTimeout() time.Duration {
// return 10 minutes if there are no waiters
if wq.Len() == 0 {
return 10 * time.Minute
}
return wq.waiters[0].deadline.Sub(time.Now())
}

// Expired sends a timeout response to all timed out waiters
func (wq *waiterQueue) Expired() {
now := time.Now()
for wq.Len() > 0 {
if wq.waiters[0].deadline.After(now) {
break
}
wtr := heap.Pop(wq).(*waiter)
wtr.ch <- sharedClientResponse{nil, 0, context.DeadlineExceeded}
close(wtr.ch)
}
}

// Respond passes the DNS response to the waiter
func (wq *waiterQueue) Respond(resp sharedClientResponse) {
if resp.err != nil {
// ReadMsg failed, but we cannot match it to a request,
// so complete all pending requests.
wq.FailAll(resp.err)
} else if resp.msg != nil {
wtr := wq.Dequeue(resp.msg.Id)
if wtr != nil {
resp.rtt = time.Since(wtr.start)
wtr.ch <- resp
close(wtr.ch)
}
}
}

// handler is started when the connection is dialed
func handler(wg *sync.WaitGroup, client *Client, conn *Conn, requests chan request) {
defer wg.Done()
Expand Down Expand Up @@ -197,11 +337,7 @@ func handler(wg *sync.WaitGroup, client *Client, conn *Conn, requests chan reque
}
}()

type waiter struct {
ch chan sharedClientResponse
start time.Time
}
waitingResponses := make(map[uint16]waiter)
waiters := newWaiterQueue()
defer func() {
conn.Close()
close(receiverTrigger)
Expand All @@ -214,10 +350,7 @@ func handler(wg *sync.WaitGroup, client *Client, conn *Conn, requests chan reque
}

// Now cancel all the remaining waiters
for _, waiter := range waitingResponses {
waiter.ch <- sharedClientResponse{nil, 0, net.ErrClosed}
close(waiter.ch)
}
waiters.FailAll(net.ErrClosed)

// Drain requests in case they come in while we are closing
// down. This loop is done only after 'requests' channel is closed in
Expand All @@ -229,8 +362,14 @@ func handler(wg *sync.WaitGroup, client *Client, conn *Conn, requests chan reque
}
}()

deadline := time.NewTimer(0)
for {
// update timer
deadline.Reset(waiters.GetTimeout())
select {
case <-deadline.C:
waiters.Expired()

case req, ok := <-requests:
if !ok {
// 'requests' is closed when SharedClient is recycled, which happens
Expand All @@ -243,10 +382,11 @@ func handler(wg *sync.WaitGroup, client *Client, conn *Conn, requests chan reque
// Due to birthday paradox and the fact that ID is uint16
// it's likely to happen with small number (~200) of concurrent requests
// which would result in goroutine leak as we would never close req.ch
if _, duplicate := waitingResponses[req.msg.Id]; duplicate {
if waiters.Exists(req.msg.Id) {
// find next available ID
duplicate := true
for id := req.msg.Id + 1; id != req.msg.Id; id++ {
if _, duplicate = waitingResponses[id]; !duplicate {
if duplicate = waiters.Exists(id); !duplicate {
req.msg.Id = id
break
}
Expand All @@ -264,7 +404,8 @@ func handler(wg *sync.WaitGroup, client *Client, conn *Conn, requests chan reque
req.ch <- sharedClientResponse{nil, 0, err}
close(req.ch)
} else {
waitingResponses[req.msg.Id] = waiter{req.ch, start}
deadline := time.Now().Add(client.getTimeoutForRequest(client.readTimeout()))
waiters.Insert(req.msg.Id, start, deadline, req.ch)

// Wake up the receiver that may be waiting to receive again
triggerReceiver()
Expand All @@ -276,22 +417,7 @@ func handler(wg *sync.WaitGroup, client *Client, conn *Conn, requests chan reque
// nothing can be received any more
return
}
if resp.err != nil {
// ReadMsg failed, but we cannot match it to a request,
// so complete all pending requests.
for _, waiter := range waitingResponses {
waiter.ch <- sharedClientResponse{nil, 0, resp.err}
close(waiter.ch)
}
waitingResponses = make(map[uint16]waiter)
} else if resp.msg != nil {
if waiter, ok := waitingResponses[resp.msg.Id]; ok {
delete(waitingResponses, resp.msg.Id)
resp.rtt = time.Since(waiter.start)
waiter.ch <- resp
close(waiter.ch)
}
}
waiters.Respond(resp)
}
}
}
Expand Down Expand Up @@ -325,20 +451,10 @@ func (c *SharedClient) ExchangeSharedContext(ctx context.Context, m *Msg) (r *Ms
return nil, 0, ctx.Err()
}

// Since c.requests is unbuffered, the handler is guaranteed to eventually close 'respCh',
// unless the response never arrives.
// Quit when ctx is done so that each request times out individually in case there is no
// response at all.
readTimeout := c.getTimeoutForRequest(c.Client.readTimeout())
ctx2, cancel2 := context.WithTimeout(ctx, readTimeout)
defer cancel2()
select {
case resp := <-respCh:
return resp.msg, resp.rtt, resp.err
case <-ctx2.Done():
// TODO: handler now has a stale waiter that remains if the response never arrives.
return nil, 0, ctx2.Err()
}
// Handler eventually responds to every request, possibly after the request times out.
resp := <-respCh

return resp.msg, resp.rtt, resp.err
}

// close closes and waits for the close to finish.
Expand Down
Loading

0 comments on commit b993c3d

Please sign in to comment.