Skip to content

Commit

Permalink
Fix decoding nil pointer to the previously used object (stellar#13)
Browse files Browse the repository at this point in the history
This commit fixes a bug in `Decoder.decodePtr`. When decoding a struct
with a pointer field into previously used object (ex. in a loop with
object created outside of the loop) the previous non-`nil` pointer value
was not overwritten if the new value was `nil`. See
`TestDecodeNilPointerIntoExistingObjectWithNotNilPointer` for more
context.
  • Loading branch information
bartekn authored Oct 28, 2020
1 parent f0e124a commit f80a23d
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 1 deletion.
27 changes: 26 additions & 1 deletion xdr3/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,26 @@ func (d *Decoder) decode(ve reflect.Value, maxSize int) (int, error) {
return 0, err
}

func setPtrToNil(v *reflect.Value) error {
if v.Kind() != reflect.Ptr {
msg := fmt.Sprintf("value is not a pointer: '%v'",
v.Type().String())
err := unmarshalError("decodePtr", ErrBadArguments, msg,
nil, nil)
return err
}
if !v.CanSet() {
msg := fmt.Sprintf("pointer value cannot be changed for '%v'",
v.Type().String())
err := unmarshalError("decodePtr", ErrNotSettable, msg,
nil, nil)
return err
}

v.Set(reflect.Zero(v.Type()))
return nil
}

func allocPtrIfNil(v *reflect.Value) error {
if v.Kind() != reflect.Ptr {
msg := fmt.Sprintf("value is not a pointer: '%v'",
Expand Down Expand Up @@ -969,7 +989,12 @@ func (d *Decoder) decodePtr(v reflect.Value) (int, error) {

present, n, err := d.DecodeBool()

if err != nil || !present {
if err != nil {
return n, err
}

if !present {
err = setPtrToNil(&v)
return n, err
}

Expand Down
36 changes: 36 additions & 0 deletions xdr3/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ func (u aUnion) ArmForSwitch(sw int32) (string, bool) {
return "-", false
}

type structWithPointer struct {
Data *string
}

// testExpectedURet is a convenience method to test an expected number of bytes
// read and error for an unmarshal.
func testExpectedURet(t *testing.T, name string, n, wantN int, err, wantErr error) bool {
Expand Down Expand Up @@ -1054,6 +1058,38 @@ func TestPaddedReads(t *testing.T) {
}
}

func TestDecodeNilPointerIntoExistingObjectWithNotNilPointer(t *testing.T) {
var buf bytes.Buffer
data := "data"
_, err := Marshal(&buf, structWithPointer{Data: &data})
if err != nil {
t.Error("unexpected error")
}

var s structWithPointer
_, err = Unmarshal(&buf, &s)
if err != nil {
t.Error("unexpected error")
}

// Note:
// 1. structWithPointer.Data is nil.
// 2. We unmarshal into previously used object.
_, err = Marshal(&buf, structWithPointer{})
if err != nil {
t.Error("unexpected error")
}

_, err = Unmarshal(&buf, &s)
if err != nil {
t.Error("unexpected error")
}

if s.Data != nil {
t.Error("Data should be nil")
}
}

func TestDecodeUnionIntoExistingObject(t *testing.T) {
var buf bytes.Buffer
var idata int32 = 1
Expand Down

0 comments on commit f80a23d

Please sign in to comment.