Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some helpers for acquiring locks #210

Merged
merged 2 commits into from
Jan 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 50 additions & 49 deletions answer.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"sync"

"capnproto.org/go/capnp/v3/internal/errors"
"capnproto.org/go/capnp/v3/internal/syncutil"
)

// A Promise holds the result of an RPC call. Only one of Fulfill,
Expand Down Expand Up @@ -193,19 +194,19 @@ func (p *Promise) resolve(r Ptr, e error) {
if p.ongoingCalls > 0 {
p.callsStopped = make(chan struct{})
}
p.mu.Unlock()
res := resolution{p.method, r, e}
for path, row := range p.clients {
t := path.transform()
for i := range row {
row[i].promise.Fulfill(res.client(t))
row[i].promise = nil
syncutil.Without(&p.mu, func() {
lthibault marked this conversation as resolved.
Show resolved Hide resolved
res := resolution{p.method, r, e}
for path, row := range p.clients {
t := path.transform()
for i := range row {
row[i].promise.Fulfill(res.client(t))
row[i].promise = nil
}
}
}
if p.callsStopped != nil {
<-p.callsStopped
}
p.mu.Lock()
if p.callsStopped != nil {
<-p.callsStopped
}
})
}

// Move p into resolved state.
Expand Down Expand Up @@ -247,24 +248,24 @@ traversal:
case parent.isPendingResolution():
// Wait for resolution. Next traversal iteration will be resolved.
r := parent.resolved
parent.mu.Unlock()
if p.joined == nil {
p.joined = make(chan struct{})
}
p.mu.Unlock()
<-r
p.mu.Lock()
parent.mu.Lock()
syncutil.Without(&parent.mu, func() {
if p.joined == nil {
p.joined = make(chan struct{})
}
syncutil.Without(&p.mu, func() {
<-r
})
})
case parent.isPendingJoin():
j := parent.joined
parent.mu.Unlock()
if p.joined == nil {
p.joined = make(chan struct{})
}
p.mu.Unlock()
<-j
p.mu.Lock()
parent.mu.Lock()
syncutil.Without(&parent.mu, func() {
if p.joined == nil {
p.joined = make(chan struct{})
}
syncutil.Without(&p.mu, func() {
<-j
})
})
case parent.isResolved():
r, e := parent.result, parent.err
parent.mu.Unlock()
Expand All @@ -284,9 +285,9 @@ traversal:
if p.joined == nil {
p.joined = make(chan struct{})
}
p.mu.Unlock()
<-p.callsStopped
p.mu.Lock()
syncutil.Without(&p.mu, func() {
<-p.callsStopped
})
p.callsStopped = nil
}
if p.joined != nil {
Expand Down Expand Up @@ -445,12 +446,12 @@ traversal:
caller := p.caller
p.mu.Unlock()
ans, release := caller.PipelineSend(ctx, transform, s)
p.mu.Lock()
p.ongoingCalls--
if p.ongoingCalls == 0 && p.callsStopped != nil {
close(p.callsStopped)
}
p.mu.Unlock()
syncutil.With(&p.mu, func() {
p.ongoingCalls--
if p.ongoingCalls == 0 && p.callsStopped != nil {
close(p.callsStopped)
}
})
return ans, release
case p.isPendingResolution():
// Block new calls until resolved.
Expand Down Expand Up @@ -503,12 +504,12 @@ traversal:
caller := p.caller
p.mu.Unlock()
pcall := caller.PipelineRecv(ctx, transform, r)
p.mu.Lock()
p.ongoingCalls--
if p.ongoingCalls == 0 && p.callsStopped != nil {
close(p.callsStopped)
}
p.mu.Unlock()
syncutil.With(&p.mu, func() {
p.ongoingCalls--
if p.ongoingCalls == 0 && p.callsStopped != nil {
close(p.callsStopped)
}
})
return pcall
case p.isPendingResolution():
// Block new calls until resolved.
Expand Down Expand Up @@ -589,9 +590,9 @@ traversal:
switch {
case p.isPendingJoin():
j := p.joined
p.mu.Unlock()
<-j
p.mu.Lock()
syncutil.Without(&p.mu, func() {
<-j
})
case p.isJoined():
q := p.next
p.mu.Unlock()
Expand Down Expand Up @@ -619,9 +620,9 @@ traversal:
p.mu.Unlock()
return c
case p.isPendingResolution():
p.mu.Unlock()
<-p.resolved
p.mu.Lock()
syncutil.Without(&p.mu, func() {
<-p.resolved
})
fallthrough
case p.isResolved():
r := p.resolution()
Expand Down
13 changes: 7 additions & 6 deletions capability.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"sync"

"capnproto.org/go/capnp/v3/flowcontrol"
"capnproto.org/go/capnp/v3/internal/syncutil"
)

// An Interface is a reference to a client in a message's capability table.
Expand Down Expand Up @@ -193,12 +194,12 @@ func (c *Client) startCall() (hook ClientHook, resolved, released bool, finish f
c.h.mu.Unlock()
savedHook := c.h
return savedHook.ClientHook, savedHook.isResolved(), false, func() {
savedHook.mu.Lock()
savedHook.calls--
if savedHook.refs == 0 && savedHook.calls == 0 {
close(savedHook.done)
}
savedHook.mu.Unlock()
syncutil.With(&savedHook.mu, func() {
savedHook.calls--
if savedHook.refs == 0 && savedHook.calls == 0 {
close(savedHook.done)
}
})
}
}

Expand Down
20 changes: 20 additions & 0 deletions internal/syncutil/with.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// misc. utilities for synchronization.
package syncutil

import (
"sync"
)

// Runs f while holding the lock
func With(l sync.Locker, f func()) {
l.Lock()
defer l.Unlock()
f()
}

// Runs f while not holding the lock
func Without(l sync.Locker, f func()) {
l.Unlock()
defer l.Lock()
f()
}
23 changes: 12 additions & 11 deletions rpc/answer.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"capnproto.org/go/capnp/v3"
"capnproto.org/go/capnp/v3/internal/errors"
"capnproto.org/go/capnp/v3/internal/syncutil"
rpccp "capnproto.org/go/capnp/v3/std/capnp/rpc"
)

Expand Down Expand Up @@ -109,11 +110,11 @@ func (c *Conn) newReturn(ctx context.Context) (rpccp.Return, func() error, capnp
// already returned. The caller MUST NOT be holding onto ans.c.mu
// or the sender lock.
func (ans *answer) setPipelineCaller(pcall capnp.PipelineCaller) {
ans.c.mu.Lock()
if ans.flags&resultsReady == 0 {
ans.pcall = pcall
}
ans.c.mu.Unlock()
syncutil.With(&ans.c.mu, func() {
if ans.flags&resultsReady == 0 {
ans.pcall = pcall
}
})
}

// AllocResults allocates the results struct.
Expand Down Expand Up @@ -233,9 +234,9 @@ func (ans *answer) sendReturn() (releaseList, error) {
return nil, nil
}
rl, err := ans.destroy()
ans.c.mu.Unlock()
ans.releaseMsg()
ans.c.mu.Lock()
syncutil.Without(&ans.c.mu, func() {
ans.releaseMsg()
})
return rl, err
}

Expand Down Expand Up @@ -281,9 +282,9 @@ func (ans *answer) sendException(e error) releaseList {
// destroy will never return an error because sendException does
// create any exports.
rl, _ := ans.destroy()
ans.c.mu.Unlock()
ans.releaseMsg()
ans.c.mu.Lock()
syncutil.Without(&ans.c.mu, func() {
ans.releaseMsg()
})
return rl
}

Expand Down
45 changes: 23 additions & 22 deletions rpc/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"

"capnproto.org/go/capnp/v3"
"capnproto.org/go/capnp/v3/internal/syncutil"
rpccp "capnproto.org/go/capnp/v3/std/capnp/rpc"
)

Expand Down Expand Up @@ -105,33 +106,33 @@ func (ic *importClient) Send(ctx context.Context, s capnp.Send) (*capnp.Answer,
// Create call message.
msg, send, release, err := ic.c.transport.NewMessage(ctx)
if err != nil {
ic.c.mu.Lock()
ic.c.questions[q.id] = nil
ic.c.questionID.remove(uint32(q.id))
ic.c.mu.Unlock()
syncutil.With(&ic.c.mu, func() {
ic.c.questions[q.id] = nil
ic.c.questionID.remove(uint32(q.id))
})
return capnp.ErrorAnswer(s.Method, failedf("create message: %w", err)), func() {}
}
ic.c.mu.Lock()
ic.c.unlockSender() // Can't be holding either lock while calling PlaceArgs.
ic.c.mu.Unlock()
syncutil.With(&ic.c.mu, func() {
ic.c.unlockSender() // Can't be holding either lock while calling PlaceArgs.
})
err = ic.c.newImportCallMessage(msg, ic.id, q.id, s)
if err != nil {
ic.c.mu.Lock()
ic.c.questions[q.id] = nil
ic.c.questionID.remove(uint32(q.id))
ic.c.lockSender()
ic.c.mu.Unlock()
syncutil.With(&ic.c.mu, func() {
ic.c.questions[q.id] = nil
ic.c.questionID.remove(uint32(q.id))
ic.c.lockSender()
})
release()
ic.c.mu.Lock()
ic.c.unlockSender()
ic.c.mu.Unlock()
syncutil.With(&ic.c.mu, func() {
ic.c.unlockSender()
})
return capnp.ErrorAnswer(s.Method, err), func() {}
}

// Send call.
ic.c.mu.Lock()
ic.c.lockSender()
ic.c.mu.Unlock()
syncutil.With(&ic.c.mu, func() {
ic.c.lockSender()
})
err = send()
release()

Expand Down Expand Up @@ -198,10 +199,10 @@ func (c *Conn) newImportCallMessage(msg rpccp.Message, imp importID, qid questio
return failedf("place arguments: %w", err)
}
clients := extractCapTable(m)
c.mu.Lock()
// TODO(soon): save param refs
_, err = c.fillPayloadCapTable(payload, clients)
c.mu.Unlock()
syncutil.With(&c.mu, func() {
// TODO(soon): save param refs
_, err = c.fillPayloadCapTable(payload, clients)
})
releaseList(clients).release()
if err != nil {
return annotate(err, "build call message")
Expand Down
Loading