Skip to content

Commit

Permalink
Replace mutex-protected error in transport with atomic.Value.
Browse files Browse the repository at this point in the history
Improves legibility by removing ambiguity about what the mutex is
protecting, and rendering explicit the error lookup/setting semantics.
  • Loading branch information
lthibault committed Feb 22, 2021
1 parent 9b9ae57 commit b611000
Showing 1 changed file with 20 additions and 19 deletions.
39 changes: 20 additions & 19 deletions rpc/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package rpc
import (
"context"
"io"
"sync"
"sync/atomic"
"time"

capnp "zombiezen.com/go/capnproto2"
Expand Down Expand Up @@ -67,9 +67,7 @@ type Codec interface {
type transport struct {
c Codec
closed bool

mu sync.RWMutex
err error
err errorValue
}

// NewTransport creates a new transport that uses the supplied codec
Expand Down Expand Up @@ -101,10 +99,7 @@ func NewPackedStreamTransport(rwc io.ReadWriteCloser) Transport {
// It is safe to call NewMessage concurrently with RecvMessage.
func (s *transport) NewMessage(ctx context.Context) (_ rpccp.Message, send func() error, release capnp.ReleaseFunc, _ error) {
// Check if stream is broken
s.mu.RLock()
err := s.err
s.mu.RUnlock()
if err != nil {
if err := s.err.Load(); err != nil {
return rpccp.Message{}, nil, nil, err
}

Expand All @@ -125,19 +120,14 @@ func (s *transport) NewMessage(ctx context.Context) (_ rpccp.Message, send func(
}

// stream error?
s.mu.RLock()
err := s.err
s.mu.RUnlock()
if err != nil {
if err := s.err.Load(); err != nil {
return err
}

// ok, go!
if err = s.c.Encode(ctx, msg); err != nil {
if _, ok := err.(partialWriteError); ok {
s.mu.Lock()
s.err = errors.New(errors.Disconnected, "rpc stream transport", "broken due to partial write")
s.mu.Unlock()
s.err.Set(errors.New(errors.Disconnected, "rpc stream transport", "broken due to partial write"))
}

err = errors.New(errors.Failed, "rpc stream transport", "send: "+err.Error())
Expand Down Expand Up @@ -165,10 +155,7 @@ func (s *transport) SetPartialWriteTimeout(d time.Duration) {
//
// It is safe to call RecvMessage concurrently with NewMessage.
func (s *transport) RecvMessage(ctx context.Context) (rpccp.Message, capnp.ReleaseFunc, error) {
s.mu.RLock()
err := s.err
s.mu.RUnlock()
if err != nil {
if err := s.err.Load(); err != nil {
return rpccp.Message{}, nil, err
}

Expand Down Expand Up @@ -451,3 +438,17 @@ func isTimeout(e error) bool {
}

type partialWriteError struct{ error }

type errorValue atomic.Value

func (ev *errorValue) Load() error {
if err := (*atomic.Value)(ev).Load(); err != nil {
return err.(error)
}

return nil
}

func (ev *errorValue) Set(err error) {
(*atomic.Value)(ev).Store(err)
}

0 comments on commit b611000

Please sign in to comment.