Skip to content

Commit

Permalink
Add Marshaler and Unmarshaler interfaces
Browse files Browse the repository at this point in the history
Add Marshaler and Unmarshaler interfaces, which allow for custom encoding/decoding
working similarly to encoding/json Marshaler and Unmarshaler:

```
type Marshaler interface {
	MarshalXDR(e *Encoder) (int, error)
}

type Unmarshaler interface {
	UnmarshalXDR(e *Decoder) (int, error)
}
```

The Decoder and Encoder automatically call the relevant methods if found.

Both implementation by value and by pointer are supported.

A new error (`ErrCustomUnaddressableByPointer`) will be returned if
a by-pointer implementation is found but the value passed is unaddressable
(not allowing to call the method by pointer).
  • Loading branch information
2opremio committed Nov 10, 2021
1 parent 8017fc4 commit 535624d
Show file tree
Hide file tree
Showing 5 changed files with 270 additions and 13 deletions.
76 changes: 74 additions & 2 deletions xdr3/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,11 @@ import (

const maxInt32 = int(^uint32(0) >> 1)

var errMaxSlice = "data exceeds max slice limit"
var errIODecode = "%s while decoding %d bytes"
var (
errMaxSlice = "data exceeds max slice limit"
errIODecode = "%s while decoding %d bytes"
unmarshalerType = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
)

/*
Unmarshal parses XDR-encoded data into the value pointed to by v reading from
Expand Down Expand Up @@ -74,6 +77,9 @@ UnmarshalError is returned with a human readable description as well as
an ErrorCode value for further inspection from sophisticated callers. Some
potential issues are unsupported Go types, attempting to decode a value which is
too large to fit into a specified Go type, and exceeding max slice limitations.
If any of the types encountered comply with the Unmarshaler interface, the UnmarshalXDR
method will be used instead of reflection (similarly to UnmarshalJSON in the encoding/json package)
*/
func Unmarshal(r io.Reader, v interface{}) (int, error) {
d := Decoder{r: r}
Expand Down Expand Up @@ -770,6 +776,37 @@ func (d *Decoder) decodeInterface(v reflect.Value) (int, error) {
return d.decode(ve, 0)
}

func getUnmarshaler(ve reflect.Value) Unmarshaler {
// Check the type first to avoid an unnecessary allocation when casting to interface
if ve.Type().Implements(unmarshalerType) && ve.CanInterface() {
return ve.Interface().(Unmarshaler)
}
return nil
}

func getUnmarshalerFromNonPtr(ve reflect.Value) (Unmarshaler, error) {
if u := getUnmarshaler(ve); u != nil {
return u, nil
}

// Implements Unmarshaller by address?
// Check the type first, to avoid an unnecessary allocation when casting to interface
// TODO: this check is redundant if decoding got here through a pointer type
if reflect.PtrTo(ve.Type()).Implements(unmarshalerType) {
if !ve.CanAddr() {
msg := fmt.Sprintf("unaddressable value of %s implementing MarshalXDR() by pointer", ve.Type().Name())
err := marshalError("encode", ErrCustomUnaddressableByPointer, msg, nil, nil)
return nil, err
}
ptr := ve.Addr()
if ptr.CanInterface() {
return ptr.Interface().(Unmarshaler), nil
}
}

return nil, nil
}

// decode is the main workhorse for unmarshalling via reflection. It uses
// the passed reflection value to choose the XDR primitives to decode from
// the encapsulated reader. It is a recursive function,
Expand All @@ -782,6 +819,19 @@ func (d *Decoder) decode(ve reflect.Value, maxSize int) (int, error) {
return 0, err
}

// Check for Unmarshaler interface
if ve.Kind() != reflect.Ptr {
// For pointer types this will be checked again later on.
// (this is because pointers require decoding whether they are nil)
u, err := getUnmarshalerFromNonPtr(ve)
if err != nil {
return 0, err
}
if u != nil {
return u.UnmarshalXDR(d)
}
}

// Handle time.Time values by decoding them as an RFC3339 formatted
// string with nanosecond precision. Check the type string rather
// than doing a full blown conversion to interface and type assertion
Expand Down Expand Up @@ -1013,6 +1063,13 @@ func (d *Decoder) decodePtr(v reflect.Value) (int, error) {
return n, err
}

// We intentionally check for Marshaler before deferencing the pointer.
// This is because, for unaddressable values, you cannot recover the
// pointer later on, which would make impossible to invoke by-address marshalers.
if u := getUnmarshaler(v); u != nil {
n2, err := u.UnmarshalXDR(d)
return n + n2, err
}
n2, err := d.decode(v.Elem(), 0)
return n + n2, err
}
Expand All @@ -1038,6 +1095,10 @@ func (d *Decoder) Decode(v interface{}) (int, error) {
nil)
}

if unmarshaler, ok := v.(Unmarshaler); ok {
return unmarshaler.UnmarshalXDR(d)
}

vv := reflect.ValueOf(v)
if vv.Kind() != reflect.Ptr {
msg := fmt.Sprintf("can't unmarshal to non-pointer '%v' - use "+
Expand All @@ -1055,9 +1116,20 @@ func (d *Decoder) Decode(v interface{}) (int, error) {
return d.decode(vv.Elem(), 0)
}

// Read reads from the internal reader. This method can be useful for implementing UnmarshalXDR
func (d *Decoder) Read(p []byte) (int, error) {
return d.r.Read(p)
}

// NewDecoder returns a Decoder that can be used to manually decode XDR data
// from a provided reader. Typically, Unmarshal should be used instead of
// manually creating a Decoder.
func NewDecoder(r io.Reader) *Decoder {
return &Decoder{r: r}
}

// Unmarshaler is the interface implemented by types that can unmarshal themselves from valid XDR.
// The supplied decoder should be used to perform the decoding.
type Unmarshaler interface {
UnmarshalXDR(e *Decoder) (int, error)
}
52 changes: 52 additions & 0 deletions xdr3/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,51 @@ type subTest struct {
B uint8
}

type customDecodedType struct {
A uint8
}

const magicCustomDecoderCount = 123456

func (c *customDecodedType) UnmarshalXDR(e *Decoder) (int, error) {
i, _, err := e.DecodeInt()
c.A = uint8(i)
// return magic number to indicate that the custom unmarshalling code was called
return magicCustomDecoderCount, err
}

type customDecodedTypeByValue struct {
A uint8
}

// make sure that (odd) method implementations by value also work
func (c customDecodedTypeByValue) UnmarshalXDR(e *Decoder) (int, error) {
// there is no point in consuming the input since it can be saved (the method invocation is by value)
i, _, err := e.DecodeInt()
// return magic number to indicate that the custom unmarshalling code was called
return magicCustomDecoderCount + int(i), err
}

type customDecodedSubType struct {
A uint8
B customDecodedType
}

type customDecodedSubTypeByValue struct {
A uint8
B customDecodedTypeByValue
}

type customDecodedSubTypeWithPointer struct {
A uint8
B *customDecodedType
}

type customDecodedSubTypeByValueWithPointer struct {
A uint8
B *customDecodedTypeByValue
}

// allTypesTest is used to allow testing of the Unmarshal function into struct
// fields of all supported types.
type allTypesTest struct {
Expand Down Expand Up @@ -388,6 +433,13 @@ func TestUnmarshal(t *testing.T) {
{[]byte{0x00, 0x00, 0x00}, opaqueStruct{}, 3, &UnmarshalError{ErrorCode: ErrIO}},
{[]byte{0x00, 0x00, 0x00, 0x00, 0x00}, opaqueStruct{}, 5, &UnmarshalError{ErrorCode: ErrIO}},

// Unmarshaler interface
{[]byte{0x00, 0x00, 0x00, 0xFF}, customDecodedType{255}, magicCustomDecoderCount, nil},
{[]byte{0x00, 0x00, 0x00, 0xFF, 0x00, 0x00, 0x00, 0x10}, customDecodedSubType{255, customDecodedType{16}}, magicCustomDecoderCount + 4, nil},
{[]byte{0x00, 0x00, 0x00, 0xFF, 0x00, 0x00, 0x00, 0x10}, customDecodedSubTypeByValue{255, customDecodedTypeByValue{0}}, (magicCustomDecoderCount + 16) + 4, nil},
{[]byte{0x00, 0x00, 0x00, 0xFF, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10}, customDecodedSubTypeWithPointer{255, &customDecodedType{16}}, magicCustomDecoderCount + 4 + 4, nil},
{[]byte{0x00, 0x00, 0x00, 0xFF, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10}, customDecodedSubTypeByValueWithPointer{255, &customDecodedTypeByValue{0}}, (magicCustomDecoderCount + 16) + 4 + 4, nil},

// Expected errors
{nil, nilInterface, 0, &UnmarshalError{ErrorCode: ErrNilInterface}},
{nil, &nilInterface, 0, &UnmarshalError{ErrorCode: ErrIO}},
Expand Down
77 changes: 76 additions & 1 deletion xdr3/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ import (
"time"
)

var errIOEncode = "%s while encoding %d bytes"
var (
errIOEncode = "%s while encoding %d bytes"
marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem()
)

/*
Marshal writes the XDR encoding of v to writer w and returns the number of bytes
Expand Down Expand Up @@ -71,6 +74,9 @@ returned with a human readable description as well as an ErrorCode value for
further inspection from sophisticated callers. Some potential issues are
unsupported Go types, attempting to encode more opaque data than can be
represented by a single opaque XDR entry, and exceeding max slice limitations.
If any of the types encountered comply with the Marshaler interface, the MarshalXDR
method will be used instead of reflection (similarly to MarshalJSON in the encoding/json package)
*/
func Marshal(w io.Writer, v interface{}) (int, error) {
enc := Encoder{w: w}
Expand Down Expand Up @@ -609,6 +615,38 @@ func (enc *Encoder) encodeInterface(v reflect.Value) (int, error) {
return enc.encode(ve)
}

func getMarshaler(ve reflect.Value) Marshaler {
if ve.Type().Implements(marshalerType) && ve.CanInterface() {
return ve.Interface().(Marshaler)
}
return nil
}

func getMarshalerFromNonPtr(ve reflect.Value) (Marshaler, error) {
if m := getMarshaler(ve); m != nil {
return m, nil
}

// Implements it by address?
// TODO: this check is redundant if encoding got here through a pointer type
if reflect.PtrTo(ve.Type()).Implements(marshalerType) {
if !ve.CanAddr() {
msg := fmt.Sprintf("unaddressable value of %s implementing MarshalXDR() by pointer", ve.Type().Name())
err := marshalError("encode", ErrCustomUnaddressableByPointer, msg, nil, nil)
return nil, err
}
ptr := ve.Addr()
if ptr.CanInterface() {
iface := ve.Interface()
if u, ok := iface.(Marshaler); ok {
return u, nil
}
}
}

return nil, nil
}

// encode is the main workhorse for marshalling via reflection. It uses
// the passed reflection value to choose the XDR primitives to encode into
// the encapsulated writer and returns the number of bytes written. It is a
Expand All @@ -623,6 +661,19 @@ func (enc *Encoder) encode(ve reflect.Value) (int, error) {
return n, err
}

// Check for marshaler interface
if ve.Kind() != reflect.Ptr {
// For pointer types this will be checked again later on.
// (this is because pointers require encoding whether they are nil)
m, err := getMarshalerFromNonPtr(ve)
if err != nil {
return n, err
}
if m != nil {
return m.MarshalXDR(enc)
}
}

if ve.Kind() == reflect.Ptr {
if ve.IsNil() {
return enc.EncodeBool(false)
Expand All @@ -635,6 +686,15 @@ func (enc *Encoder) encode(ve reflect.Value) (int, error) {
return n, err
}

// We intentionally check for Unmarshaler before deferencing the pointer.
// This is because, for unaddressable values, you cannot recover the
// pointer later on, which would make impossible to invoke by-address marshalers.
if u := getMarshaler(ve); u != nil {
n2, err = u.MarshalXDR(enc)
n += n2
return n, err
}

n2, err = enc.encode(ve.Elem())
n += n2
return n, err
Expand Down Expand Up @@ -735,6 +795,10 @@ func (enc *Encoder) Encode(v interface{}) (int, error) {
return 0, err
}

if marshaler, ok := v.(Marshaler); ok {
return marshaler.MarshalXDR(enc)
}

vv := reflect.ValueOf(v)
vve := vv
for vve.Kind() == reflect.Ptr {
Expand All @@ -751,6 +815,11 @@ func (enc *Encoder) Encode(v interface{}) (int, error) {
return enc.encode(vve)
}

// Write writes to the internal writer. This method can be useful for implementing MarshalXDR
func (enc *Encoder) Write(p []byte) (int, error) {
return enc.w.Write(p)
}

// NewEncoder returns an object that can be used to manually choose fields to
// XDR encode to the passed writer w. Typically, Marshal should be used instead
// of manually creating an Encoder. An Encoder, along with several of its
Expand All @@ -760,3 +829,9 @@ func (enc *Encoder) Encode(v interface{}) (int, error) {
func NewEncoder(w io.Writer) *Encoder {
return &Encoder{w: w}
}

// Marshaler is the interface implemented by types that can marshal themselves into valid XDR.
// The supplied encoder should be used to perform the encoding.
type Marshaler interface {
MarshalXDR(e *Encoder) (int, error)
}
53 changes: 53 additions & 0 deletions xdr3/encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,49 @@ import (
. "github.com/stellar/go-xdr/xdr3"
)

type customEncodedType struct {
A uint8
}

const magicCustomEncoderValue = 0xfafb

func (c customEncodedType) MarshalXDR(e *Encoder) (int, error) {
// ignore the real value and output a magic number to make sure we were called
i, err := e.EncodeInt(int32(magicCustomEncoderValue))
return i, err
}

type customEncodedTypeByPointer struct {
A uint8
}

// make sure that interface implementations by pointer also work
func (c *customEncodedTypeByPointer) MarshalXDR(e *Encoder) (int, error) {
// ignore the real value and output a magic number to make sure we were called
i, err := e.EncodeInt(int32(magicCustomEncoderValue))
return i, err
}

type customEncodedSubType struct {
A uint8
B customEncodedType
}

type customEncodedSubTypeByPointer struct {
A uint8
B customEncodedTypeByPointer
}

type customEncodedSubTypeWithPointer struct {
A uint8
B *customEncodedType
}

type customEncodedSubTypeWithPointerByPointer struct {
A uint8
B *customEncodedTypeByPointer
}

// testExpectedMRet is a convenience method to test an expected number of bytes
// written and error for a marshal.
func testExpectedMRet(t *testing.T, name string, n, wantN int, err, wantErr error) bool {
Expand Down Expand Up @@ -307,6 +350,16 @@ func TestMarshal(t *testing.T) {
[]byte{0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01},
8, &MarshalError{ErrorCode: ErrIO}},

// Marshaler interface
{customEncodedType{255}, []byte{0x00, 0x00, 0xFA, 0xFB}, 4, nil},
{&customEncodedTypeByPointer{255}, []byte{0x00, 0x00, 0xFA, 0xFB}, 4, nil},
{customEncodedSubType{255, customEncodedType{16}}, []byte{0x00, 0x00, 0x00, 0xFF, 0x00, 0x00, 0xFA, 0xFB}, 8, nil},
{customEncodedSubTypeWithPointer{255, &customEncodedType{16}}, []byte{0x00, 0x00, 0x00, 0xFF, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0xFA, 0xFB}, 12, nil},
{customEncodedSubTypeWithPointerByPointer{255, &customEncodedTypeByPointer{16}}, []byte{0x00, 0x00, 0x00, 0xFF, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0xFA, 0xFB}, 12, nil},
// Expected failures
{customEncodedTypeByPointer{255}, []byte{}, 0, &MarshalError{ErrorCode: ErrCustomUnaddressableByPointer}},
{customEncodedSubTypeByPointer{255, customEncodedTypeByPointer{16}}, []byte{0x00, 0x00, 0x00, 0xFF}, 4, &MarshalError{ErrorCode: ErrCustomUnaddressableByPointer}},

// Expected errors
{nilInterface, []byte{}, 0, &MarshalError{ErrorCode: ErrNilInterface}},
{&nilInterface, []byte{}, 0, &MarshalError{ErrorCode: ErrNilInterface}},
Expand Down
Loading

0 comments on commit 535624d

Please sign in to comment.