Skip to content

Commit

Permalink
Merge pull request #225 from zenhack/no-MaxConcurrentCalls
Browse files Browse the repository at this point in the history
Don't block when queuing methods on *Server.
  • Loading branch information
lthibault authored Apr 18, 2022
2 parents 47559d5 + 9bd1e54 commit 4f0a052
Showing 1 changed file with 114 additions and 169 deletions.
283 changes: 114 additions & 169 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"capnproto.org/go/capnp/v3"
"capnproto.org/go/capnp/v3/exc"
"capnproto.org/go/capnp/v3/internal/mpsc"
)

// A Method describes a single capability method on a server object.
Expand All @@ -21,30 +22,24 @@ type Method struct {
// Call holds the state of an ongoing capability method call.
// A Call cannot be used after the server method returns.
type Call struct {
args capnp.Struct
ctx context.Context
cancel context.CancelFunc
method *Method
recv capnp.Recv
aq *answerQueue
srv *Server

alloced bool
alloc resultsAllocer
results capnp.Struct

ack chan<- struct{}
acked bool
}

func newCall(args capnp.Struct, ra resultsAllocer) (*Call, <-chan struct{}) {
ack := make(chan struct{})
return &Call{
args: args,
alloc: ra,
ack: ack,
}, ack
}

// Args returns the call's arguments. Args is not safe to
// reference after a method implementation returns. Args is safe to
// call and read from multiple goroutines.
func (c *Call) Args() capnp.Struct {
return c.args
return c.recv.Args
}

// AllocResults allocates the results struct. It is an error to call
Expand All @@ -55,7 +50,7 @@ func (c *Call) AllocResults(sz capnp.ObjectSize) (capnp.Struct, error) {
}
var err error
c.alloced = true
c.results, err = c.alloc.AllocResults(sz)
c.results, err = c.recv.Returner.AllocResults(sz)
return c.results, err
}

Expand All @@ -73,8 +68,8 @@ func (c *Call) Ack() {
if c.acked {
return
}
close(c.ack)
c.acked = true
go c.srv.handleCalls(c.srv.handleCallsCtx)
}

// Shutdowner is the interface that wraps the Shutdown method.
Expand All @@ -88,32 +83,32 @@ type Server struct {
methods sortedMethods
brand interface{}
shutdown Shutdowner
policy Policy

// mu protects the following fields.
// mu should never be held while calling application code.
mu sync.Mutex

// ongoing is a fixed-size list of ongoing calls.
// It is used as a semaphore: when all elements are set, no new work
// can be started until an element is cleared.
ongoing []cstate
// Cancels handleCallsCtx
cancelHandleCalls context.CancelFunc

// starting is non-nil if start() is waiting for acknowledgement of a
// call. It is closed when the acknowledgement is received.
starting <-chan struct{}
// Context used by the goroutine running handleCalls(). Note
// the calls themselves will have different contexts, which
// are not children of this context, but are supplied by
// start(). See cancelCurrentCall.
handleCallsCtx context.Context

// full is non-nil if a start() is waiting for a space in ongoing to
// free up. It is closed and set to nil when the next call returns.
full chan<- struct{}
// wg is incremented each time a method is queued, and
// decremented after it is handled.
wg sync.WaitGroup

// drain is non-nil when Shutdown starts and is closed by the last
// call to return.
drain chan struct{}
}
// Calls are inserted into this queue, to be handled
// by a goroutine running handleCalls()
callQueue *mpsc.Queue[*Call]

type cstate struct {
cancel context.CancelFunc // nil if slot free
// When a call is in progress, this channel will contain the
// CancelFunc for that call's context. A goroutine may receive
// on this to fetch the function, and is then responsible for calling
// it. This happens in Shutdown().
//
// The caller must call cancelHandleCalls() *before* calling
// the received CancelFunc.
cancelCurrentCall chan context.CancelFunc
}

// Policy is a set of behavioral parameters for a Server.
Expand All @@ -135,20 +130,20 @@ type Policy struct {
// return or acknowledgment of the previous call. See Call.Ack for more
// details.
func New(methods []Method, brand interface{}, shutdown Shutdowner, policy *Policy) *Server {
ctx, cancel := context.WithCancel(context.Background())

srv := &Server{
methods: make(sortedMethods, len(methods)),
brand: brand,
shutdown: shutdown,
methods: make(sortedMethods, len(methods)),
brand: brand,
shutdown: shutdown,
callQueue: mpsc.New[*Call](),
cancelHandleCalls: cancel,
handleCallsCtx: ctx,
cancelCurrentCall: make(chan context.CancelFunc, 1),
}
copy(srv.methods, methods)
sort.Sort(srv.methods)
if policy != nil {
srv.policy = *policy
}
if srv.policy.MaxConcurrentCalls < 1 {
srv.policy.MaxConcurrentCalls = 2
}
srv.ongoing = make([]cstate, srv.policy.MaxConcurrentCalls)
go srv.handleCalls(ctx)
return srv
}

Expand Down Expand Up @@ -186,128 +181,84 @@ func (srv *Server) Recv(ctx context.Context, r capnp.Recv) capnp.PipelineCaller
return srv.start(ctx, mm, r)
}

func (srv *Server) start(ctx context.Context, m *Method, r capnp.Recv) capnp.PipelineCaller {
// Acquire "starting" condition variable.
srv.mu.Lock()
func (srv *Server) handleCalls(ctx context.Context) {
for {
if srv.drain != nil {
srv.mu.Unlock()
r.Reject(exc.New(exc.Failed, "capnp server", "call after shutdown"))
return nil
}
if srv.starting == nil {
call, err := srv.callQueue.Recv(ctx)
if err != nil {
break
}
wait := srv.starting
srv.mu.Unlock()
select {
case <-wait:
case <-ctx.Done():
r.Reject(ctx.Err())
return nil

srv.handleCall(ctx, call)
if call.acked {
// Another goroutine has taken over; time
// to retire.
return
}
srv.mu.Lock()
}
starting := make(chan struct{})
srv.starting = starting

// Acquire an ID (semaphore).
id := srv.nextID()
if id == -1 {
full := make(chan struct{})
srv.full = full
srv.mu.Unlock()
select {
case <-full:
case <-ctx.Done():
srv.mu.Lock()
srv.starting = nil
close(starting)
srv.full = nil // full could be nil or non-nil, ensure it is nil.
srv.mu.Unlock()
r.Reject(ctx.Err())
return nil
}
srv.mu.Lock()
id = srv.nextID()
if srv.drain != nil {
srv.starting = nil
close(starting)
srv.mu.Unlock()
r.Reject(exc.New(exc.Failed, "capnp server", "call after shutdown"))
return nil
for {
// Context has been canceled; drain the rest of the queue,
// cancelling each call.
call, ok := srv.callQueue.TryRecv()
if !ok {
return
}
call.cancel()
srv.handleCall(ctx, call)
}
}

// Bookkeeping: set starting to indicate we're waiting for an ack and
// record the cancel function for draining.
ctx, cancel := context.WithCancel(ctx)
srv.ongoing[id] = cstate{cancel}
srv.mu.Unlock()
func (srv *Server) handleCall(ctx context.Context, c *Call) {
defer srv.wg.Done()
defer c.cancel()

// Call implementation function.
call, ack := newCall(r.Args, r.Returner)
aq := newAnswerQueue(r.Method)
done := make(chan struct{})
go func() {
err := m.Impl(ctx, call)
r.ReleaseArgs()
if err == nil {
aq.fulfill(call.results)
r.Returner.Return(nil)
} else {
aq.reject(err)
r.Returner.Return(err)
}
srv.mu.Lock()
srv.ongoing[id].cancel()
srv.ongoing[id] = cstate{}
if srv.drain != nil && !srv.hasOngoing() {
close(srv.drain)
}
if srv.full != nil {
close(srv.full)
srv.full = nil
// Store this in the channel, in case Shutdown() gets called
// while we're servicing the method call.
srv.cancelCurrentCall <- c.cancel
defer func() {
select {
case <-srv.cancelCurrentCall:
default:
}
srv.mu.Unlock()
close(done)
}()
var pcall capnp.PipelineCaller
select {
case <-ack:
pcall = aq
case <-done:
// Implementation functions may not call Ack, which is fine for
// smaller functions.

// Handling the contexts is tricky here, since neither one
// is necessarily a parent of the other. We need to check
// the context that was passed to us (which manages the
// handleCalls loop) some time *after* storing c.cancel,
// above, to avoid a race between this code and Shutdown(),
// which cancels ctx before attempting to receive c.cancel.
err := ctx.Err()
if err == nil {
err = c.ctx.Err()
}
if err == nil {
err = c.method.Impl(c.ctx, c)
}
srv.mu.Lock()
srv.starting = nil
close(starting)
srv.mu.Unlock()
return pcall
}

// nextID returns the next available index in srv.ongoing or -1 if
// there are too many ongoing calls. The caller must be holding onto
// srv.mu.
func (srv *Server) nextID() int {
for i := range srv.ongoing {
if srv.ongoing[i].cancel == nil {
return i
}
c.recv.ReleaseArgs()
if err == nil {
c.aq.fulfill(c.results)
} else {
c.aq.reject(err)
}
return -1
c.recv.Returner.Return(err)
}

// hasOngoing reports whether there are any ongoing calls.
// The caller must be holding onto srv.mu.
func (srv *Server) hasOngoing() bool {
for i := range srv.ongoing {
if srv.ongoing[i].cancel != nil {
return true
}
}
return false
func (srv *Server) start(ctx context.Context, m *Method, r capnp.Recv) capnp.PipelineCaller {
srv.wg.Add(1)

ctx, cancel := context.WithCancel(ctx)

aq := newAnswerQueue(r.Method)
srv.callQueue.Send(&Call{
ctx: ctx,
cancel: cancel,
method: m,
recv: r,
aq: aq,
srv: srv,
})
return aq
}

// Brand returns a value that will match IsServer.
Expand All @@ -319,24 +270,18 @@ func (srv *Server) Brand() capnp.Brand {
// Shutdowner passed into NewServer. Shutdown must not be called more
// than once.
func (srv *Server) Shutdown() {
srv.mu.Lock()
if srv.drain != nil {
srv.mu.Unlock()
panic("capnp server: Shutdown called multiple times")
}
srv.drain = make(chan struct{})
if srv.hasOngoing() {
for _, cs := range srv.ongoing {
if cs.cancel != nil {
cs.cancel()
}
}
srv.mu.Unlock()
<-srv.drain
} else {
close(srv.drain)
srv.mu.Unlock()
// Cancel the loop in handleCalls(), and then cancel the outstanding
// call, if any. The order here is critical; if we cancel the
// outstanding call first, the loop may start another call before
// we cancel it.
srv.cancelHandleCalls()
select {
case cancel := <-srv.cancelCurrentCall:
cancel()
default:
}

srv.wg.Wait()
if srv.shutdown != nil {
srv.shutdown.Shutdown()
}
Expand Down

0 comments on commit 4f0a052

Please sign in to comment.