From e0eb98e184b75cbd470629de3f84082a2ffe91c1 Mon Sep 17 00:00:00 2001 From: Louis Thibault <9452561+lthibault@users.noreply.github.com> Date: Wed, 3 Mar 2021 11:43:22 -0500 Subject: [PATCH] Add option to use packed encoding in rpc.StreamTransport. (#161) Changelog: * Eschew rpc.StreamTransport in favor of generic codec-based transport (unexported). * Add codec for generic stream transport. * Add codec for stream transport with packed encoding. * Add NewCodecTransport and NewPackedStreamTransport constructors. * Replace mutex-protected error in transport with atomic.Value. This improves legibility by disambiguating what the mutex is locking, and by makin error lookup/setting semantics explicit. --- rpc/transport.go | 247 +++++++++++++++++++++++++++++------------- rpc/transport_test.go | 22 +++- 2 files changed, 189 insertions(+), 80 deletions(-) diff --git a/rpc/transport.go b/rpc/transport.go index 8ed71577..2dfa8659 100644 --- a/rpc/transport.go +++ b/rpc/transport.go @@ -3,10 +3,10 @@ package rpc import ( "context" "io" - "sync" + "sync/atomic" "time" - "zombiezen.com/go/capnproto2" + capnp "zombiezen.com/go/capnproto2" "zombiezen.com/go/capnproto2/internal/errors" rpccp "zombiezen.com/go/capnproto2/std/capnp/rpc" ) @@ -52,20 +52,28 @@ type Transport interface { Close() error } -// StreamTransport serializes and deserializes unpacked Cap'n Proto -// messages on a byte stream. StreamTransport adds no buffering beyond -// what its underlying stream has. -type StreamTransport struct { - cr ctxReader - wc io.WriteCloser - - partialWriteTimeout time.Duration - closed bool +// A Codec is responsible for encoding and decoding messages from +// a single logical stream. +type Codec interface { + Encode(context.Context, *capnp.Message) error + Decode(context.Context) (*capnp.Message, error) + SetPartialWriteTimeout(time.Duration) + Close() error +} - mu sync.RWMutex - err error +// A transport serializes and deserializes Cap'n Proto using a Codec. +// It adds no buffering beyond what is provided by the underlying +// byte transfer mechanism. +type transport struct { + c Codec + closed bool + err errorValue } +// NewTransport creates a new transport that uses the supplied codec +// to read and write messages across the wire. +func NewTransport(c Codec) Transport { return &transport{c: c} } + // NewStreamTransport creates a new transport that reads and writes to rwc. // Closing the transport will close rwc. // @@ -74,24 +82,24 @@ type StreamTransport struct { // have these methods, then rwc.Close must be safe to call concurrently // with rwc.Read. Notably, this is not true of *os.File before Go 1.9 // (see https://golang.org/issue/7970). -func NewStreamTransport(rwc io.ReadWriteCloser) *StreamTransport { - return &StreamTransport{ - cr: ctxReader{r: rwc}, - wc: rwc, +func NewStreamTransport(rwc io.ReadWriteCloser) Transport { + return NewTransport(newStreamCodec(rwc, basicEncoding{})) +} - partialWriteTimeout: 30 * time.Second, - } +// NewPackedStreamTransport creates a new transport that uses a packed +// encoding. +// +// See: NewStreamTransport. +func NewPackedStreamTransport(rwc io.ReadWriteCloser) Transport { + return NewTransport(newStreamCodec(rwc, packedEncoding{})) } // NewMessage allocates a new message to be sent. // // It is safe to call NewMessage concurrently with RecvMessage. -func (s *StreamTransport) NewMessage(ctx context.Context) (_ rpccp.Message, send func() error, release capnp.ReleaseFunc, _ error) { +func (s *transport) NewMessage(ctx context.Context) (_ rpccp.Message, send func() error, release capnp.ReleaseFunc, _ error) { // Check if stream is broken - s.mu.RLock() - err := s.err - s.mu.RUnlock() - if err != nil { + if err := s.err.Load(); err != nil { return rpccp.Message{}, nil, nil, err } @@ -104,37 +112,31 @@ func (s *StreamTransport) NewMessage(ctx context.Context) (_ rpccp.Message, send if err != nil { return rpccp.Message{}, nil, nil, errors.New(errors.Failed, "rpc stream transport", "new message: "+err.Error()) } + send = func() error { - select { - case <-ctx.Done(): + // context expired? + if err := ctx.Err(); err != nil { return errors.New(errors.Failed, "rpc stream transport", "send: "+ctx.Err().Error()) - default: } - s.mu.RLock() - err := s.err - s.mu.RUnlock() - if err != nil { + + // stream error? + if err := s.err.Load(); err != nil { return err } - b, err := msg.Marshal() - if err != nil { - return errors.New(errors.Failed, "rpc stream transport", "send: "+err.Error()) - } - n, err := writeCtx(ctx, s.wc, b, s.partialWriteTimeout) - if n > 0 && n < len(b) { - s.mu.Lock() - s.err = errors.New(errors.Disconnected, "rpc stream transport", "broken due to partial write") - s.mu.Unlock() - } - if err != nil { - return errors.New(errors.Failed, "rpc stream transport", "send: "+err.Error()) + + // ok, go! + if err = s.c.Encode(ctx, msg); err != nil { + if _, ok := err.(partialWriteError); ok { + s.err.Set(errors.New(errors.Disconnected, "rpc stream transport", "broken due to partial write")) + } + + err = errors.New(errors.Failed, "rpc stream transport", "send: "+err.Error()) } - return nil - } - release = func() { - msg.Reset(nil) + + return err } - return rmsg, send, release, nil + + return rmsg, send, func() { msg.Reset(nil) }, nil } // SetPartialWriteTimeout sets the timeout for completing the @@ -145,22 +147,19 @@ func (s *StreamTransport) NewMessage(ctx context.Context) (_ rpccp.Message, send // Setting a shorter timeout may free up resources faster in the case of // an unresponsive remote peer, but may also make the transport respond // too aggressively to bursts of latency. -func (s *StreamTransport) SetPartialWriteTimeout(d time.Duration) { - s.partialWriteTimeout = d +func (s *transport) SetPartialWriteTimeout(d time.Duration) { + s.c.SetPartialWriteTimeout(d) } // RecvMessage reads the next message from the underlying reader. // // It is safe to call RecvMessage concurrently with NewMessage. -func (s *StreamTransport) RecvMessage(ctx context.Context) (rpccp.Message, capnp.ReleaseFunc, error) { - s.mu.RLock() - err := s.err - s.mu.RUnlock() - if err != nil { +func (s *transport) RecvMessage(ctx context.Context) (rpccp.Message, capnp.ReleaseFunc, error) { + if err := s.err.Load(); err != nil { return rpccp.Message{}, nil, err } - s.cr.ctx = ctx - msg, err := capnp.NewDecoder(&s.cr).Decode() + + msg, err := s.c.Decode(ctx) if err != nil { return rpccp.Message{}, nil, errors.New(errors.Failed, "rpc stream transport", "receive: "+err.Error()) } @@ -173,22 +172,79 @@ func (s *StreamTransport) RecvMessage(ctx context.Context) (rpccp.Message, capnp // Close closes the underlying ReadWriteCloser. It is not safe to call // Close concurrently with any other operations on the transport. -func (s *StreamTransport) Close() error { +func (s *transport) Close() error { if s.closed { return errors.New(errors.Disconnected, "rpc stream transport", "already closed") } s.closed = true - err := s.wc.Close() - s.cr.wait() + err := s.c.Close() if err != nil { return errors.New(errors.Failed, "rpc stream transport", "close: "+err.Error()) } return nil } +type streamCodec struct { + r *ctxReader + dec *capnp.Decoder + + wc *ctxWriteCloser + enc *capnp.Encoder +} + +func newStreamCodec(rwc io.ReadWriteCloser, f streamEncoding) *streamCodec { + c := &streamCodec{ + r: &ctxReader{Reader: rwc}, + wc: &ctxWriteCloser{ + WriteCloser: rwc, + partialWriteTimeout: 30 * time.Second, + }, + } + + c.dec = f.NewDecoder(c.r) + c.enc = f.NewEncoder(c.wc) + + return c +} + +func (c *streamCodec) Encode(ctx context.Context, m *capnp.Message) error { + c.wc.setWriteContext(ctx) + return c.enc.Encode(m) +} + +func (c *streamCodec) Decode(ctx context.Context) (*capnp.Message, error) { + c.r.setReadContext(ctx) + return c.dec.Decode() +} + +func (c *streamCodec) SetPartialWriteTimeout(d time.Duration) { + c.wc.partialWriteTimeout = d +} + +func (c streamCodec) Close() error { + defer c.r.wait() + + return c.wc.Close() +} + +type streamEncoding interface { + NewEncoder(io.Writer) *capnp.Encoder + NewDecoder(io.Reader) *capnp.Decoder +} + +type basicEncoding struct{} + +func (basicEncoding) NewEncoder(w io.Writer) *capnp.Encoder { return capnp.NewEncoder(w) } +func (basicEncoding) NewDecoder(r io.Reader) *capnp.Decoder { return capnp.NewDecoder(r) } + +type packedEncoding struct{} + +func (packedEncoding) NewEncoder(w io.Writer) *capnp.Encoder { return capnp.NewPackedEncoder(w) } +func (packedEncoding) NewDecoder(r io.Reader) *capnp.Decoder { return capnp.NewPackedDecoder(r) } + // ctxReader adds timeouts and cancellation to a reader. type ctxReader struct { - r io.Reader + io.Reader ctx context.Context // set to change Context // internal state @@ -203,6 +259,8 @@ type readResult struct { err error } +func (cr *ctxReader) setReadContext(ctx context.Context) { cr.ctx = ctx } + // Read reads into p. It makes a best effort to respect the Done signal // in cr.ctx. func (cr *ctxReader) Read(p []byte) (int, error) { @@ -240,7 +298,7 @@ func (cr *ctxReader) Read(p []byte) (int, error) { default: } // Query timeout support. - rd, ok := cr.r.(interface { + rd, ok := cr.Reader.(interface { SetReadDeadline(time.Time) error }) if !ok { @@ -265,7 +323,7 @@ func (cr *ctxReader) Read(p []byte) (int, error) { case <-readDone: } }() - n, err := cr.r.Read(p) + n, err := cr.Reader.Read(p) close(readDone) <-listenDone return n, err @@ -281,7 +339,7 @@ func (cr *ctxReader) leakyRead(p []byte) (int, error) { max = len(cr.buf) } go func() { - n, err := cr.r.Read(cr.buf[:max]) + n, err := cr.Reader.Read(cr.buf[:max]) cr.result <- readResult{n, err} }() select { @@ -294,7 +352,7 @@ func (cr *ctxReader) leakyRead(p []byte) (int, error) { } } -// wait waits until any goroutine started by leakyRead finishes. +// wait until any goroutine started by leakyRead finishes. func (cr *ctxReader) wait() { if cr.result == nil { return @@ -305,29 +363,46 @@ func (cr *ctxReader) wait() { cr.err = r.err } -// writeCtx writes bytes to a writer while making a best effort to +type ctxWriteCloser struct { + io.WriteCloser + ctx context.Context + partialWriteTimeout time.Duration +} + +// Write bytes to a writer while making a best effort to // respect the Done signal of the Context. However, if allowPartial is // false, then once any bytes have been written to w, writeCtx will // ignore the Done signal to avoid partial writes. -func writeCtx(ctx context.Context, w io.Writer, b []byte, partialTimeout time.Duration) (int, error) { +func (wc *ctxWriteCloser) Write(b []byte) (int, error) { + n, err := wc.write(b) + if n > 0 && n < len(b) { + err = partialWriteError{err} + } + + return n, err +} + +func (wc *ctxWriteCloser) setWriteContext(ctx context.Context) { wc.ctx = ctx } + +func (wc *ctxWriteCloser) write(b []byte) (int, error) { select { - case <-ctx.Done(): + case <-wc.ctx.Done(): // Early cancel. - return 0, ctx.Err() + return 0, wc.ctx.Err() default: } // Check for timeout support. - wd, ok := w.(interface { + wd, ok := wc.WriteCloser.(interface { SetWriteDeadline(time.Time) error }) if !ok { - return w.Write(b) + return wc.WriteCloser.Write(b) } if err := wd.SetWriteDeadline(time.Now()); err != nil { - return w.Write(b) + return wc.WriteCloser.Write(b) } // Start separate goroutine to wait on Context.Done. - if d, ok := ctx.Deadline(); ok { + if d, ok := wc.ctx.Deadline(); ok { wd.SetWriteDeadline(d) } else { wd.SetWriteDeadline(time.Time{}) @@ -337,21 +412,21 @@ func writeCtx(ctx context.Context, w io.Writer, b []byte, partialTimeout time.Du go func() { defer close(listenDone) select { - case <-ctx.Done(): + case <-wc.ctx.Done(): wd.SetWriteDeadline(time.Now()) // interrupt write case <-writeDone: } }() - n, err := w.Write(b) + n, err := wc.WriteCloser.Write(b) close(writeDone) <-listenDone - if partialTimeout <= 0 || n == 0 || !isTimeout(err) { + if wc.partialWriteTimeout <= 0 || n == 0 || !isTimeout(err) { return n, err } // Data has been written. Block with extra partial timeout, since // partial writes are guaranteed protocol violations. - wd.SetWriteDeadline(time.Now().Add(partialTimeout)) - nn, err := w.Write(b[n:]) + wd.SetWriteDeadline(time.Now().Add(wc.partialWriteTimeout)) + nn, err := wc.WriteCloser.Write(b[n:]) return n + nn, err } @@ -361,3 +436,19 @@ func isTimeout(e error) bool { }) return ok && te.Timeout() } + +type partialWriteError struct{ error } + +type errorValue atomic.Value + +func (ev *errorValue) Load() error { + if err := (*atomic.Value)(ev).Load(); err != nil { + return err.(error) + } + + return nil +} + +func (ev *errorValue) Set(err error) { + (*atomic.Value)(ev).Store(err) +} diff --git a/rpc/transport_test.go b/rpc/transport_test.go index 296b1d0b..805bed70 100644 --- a/rpc/transport_test.go +++ b/rpc/transport_test.go @@ -3,11 +3,12 @@ package rpc_test import ( "context" "errors" + "io" "net" "testing" "time" - "zombiezen.com/go/capnproto2" + capnp "zombiezen.com/go/capnproto2" "zombiezen.com/go/capnproto2/rpc" rpccp "zombiezen.com/go/capnproto2/std/capnp/rpc" ) @@ -178,10 +179,25 @@ func testTransport(t *testing.T, makePipe func() (t1, t2 rpc.Transport, err erro } func TestTCPStreamTransport(t *testing.T) { + t.Run("Unpacked", func(t *testing.T) { + t.Parallel() + + testTCPStreamTransport(t, rpc.NewStreamTransport) + }) + + t.Run("Packed", func(t *testing.T) { + t.Parallel() + + testTCPStreamTransport(t, rpc.NewPackedStreamTransport) + }) +} + +func testTCPStreamTransport(t *testing.T, newTransport func(io.ReadWriteCloser) rpc.Transport) { type listenCall struct { c *net.TCPConn err error } + makePipe := func() (t1, t2 rpc.Transport, err error) { host, err := net.LookupIP("localhost") if err != nil { @@ -214,11 +230,13 @@ func TestTCPStreamTransport(t *testing.T) { l.Close() return nil, nil, err } - return rpc.NewStreamTransport(lc.c), rpc.NewStreamTransport(c2), nil + return newTransport(lc.c), newTransport(c2), nil } + t.Run("ServerToClient", func(t *testing.T) { testTransport(t, makePipe) }) + t.Run("ClientToServer", func(t *testing.T) { testTransport(t, func() (t1, t2 rpc.Transport, err error) { t2, t1, err = makePipe()