Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve error handling in RPC package #187

Merged
merged 4 commits into from
Nov 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ go 1.16

require (
github.com/kylelemons/godebug v1.1.0
github.com/stretchr/testify v1.7.0 // indirect
github.com/stretchr/testify v1.7.0
lthibault marked this conversation as resolved.
Show resolved Hide resolved
github.com/tinylib/msgp v1.1.5
)
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ golang.org/x/tools v0.0.0-20201022035929-9cf592e881e9/go.mod h1:emZCQorbCU4vsT4f
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
67 changes: 43 additions & 24 deletions internal/errors/errors.go
Original file line number Diff line number Diff line change
@@ -1,57 +1,76 @@
// Package errors provides errors with codes and prefixes.
package errors

import "strconv"
import (
"errors"
"fmt"
"strconv"
)

// capnpError holds a Cap'n Proto exception.
type capnpError struct {
typ Type
prefix string
msg string
// Error holds a Cap'n Proto exception.
type Error struct {
Type Type
Prefix string
Cause error
}

// New creates a new error that formats as "<prefix>: <msg>".
// The type can be recovered using the TypeOf() function.
func New(typ Type, prefix, msg string) error {
return &capnpError{typ, prefix, msg}
func New(typ Type, prefix, msg string) Error {
return Error{typ, prefix, errors.New(msg)}
}

func (e *capnpError) Error() string {
if e.prefix == "" {
return e.msg
func (e Error) Error() string {
if e.Prefix == "" {
return e.Cause.Error()
}
return e.prefix + ": " + e.msg

return fmt.Sprintf("%s: %v", e.Prefix, e.Cause)
}

func (e *capnpError) GoString() string {
return "errors.New(" + e.typ.GoString() + ", " + strconv.Quote(e.prefix) + ", " + strconv.Quote(e.msg) + ")"
func (e Error) Unwrap() error { return e.Cause }

func (e Error) GoString() string {
return fmt.Sprintf("errors.Error{Type: %s, Prefix: %q, Cause: fmt.Errorf(%q)}",
e.Type.GoString(),
e.Prefix,
e.Cause)
lthibault marked this conversation as resolved.
Show resolved Hide resolved
}

// Annotate is creates a new error that formats as "<prefix>: <msg>: <e>".
// If e.Prefix == prefix, the prefix will not be duplicated.
// The returned Error.Type == e.Type.
func (e Error) Annotate(prefix, msg string) Error {
if prefix != e.Prefix {
return Error{e.Type, prefix, fmt.Errorf("%s: %w", msg, e)}
}

return Error{e.Type, prefix, fmt.Errorf("%s: %w", msg, e.Cause)}
}

// Annotate creates a new error that formats as "<prefix>: <msg>: <err>".
// If err has the same prefix, then the prefix won't be duplicated.
// The returned error's type will match err's type.
func Annotate(prefix, msg string, err error) error {
if err == nil {
panic("Annotate on nil error")
}
ce, ok := err.(*capnpError)
if !ok {
return &capnpError{Failed, prefix, msg + ": " + err.Error()}
panic("Annotate on nil error") // TODO: return nil?
}
if prefix != ce.prefix {
return &capnpError{ce.typ, prefix, msg + ": " + err.Error()}

if ce, ok := err.(Error); ok {
return ce.Annotate(prefix, msg)
}
return &capnpError{ce.typ, prefix, msg + ": " + ce.msg}

return Error{Failed, prefix, fmt.Errorf("%s: %w", msg, err)}
}

// TypeOf returns err's type if err was created by this package or
// Failed if it was not.
func TypeOf(err error) Type {
ce, ok := err.(*capnpError)
ce, ok := err.(Error)
if !ok {
return Failed
}
return ce.typ
return ce.Type
}

// Type indicates the type of error, mirroring those in rpc.capnp.
Expand Down
46 changes: 31 additions & 15 deletions internal/errors/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,30 @@ package errors
import (
"errors"
"testing"

"github.com/stretchr/testify/assert"
)

func TestUnwrap(t *testing.T) {
t.Parallel()

var (
errGeneric = errors.New("something went wrong")
err = Annotate("annotated", "test", errGeneric)
exc Error
)

assert.EqualError(t, errors.Unwrap(err), "test: something went wrong")
assert.ErrorIs(t, err, errGeneric)

assert.ErrorAs(t, err, &exc)
assert.Equal(t, "annotated", exc.Prefix)
assert.EqualError(t, exc.Cause, "test: something went wrong")
}

func TestErrorString(t *testing.T) {
t.Parallel()

tests := []struct {
typ Type
prefix string
Expand All @@ -17,14 +38,14 @@ func TestErrorString(t *testing.T) {
{Failed, "capnp", "goofed", "capnp: goofed"},
}
for _, test := range tests {
got := New(test.typ, test.prefix, test.msg).Error()
if got != test.want {
t.Errorf("New(%#v, %q, %q).Error() = %q; want %q", test.typ, test.prefix, test.msg, got, test.want)
}
err := New(test.typ, test.prefix, test.msg)
assert.EqualError(t, err, test.want)
}
}

func TestTypeOf(t *testing.T) {
t.Parallel()

tests := []struct {
err error
want Type
Expand All @@ -37,13 +58,13 @@ func TestTypeOf(t *testing.T) {
{New(Unimplemented, "capnp", "unimplemented error"), Unimplemented},
}
for _, test := range tests {
if got := TypeOf(test.err); got != test.want {
t.Errorf("TypeOf(%#v) = %#v; want %#v", test.err, got, test.want)
}
assert.Equal(t, test.want, TypeOf(test.err))
}
}

func TestAnnotate(t *testing.T) {
t.Parallel()

tests := []struct {
prefix string
msg string
Expand Down Expand Up @@ -112,13 +133,8 @@ func TestAnnotate(t *testing.T) {
},
}
for _, test := range tests {
got := Annotate(test.prefix, test.msg, test.err)
if got.Error() != test.want {
t.Errorf("Annotate(%q, %q, %#v).Error() = %q; %q", test.prefix, test.msg, test.err, got.Error(), test.want)
}
gotType := TypeOf(got)
if gotType != test.wantType {
t.Errorf("TypeOf(Annotate(%q, %q, %#v)) = %#v; %#v", test.prefix, test.msg, test.err, gotType, test.wantType)
}
err := Annotate(test.prefix, test.msg, test.err)
assert.EqualError(t, err, test.want)
assert.Equal(t, test.wantType, TypeOf(err))
}
}
24 changes: 12 additions & 12 deletions rpc/answer.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ func errorAnswer(c *Conn, id answerID, err error) *answer {
func (c *Conn) newReturn(ctx context.Context) (rpccp.Return, func() error, capnp.ReleaseFunc, error) {
msg, send, release, err := c.transport.NewMessage(ctx)
if err != nil {
return rpccp.Return{}, nil, nil, errorf("create return: %v", err)
return rpccp.Return{}, nil, nil, failedf("create return: %w", err)
}
ret, err := msg.NewReturn()
if err != nil {
release()
return rpccp.Return{}, nil, nil, errorf("create return: %v", err)
return rpccp.Return{}, nil, nil, failedf("create return: %w", err)
}
return ret, send, release, nil
}
Expand All @@ -121,14 +121,14 @@ func (ans *answer) AllocResults(sz capnp.ObjectSize) (capnp.Struct, error) {
var err error
ans.results, err = ans.ret.NewResults()
if err != nil {
return capnp.Struct{}, errorf("alloc results: %v", err)
return capnp.Struct{}, failedf("alloc results: %w", err)
}
s, err := capnp.NewStruct(ans.results.Segment(), sz)
if err != nil {
return capnp.Struct{}, errorf("alloc results: %v", err)
return capnp.Struct{}, failedf("alloc results: %w", err)
}
if err := ans.results.SetContent(s.ToPtr()); err != nil {
return capnp.Struct{}, errorf("alloc results: %v", err)
return capnp.Struct{}, failedf("alloc results: %w", err)
}
return s, nil
}
Expand All @@ -145,11 +145,11 @@ func (ans *answer) setBootstrap(c *capnp.Client) error {
var err error
ans.results, err = ans.ret.NewResults()
if err != nil {
return errorf("alloc bootstrap results: %v", err)
return failedf("alloc bootstrap results: %w", err)
}
iface := capnp.NewInterface(ans.results.Segment(), 0)
if err := ans.results.SetContent(iface.ToPtr()); err != nil {
return errorf("alloc bootstrap results: %v", err)
return failedf("alloc bootstrap results: %w", err)
}
return nil
}
Expand Down Expand Up @@ -210,7 +210,7 @@ func (ans *answer) sendReturn(cstates []capnp.ClientState) (releaseList, error)
var err error
ans.exportRefs, err = ans.c.fillPayloadCapTable(ans.results, ans.resultCapTable, cstates)
if err != nil {
ans.c.report(annotate(err).errorf("send return"))
ans.c.report(annotate(err, "send return"))
// Continue. Don't fail to send return if cap table isn't fully filled.
}

Expand All @@ -220,7 +220,7 @@ func (ans *answer) sendReturn(cstates []capnp.ClientState) (releaseList, error)
fin := ans.flags&finishReceived != 0
ans.c.mu.Unlock()
if err := ans.sendMsg(); err != nil {
ans.c.reportf("send return: %v", err)
ans.c.report(failedf("send return: %w", err))
}
if fin {
ans.releaseMsg()
Expand Down Expand Up @@ -258,13 +258,13 @@ func (ans *answer) sendException(e error) releaseList {
fin := ans.flags&finishReceived != 0
ans.c.mu.Unlock()
if exc, err := ans.ret.NewException(); err != nil {
ans.c.reportf("send exception: %v", err)
ans.c.report(failedf("send exception: %w", err))
} else {
exc.SetType(rpccp.Exception_Type(errors.TypeOf(e)))
if err := exc.SetReason(e.Error()); err != nil {
ans.c.reportf("send exception: %v", err)
ans.c.report(failedf("send exception: %w", err))
} else if err := ans.sendMsg(); err != nil {
ans.c.reportf("send return: %v", err)
ans.c.report(failedf("send return: %w", err))
}
}
if fin {
Expand Down
53 changes: 53 additions & 0 deletions rpc/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package rpc

import (
goerr "errors"
"fmt"

"capnproto.org/go/capnp/v3/internal/errors"
)

const prefix = "rpc"

var (
// Base errors
ErrConnClosed = goerr.New("connection closed")
ErrNotACapability = goerr.New("not a capability")
ErrCapTablePopulated = goerr.New("capability table already populated")

// RPC exceptions
ExcClosed = disconnected(ErrConnClosed)
ExcAlreadyClosed = failed(goerr.New("close on closed connection"))
)

func failedf(format string, args ...interface{}) errors.Error {
return failed(fmt.Errorf(format, args...))
}

func failed(err error) errors.Error {
return exception(errors.Failed, err)
}

func disconnectedf(format string, args ...interface{}) errors.Error {
return disconnected(fmt.Errorf(format, args...))
}

func disconnected(err error) errors.Error {
return exception(errors.Disconnected, err)
}

func unimplementedf(format string, args ...interface{}) errors.Error {
return unimplemented(fmt.Errorf(format, args...))
}

func unimplemented(err error) errors.Error {
return exception(errors.Unimplemented, err)
}

func annotate(err error, msg string) error {
return errors.Annotate(prefix, msg, err)
}

func exception(t errors.Type, err error) errors.Error {
return errors.Error{Type: t, Prefix: prefix, Cause: err}
}
14 changes: 7 additions & 7 deletions rpc/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (c *Conn) findExport(id exportID) *expent {
func (c *Conn) releaseExport(id exportID, count uint32) (*capnp.Client, error) {
ent := c.findExport(id)
if ent == nil {
return nil, errorf("unknown export ID %d", id)
return nil, failedf("unknown export ID %d", id)
}
switch {
case count == ent.wireRefs:
Expand All @@ -43,7 +43,7 @@ func (c *Conn) releaseExport(id exportID, count uint32) (*capnp.Client, error) {
c.exportID.remove(uint32(id))
return client, nil
case count > ent.wireRefs:
return nil, errorf("export ID %d released too many references", id)
return nil, failedf("export ID %d released too many references", id)
default:
ent.wireRefs -= count
return nil, nil
Expand Down Expand Up @@ -133,7 +133,7 @@ func (c *Conn) fillPayloadCapTable(payload rpccp.Payload, clients []*capnp.Clien
}
list, err := payload.NewCapTable(int32(len(clients)))
if err != nil {
return nil, errorf("payload capability table: %v", err)
return nil, failedf("payload capability table: %w", err)
}
var refs map[exportID]uint32
for i, client := range clients {
Expand Down Expand Up @@ -247,19 +247,19 @@ type senderLoopback struct {
func (sl *senderLoopback) buildDisembargo(msg rpccp.Message) error {
d, err := msg.NewDisembargo()
if err != nil {
return errorf("build disembargo: %v", err)
return failedf("build disembargo: %w", err)
}
tgt, err := d.NewTarget()
if err != nil {
return errorf("build disembargo: %v", err)
return failedf("build disembargo: %w", err)
}
pa, err := tgt.NewPromisedAnswer()
if err != nil {
return errorf("build disembargo: %v", err)
return failedf("build disembargo: %w", err)
}
oplist, err := pa.NewTransform(int32(len(sl.transform)))
if err != nil {
return errorf("build disembargo: %v", err)
return failedf("build disembargo: %w", err)
}

d.Context().SetSenderLoopback(uint32(sl.id))
Expand Down
Loading