diff --git a/xdr3/decode.go b/xdr3/decode.go index 7608af6..be4534b 100644 --- a/xdr3/decode.go +++ b/xdr3/decode.go @@ -27,8 +27,11 @@ import ( const maxInt32 = int(^uint32(0) >> 1) -var errMaxSlice = "data exceeds max slice limit" -var errIODecode = "%s while decoding %d bytes" +var ( + errMaxSlice = "data exceeds max slice limit" + errIODecode = "%s while decoding %d bytes" + unmarshalerType = reflect.TypeOf((*Unmarshaler)(nil)).Elem() +) /* Unmarshal parses XDR-encoded data into the value pointed to by v reading from @@ -74,6 +77,9 @@ UnmarshalError is returned with a human readable description as well as an ErrorCode value for further inspection from sophisticated callers. Some potential issues are unsupported Go types, attempting to decode a value which is too large to fit into a specified Go type, and exceeding max slice limitations. + +If any of the types encountered comply with the Unmarshaler interface, the UnmarshalXDR +method will be used instead of reflection (similarly to UnmarshalJSON in the encoding/json package) */ func Unmarshal(r io.Reader, v interface{}) (int, error) { d := Decoder{r: r} @@ -770,6 +776,37 @@ func (d *Decoder) decodeInterface(v reflect.Value) (int, error) { return d.decode(ve, 0) } +func getUnmarshaler(ve reflect.Value) Unmarshaler { + // Check the type first to avoid an unnecessary allocation when casting to interface + if ve.Type().Implements(unmarshalerType) && ve.CanInterface() { + return ve.Interface().(Unmarshaler) + } + return nil +} + +func getUnmarshalerFromNonPtr(ve reflect.Value) (Unmarshaler, error) { + if u := getUnmarshaler(ve); u != nil { + return u, nil + } + + // Implements Unmarshaller by address? + // Check the type first, to avoid an unnecessary allocation when casting to interface + // TODO: this check is redundant if decoding got here through a pointer type + if reflect.PtrTo(ve.Type()).Implements(unmarshalerType) { + if !ve.CanAddr() { + msg := fmt.Sprintf("unaddressable value of %s implementing MarshalXDR() by pointer", ve.Type().Name()) + err := marshalError("encode", ErrCustomUnaddressableByPointer, msg, nil, nil) + return nil, err + } + ptr := ve.Addr() + if ptr.CanInterface() { + return ptr.Interface().(Unmarshaler), nil + } + } + + return nil, nil +} + // decode is the main workhorse for unmarshalling via reflection. It uses // the passed reflection value to choose the XDR primitives to decode from // the encapsulated reader. It is a recursive function, @@ -782,6 +819,19 @@ func (d *Decoder) decode(ve reflect.Value, maxSize int) (int, error) { return 0, err } + // Check for Unmarshaler interface + if ve.Kind() != reflect.Ptr { + // For pointer types this will be checked again later on. + // (this is because pointers require decoding whether they are nil) + u, err := getUnmarshalerFromNonPtr(ve) + if err != nil { + return 0, err + } + if u != nil { + return u.UnmarshalXDR(d) + } + } + // Handle time.Time values by decoding them as an RFC3339 formatted // string with nanosecond precision. Check the type string rather // than doing a full blown conversion to interface and type assertion @@ -1013,6 +1063,13 @@ func (d *Decoder) decodePtr(v reflect.Value) (int, error) { return n, err } + // We intentionally check for Marshaler before deferencing the pointer. + // This is because, for unaddressable values, you cannot recover the + // pointer later on, which would make impossible to invoke by-address marshalers. + if u := getUnmarshaler(v); u != nil { + n2, err := u.UnmarshalXDR(d) + return n + n2, err + } n2, err := d.decode(v.Elem(), 0) return n + n2, err } @@ -1038,6 +1095,10 @@ func (d *Decoder) Decode(v interface{}) (int, error) { nil) } + if unmarshaler, ok := v.(Unmarshaler); ok { + return unmarshaler.UnmarshalXDR(d) + } + vv := reflect.ValueOf(v) if vv.Kind() != reflect.Ptr { msg := fmt.Sprintf("can't unmarshal to non-pointer '%v' - use "+ @@ -1055,9 +1116,20 @@ func (d *Decoder) Decode(v interface{}) (int, error) { return d.decode(vv.Elem(), 0) } +// Read reads from the internal reader. This method can be useful for implementing UnmarshalXDR +func (d *Decoder) Read(p []byte) (int, error) { + return d.r.Read(p) +} + // NewDecoder returns a Decoder that can be used to manually decode XDR data // from a provided reader. Typically, Unmarshal should be used instead of // manually creating a Decoder. func NewDecoder(r io.Reader) *Decoder { return &Decoder{r: r} } + +// Unmarshaler is the interface implemented by types that can unmarshal themselves from valid XDR. +// The supplied decoder should be used to perform the decoding. +type Unmarshaler interface { + UnmarshalXDR(e *Decoder) (int, error) +} diff --git a/xdr3/decode_test.go b/xdr3/decode_test.go index bc575f6..2f5ec1f 100644 --- a/xdr3/decode_test.go +++ b/xdr3/decode_test.go @@ -34,6 +34,51 @@ type subTest struct { B uint8 } +type customDecodedType struct { + A uint8 +} + +const magicCustomDecoderCount = 123456 + +func (c *customDecodedType) UnmarshalXDR(e *Decoder) (int, error) { + i, _, err := e.DecodeInt() + c.A = uint8(i) + // return magic number to indicate that the custom unmarshalling code was called + return magicCustomDecoderCount, err +} + +type customDecodedTypeByValue struct { + A uint8 +} + +// make sure that (odd) method implementations by value also work +func (c customDecodedTypeByValue) UnmarshalXDR(e *Decoder) (int, error) { + // there is no point in consuming the input since it can be saved (the method invocation is by value) + i, _, err := e.DecodeInt() + // return magic number to indicate that the custom unmarshalling code was called + return magicCustomDecoderCount + int(i), err +} + +type customDecodedSubType struct { + A uint8 + B customDecodedType +} + +type customDecodedSubTypeByValue struct { + A uint8 + B customDecodedTypeByValue +} + +type customDecodedSubTypeWithPointer struct { + A uint8 + B *customDecodedType +} + +type customDecodedSubTypeByValueWithPointer struct { + A uint8 + B *customDecodedTypeByValue +} + // allTypesTest is used to allow testing of the Unmarshal function into struct // fields of all supported types. type allTypesTest struct { @@ -388,6 +433,13 @@ func TestUnmarshal(t *testing.T) { {[]byte{0x00, 0x00, 0x00}, opaqueStruct{}, 3, &UnmarshalError{ErrorCode: ErrIO}}, {[]byte{0x00, 0x00, 0x00, 0x00, 0x00}, opaqueStruct{}, 5, &UnmarshalError{ErrorCode: ErrIO}}, + // Unmarshaler interface + {[]byte{0x00, 0x00, 0x00, 0xFF}, customDecodedType{255}, magicCustomDecoderCount, nil}, + {[]byte{0x00, 0x00, 0x00, 0xFF, 0x00, 0x00, 0x00, 0x10}, customDecodedSubType{255, customDecodedType{16}}, magicCustomDecoderCount + 4, nil}, + {[]byte{0x00, 0x00, 0x00, 0xFF, 0x00, 0x00, 0x00, 0x10}, customDecodedSubTypeByValue{255, customDecodedTypeByValue{0}}, (magicCustomDecoderCount + 16) + 4, nil}, + {[]byte{0x00, 0x00, 0x00, 0xFF, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10}, customDecodedSubTypeWithPointer{255, &customDecodedType{16}}, magicCustomDecoderCount + 4 + 4, nil}, + {[]byte{0x00, 0x00, 0x00, 0xFF, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10}, customDecodedSubTypeByValueWithPointer{255, &customDecodedTypeByValue{0}}, (magicCustomDecoderCount + 16) + 4 + 4, nil}, + // Expected errors {nil, nilInterface, 0, &UnmarshalError{ErrorCode: ErrNilInterface}}, {nil, &nilInterface, 0, &UnmarshalError{ErrorCode: ErrIO}}, diff --git a/xdr3/encode.go b/xdr3/encode.go index 32a67cc..488c6eb 100644 --- a/xdr3/encode.go +++ b/xdr3/encode.go @@ -24,7 +24,10 @@ import ( "time" ) -var errIOEncode = "%s while encoding %d bytes" +var ( + errIOEncode = "%s while encoding %d bytes" + marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem() +) /* Marshal writes the XDR encoding of v to writer w and returns the number of bytes @@ -71,6 +74,9 @@ returned with a human readable description as well as an ErrorCode value for further inspection from sophisticated callers. Some potential issues are unsupported Go types, attempting to encode more opaque data than can be represented by a single opaque XDR entry, and exceeding max slice limitations. + +If any of the types encountered comply with the Marshaler interface, the MarshalXDR +method will be used instead of reflection (similarly to MarshalJSON in the encoding/json package) */ func Marshal(w io.Writer, v interface{}) (int, error) { enc := Encoder{w: w} @@ -609,6 +615,38 @@ func (enc *Encoder) encodeInterface(v reflect.Value) (int, error) { return enc.encode(ve) } +func getMarshaler(ve reflect.Value) Marshaler { + if ve.Type().Implements(marshalerType) && ve.CanInterface() { + return ve.Interface().(Marshaler) + } + return nil +} + +func getMarshalerFromNonPtr(ve reflect.Value) (Marshaler, error) { + if m := getMarshaler(ve); m != nil { + return m, nil + } + + // Implements it by address? + // TODO: this check is redundant if encoding got here through a pointer type + if reflect.PtrTo(ve.Type()).Implements(marshalerType) { + if !ve.CanAddr() { + msg := fmt.Sprintf("unaddressable value of %s implementing MarshalXDR() by pointer", ve.Type().Name()) + err := marshalError("encode", ErrCustomUnaddressableByPointer, msg, nil, nil) + return nil, err + } + ptr := ve.Addr() + if ptr.CanInterface() { + iface := ve.Interface() + if u, ok := iface.(Marshaler); ok { + return u, nil + } + } + } + + return nil, nil +} + // encode is the main workhorse for marshalling via reflection. It uses // the passed reflection value to choose the XDR primitives to encode into // the encapsulated writer and returns the number of bytes written. It is a @@ -623,6 +661,19 @@ func (enc *Encoder) encode(ve reflect.Value) (int, error) { return n, err } + // Check for marshaler interface + if ve.Kind() != reflect.Ptr { + // For pointer types this will be checked again later on. + // (this is because pointers require encoding whether they are nil) + m, err := getMarshalerFromNonPtr(ve) + if err != nil { + return n, err + } + if m != nil { + return m.MarshalXDR(enc) + } + } + if ve.Kind() == reflect.Ptr { if ve.IsNil() { return enc.EncodeBool(false) @@ -635,6 +686,15 @@ func (enc *Encoder) encode(ve reflect.Value) (int, error) { return n, err } + // We intentionally check for Unmarshaler before deferencing the pointer. + // This is because, for unaddressable values, you cannot recover the + // pointer later on, which would make impossible to invoke by-address marshalers. + if u := getMarshaler(ve); u != nil { + n2, err = u.MarshalXDR(enc) + n += n2 + return n, err + } + n2, err = enc.encode(ve.Elem()) n += n2 return n, err @@ -735,6 +795,10 @@ func (enc *Encoder) Encode(v interface{}) (int, error) { return 0, err } + if marshaler, ok := v.(Marshaler); ok { + return marshaler.MarshalXDR(enc) + } + vv := reflect.ValueOf(v) vve := vv for vve.Kind() == reflect.Ptr { @@ -751,6 +815,11 @@ func (enc *Encoder) Encode(v interface{}) (int, error) { return enc.encode(vve) } +// Write writes to the internal writer. This method can be useful for implementing MarshalXDR +func (enc *Encoder) Write(p []byte) (int, error) { + return enc.w.Write(p) +} + // NewEncoder returns an object that can be used to manually choose fields to // XDR encode to the passed writer w. Typically, Marshal should be used instead // of manually creating an Encoder. An Encoder, along with several of its @@ -760,3 +829,9 @@ func (enc *Encoder) Encode(v interface{}) (int, error) { func NewEncoder(w io.Writer) *Encoder { return &Encoder{w: w} } + +// Marshaler is the interface implemented by types that can marshal themselves into valid XDR. +// The supplied encoder should be used to perform the encoding. +type Marshaler interface { + MarshalXDR(e *Encoder) (int, error) +} diff --git a/xdr3/encode_test.go b/xdr3/encode_test.go index bda0873..a235f0f 100644 --- a/xdr3/encode_test.go +++ b/xdr3/encode_test.go @@ -26,6 +26,49 @@ import ( . "github.com/stellar/go-xdr/xdr3" ) +type customEncodedType struct { + A uint8 +} + +const magicCustomEncoderValue = 0xfafb + +func (c customEncodedType) MarshalXDR(e *Encoder) (int, error) { + // ignore the real value and output a magic number to make sure we were called + i, err := e.EncodeInt(int32(magicCustomEncoderValue)) + return i, err +} + +type customEncodedTypeByPointer struct { + A uint8 +} + +// make sure that interface implementations by pointer also work +func (c *customEncodedTypeByPointer) MarshalXDR(e *Encoder) (int, error) { + // ignore the real value and output a magic number to make sure we were called + i, err := e.EncodeInt(int32(magicCustomEncoderValue)) + return i, err +} + +type customEncodedSubType struct { + A uint8 + B customEncodedType +} + +type customEncodedSubTypeByPointer struct { + A uint8 + B customEncodedTypeByPointer +} + +type customEncodedSubTypeWithPointer struct { + A uint8 + B *customEncodedType +} + +type customEncodedSubTypeWithPointerByPointer struct { + A uint8 + B *customEncodedTypeByPointer +} + // testExpectedMRet is a convenience method to test an expected number of bytes // written and error for a marshal. func testExpectedMRet(t *testing.T, name string, n, wantN int, err, wantErr error) bool { @@ -307,6 +350,16 @@ func TestMarshal(t *testing.T) { []byte{0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01}, 8, &MarshalError{ErrorCode: ErrIO}}, + // Marshaler interface + {customEncodedType{255}, []byte{0x00, 0x00, 0xFA, 0xFB}, 4, nil}, + {&customEncodedTypeByPointer{255}, []byte{0x00, 0x00, 0xFA, 0xFB}, 4, nil}, + {customEncodedSubType{255, customEncodedType{16}}, []byte{0x00, 0x00, 0x00, 0xFF, 0x00, 0x00, 0xFA, 0xFB}, 8, nil}, + {customEncodedSubTypeWithPointer{255, &customEncodedType{16}}, []byte{0x00, 0x00, 0x00, 0xFF, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0xFA, 0xFB}, 12, nil}, + {customEncodedSubTypeWithPointerByPointer{255, &customEncodedTypeByPointer{16}}, []byte{0x00, 0x00, 0x00, 0xFF, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0xFA, 0xFB}, 12, nil}, + // Expected failures + {customEncodedTypeByPointer{255}, []byte{}, 0, &MarshalError{ErrorCode: ErrCustomUnaddressableByPointer}}, + {customEncodedSubTypeByPointer{255, customEncodedTypeByPointer{16}}, []byte{0x00, 0x00, 0x00, 0xFF}, 4, &MarshalError{ErrorCode: ErrCustomUnaddressableByPointer}}, + // Expected errors {nilInterface, []byte{}, 0, &MarshalError{ErrorCode: ErrNilInterface}}, {&nilInterface, []byte{}, 0, &MarshalError{ErrorCode: ErrNilInterface}}, diff --git a/xdr3/error.go b/xdr3/error.go index 53ffbd5..cfe1095 100644 --- a/xdr3/error.go +++ b/xdr3/error.go @@ -70,20 +70,25 @@ const ( // ErrBadUnionValue indicates a union's value is not populated when it should // be ErrBadUnionValue + + // ErrCustomUnaddressableByPointer indicates that an unaddresable value that implements Marshaler/Unmarshaler + // by pointer was passed and, as a result, the custom marshaler/unmarshaler cannot be invoked + ErrCustomUnaddressableByPointer ) // Map of ErrorCode values back to their constant names for pretty printing. var errorCodeStrings = map[ErrorCode]string{ - ErrBadArguments: "ErrBadArguments", - ErrUnsupportedType: "ErrUnsupportedType", - ErrBadEnumValue: "ErrBadEnumValue", - ErrNotSettable: "ErrNotSettable", - ErrOverflow: "ErrOverflow", - ErrNilInterface: "ErrNilInterface", - ErrIO: "ErrIO", - ErrParseTime: "ErrParseTime", - ErrBadUnionSwitch: "ErrBadUnionSwitch", - ErrBadUnionValue: "ErrBadUnionValue", + ErrBadArguments: "ErrBadArguments", + ErrUnsupportedType: "ErrUnsupportedType", + ErrBadEnumValue: "ErrBadEnumValue", + ErrNotSettable: "ErrNotSettable", + ErrOverflow: "ErrOverflow", + ErrNilInterface: "ErrNilInterface", + ErrIO: "ErrIO", + ErrParseTime: "ErrParseTime", + ErrBadUnionSwitch: "ErrBadUnionSwitch", + ErrBadUnionValue: "ErrBadUnionValue", + ErrCustomUnaddressableByPointer: "ErrCustomUnaddressableByPointer", } // String returns the ErrorCode as a human-readable name.