diff --git a/_generated/def.go b/_generated/def.go index 0360729a..a3d97b9e 100644 --- a/_generated/def.go +++ b/_generated/def.go @@ -278,3 +278,24 @@ type NonMsgStructTags struct { } type EmptyStruct struct{} + +type StructByteSlice struct { + ABytes []byte `msg:",allownil"` + AString []string `msg:",allownil"` + ABool []bool `msg:",allownil"` + AInt []int `msg:",allownil"` + AInt8 []int8 `msg:",allownil"` + AInt16 []int16 `msg:",allownil"` + AInt32 []int32 `msg:",allownil"` + AInt64 []int64 `msg:",allownil"` + AUint []uint `msg:",allownil"` + AUint8 []uint8 `msg:",allownil"` + AUint16 []uint16 `msg:",allownil"` + AUint32 []uint32 `msg:",allownil"` + AUint64 []uint64 `msg:",allownil"` + AFloat32 []float32 `msg:",allownil"` + AFloat64 []float64 `msg:",allownil"` + AComplex64 []complex64 `msg:",allownil"` + AComplex128 []complex128 `msg:",allownil"` + AStruct []Fixed `msg:",allownil"` +} diff --git a/_generated/gen_test.go b/_generated/gen_test.go index 8fa2b7a8..b85eaf64 100644 --- a/_generated/gen_test.go +++ b/_generated/gen_test.go @@ -165,3 +165,73 @@ func TestIssue168(t *testing.T) { t.Fatalf("got back %+v", test) } } + +func TestIssue362(t *testing.T) { + in := StructByteSlice{ + ABytes: make([]byte, 0), + AString: make([]string, 0), + ABool: make([]bool, 0), + AInt: make([]int, 0), + AInt8: make([]int8, 0), + AInt16: make([]int16, 0), + AInt32: make([]int32, 0), + AInt64: make([]int64, 0), + AUint: make([]uint, 0), + AUint8: make([]uint8, 0), + AUint16: make([]uint16, 0), + AUint32: make([]uint32, 0), + AUint64: make([]uint64, 0), + AFloat32: make([]float32, 0), + AFloat64: make([]float64, 0), + AComplex64: make([]complex64, 0), + AComplex128: make([]complex128, 0), + AStruct: make([]Fixed, 0), + } + + b, err := in.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + + var dst StructByteSlice + _, err = dst.UnmarshalMsg(b) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(in, dst) { + t.Fatalf("mismatch %#v != %#v", in, dst) + } + dst2 := StructByteSlice{} + dec := msgp.NewReader(bytes.NewReader(b)) + err = dst2.DecodeMsg(dec) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(in, dst2) { + t.Fatalf("mismatch %#v != %#v", in, dst2) + } + + // Encode with nil + zero := StructByteSlice{} + b, err = zero.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + // Decode into dst that now has values... + _, err = dst.UnmarshalMsg(b) + if err != nil { + t.Fatal(err) + } + // All should be nil now. + if !reflect.DeepEqual(zero, dst) { + t.Fatalf("mismatch %#v != %#v", zero, dst) + } + dec = msgp.NewReader(bytes.NewReader(b)) + err = dst2.DecodeMsg(dec) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(zero, dst2) { + t.Fatalf("mismatch %#v != %#v", zero, dst2) + } +} diff --git a/gen/decode.go b/gen/decode.go index 90e7cf4f..ad21173e 100644 --- a/gen/decode.go +++ b/gen/decode.go @@ -153,6 +153,7 @@ func (d *decodeGen) gBase(b *BaseElem) { vname := b.Varname() // e.g. "z.FieldOne" bname := b.BaseName() // e.g. "Float64" + checkNil := vname // Name of var to check for nil // handle special cases // for object type. @@ -161,8 +162,10 @@ func (d *decodeGen) gBase(b *BaseElem) { if b.Convert { lowered := b.ToBase() + "(" + vname + ")" d.p.printf("\n%s, err = dc.ReadBytes(%s)", tmp, lowered) + checkNil = tmp } else { d.p.printf("\n%s, err = dc.ReadBytes(%s)", vname, vname) + checkNil = vname } case IDENT: if b.Convert { @@ -182,6 +185,11 @@ func (d *decodeGen) gBase(b *BaseElem) { } d.p.wrapErrCheck(d.ctx.ArgsStr()) + if checkNil != "" && b.AllowNil() { + // Ensure that 0 sized slices are allocated. + d.p.printf("\nif %s == nil {\n%s = make([]byte, 0)\n}", checkNil, checkNil) + } + // close block for 'tmp' if b.Convert && b.Value != IDENT { if b.ShimMode == Cast { diff --git a/gen/elem.go b/gen/elem.go index d397cbed..3ace7764 100644 --- a/gen/elem.go +++ b/gen/elem.go @@ -554,6 +554,7 @@ type BaseElem struct { Convert bool // should we do an explicit conversion? mustinline bool // must inline; not printable needsref bool // needs reference for shim + allowNil *bool // Override from parent. } func (s *BaseElem) Printable() bool { return !s.mustinline } @@ -568,7 +569,17 @@ func (s *BaseElem) Alias(typ string) { } } -func (s *BaseElem) AllowNil() bool { return s.Value == Bytes } +func (s *BaseElem) AllowNil() bool { + if s.allowNil == nil { + return s.Value == Bytes + } + return *s.allowNil +} + +// SetIsAllowNil will override allownil when tag has been parsed. +func (s *BaseElem) SetIsAllowNil(b bool) { + s.allowNil = &b +} func (s *BaseElem) SetVarname(a string) { // extensions whose parents diff --git a/gen/unmarshal.go b/gen/unmarshal.go index c2ef89bd..75f8a467 100644 --- a/gen/unmarshal.go +++ b/gen/unmarshal.go @@ -161,6 +161,11 @@ func (u *unmarshalGen) gBase(b *BaseElem) { } u.p.wrapErrCheck(u.ctx.ArgsStr()) + if b.Value == Bytes && b.AllowNil() { + // Ensure that 0 sized slices are allocated. + u.p.printf("\nif %s == nil {\n%s = make([]byte, 0)\n}", refname, refname) + } + // close 'tmp' block if b.Convert && b.Value != IDENT { if b.ShimMode == Cast {