diff --git a/xdr3/decode.go b/xdr3/decode.go index 80a1d9b..e3627df 100644 --- a/xdr3/decode.go +++ b/xdr3/decode.go @@ -941,6 +941,26 @@ func (d *Decoder) decode(ve reflect.Value, maxSize int) (int, error) { return 0, err } +func setPtrToNil(v *reflect.Value) error { + if v.Kind() != reflect.Ptr { + msg := fmt.Sprintf("value is not a pointer: '%v'", + v.Type().String()) + err := unmarshalError("decodePtr", ErrBadArguments, msg, + nil, nil) + return err + } + if !v.CanSet() { + msg := fmt.Sprintf("pointer value cannot be changed for '%v'", + v.Type().String()) + err := unmarshalError("decodePtr", ErrNotSettable, msg, + nil, nil) + return err + } + + v.Set(reflect.Zero(v.Type())) + return nil +} + func allocPtrIfNil(v *reflect.Value) error { if v.Kind() != reflect.Ptr { msg := fmt.Sprintf("value is not a pointer: '%v'", @@ -969,7 +989,12 @@ func (d *Decoder) decodePtr(v reflect.Value) (int, error) { present, n, err := d.DecodeBool() - if err != nil || !present { + if err != nil { + return n, err + } + + if !present { + err = setPtrToNil(&v) return n, err } diff --git a/xdr3/decode_test.go b/xdr3/decode_test.go index 281c88d..bc575f6 100644 --- a/xdr3/decode_test.go +++ b/xdr3/decode_test.go @@ -93,6 +93,10 @@ func (u aUnion) ArmForSwitch(sw int32) (string, bool) { return "-", false } +type structWithPointer struct { + Data *string +} + // testExpectedURet is a convenience method to test an expected number of bytes // read and error for an unmarshal. func testExpectedURet(t *testing.T, name string, n, wantN int, err, wantErr error) bool { @@ -1054,6 +1058,38 @@ func TestPaddedReads(t *testing.T) { } } +func TestDecodeNilPointerIntoExistingObjectWithNotNilPointer(t *testing.T) { + var buf bytes.Buffer + data := "data" + _, err := Marshal(&buf, structWithPointer{Data: &data}) + if err != nil { + t.Error("unexpected error") + } + + var s structWithPointer + _, err = Unmarshal(&buf, &s) + if err != nil { + t.Error("unexpected error") + } + + // Note: + // 1. structWithPointer.Data is nil. + // 2. We unmarshal into previously used object. + _, err = Marshal(&buf, structWithPointer{}) + if err != nil { + t.Error("unexpected error") + } + + _, err = Unmarshal(&buf, &s) + if err != nil { + t.Error("unexpected error") + } + + if s.Data != nil { + t.Error("Data should be nil") + } +} + func TestDecodeUnionIntoExistingObject(t *testing.T) { var buf bytes.Buffer var idata int32 = 1