Skip to content

Commit

Permalink
Fix decoding union into a previously used object (#14)
Browse files Browse the repository at this point in the history
This commit fixes a bug in `Decoder.decodeUnion`. When decoding a union
previously used object (ex. in a loop with object created outside of the
loop) the previous arm value was not zeroed (set to `nil`) if it was a
different union type. See `TestDecodeUnionIntoExistingObject` for more
context.
  • Loading branch information
bartekn authored Oct 28, 2020
1 parent 71a1e6d commit f0e124a
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 0 deletions.
12 changes: 12 additions & 0 deletions xdr3/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -509,12 +509,24 @@ func (d *Decoder) decodeArray(v reflect.Value, ignoreOpaque bool, maxSize int) (
return n, nil
}

func setUnionArmsToNil(v reflect.Value) {
for i := 0; i < v.NumField(); i++ {
f := v.Field(i)
if f.Kind() != reflect.Ptr {
continue
}
v.Set(reflect.Zero(v.Type()))
}
}

// decodeUnion
func (d *Decoder) decodeUnion(v reflect.Value) (int, error) {
// we should have already checked that v is a union
// prior to this call, so we panic if v is not a union
u := v.Interface().(Union)

setUnionArmsToNil(v)

i, n, err := d.DecodeInt()
if err != nil {
return n, err
Expand Down
44 changes: 44 additions & 0 deletions xdr3/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1053,3 +1053,47 @@ func TestPaddedReads(t *testing.T) {
t.Error("expected error when unmarshaling varopaque with non-zero padding byte, got none")
}
}

func TestDecodeUnionIntoExistingObject(t *testing.T) {
var buf bytes.Buffer
var idata int32 = 1
sdata := "data"
_, err := Marshal(&buf, aUnion{
Type: 0,
Data: &idata,
})
if err != nil {
t.Error("unexpected error")
}

var s aUnion
_, err = Unmarshal(&buf, &s)
if err != nil {
t.Error("unexpected error")
}

_, err = Marshal(&buf, aUnion{
Type: 1,
Text: &sdata,
})
if err != nil {
t.Error("unexpected error")
}

_, err = Unmarshal(&buf, &s)
if err != nil {
t.Error("unexpected error")
}

if s.Data != nil {
t.Error("Data should be nil")
}

if s.Type != 1 {
t.Error("Type does not match")
}

if *s.Text != sdata {
t.Error("Text does not match")
}
}

0 comments on commit f0e124a

Please sign in to comment.