Skip to content

Commit

Permalink
allownil: Allocate 0 length slices
Browse files Browse the repository at this point in the history
When `allownil` is enabled, always allocate zero length slices.

This ensures roundtrips with 0-length slices are not returned as nil.

Replaces tinylib#304

Adds tests. Bonus: Don't shell out to test issue 94.
  • Loading branch information
klauspost committed Oct 27, 2023
1 parent 4c45222 commit dac711c
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 12 deletions.
57 changes: 57 additions & 0 deletions _generated/allownil_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package _generated

import (
"bytes"
"reflect"
"testing"

"github.com/tinylib/msgp/msgp"
)

func TestAllownil(t *testing.T) {
tt := &NamedStructAN{
A: []string{},
B: nil,
}
var buf bytes.Buffer

err := msgp.Encode(&buf, tt)
if err != nil {
t.Fatal(err)
}
in := buf.Bytes()

for _, tnew := range []*NamedStructAN{{}, {A: []string{}}, {B: []string{}}} {
err = msgp.Decode(bytes.NewBuffer(in), tnew)
if err != nil {
t.Error(err)
}

if !reflect.DeepEqual(tt, tnew) {
t.Logf("in: %#v", tt)
t.Logf("out: %#v", tnew)
t.Fatal("objects not equal")
}
}

in, err = tt.MarshalMsg(nil)
if err != nil {
t.Fatal(err)
}
for _, tanother := range []*NamedStructAN{{}, {A: []string{}}, {B: []string{}}} {
var left []byte
left, err = tanother.UnmarshalMsg(in)
if err != nil {
t.Error(err)
}
if len(left) > 0 {
t.Errorf("%d bytes left", len(left))
}

if !reflect.DeepEqual(tt, tanother) {
t.Logf("in: %#v", tt)
t.Logf("out: %#v", tanother)
t.Fatal("objects not equal")
}
}
}
10 changes: 0 additions & 10 deletions _generated/issue94.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,6 @@ import (

//go:generate msgp

// Issue 94: shims were not propogated recursively,
// which caused shims that weren't at the top level
// to be silently ignored.
//
// The following line will generate an error after
// the code is generated if the generated code doesn't
// have the right identifier in it.

//go:generate ./search.sh $GOFILE timetostr

//msgp:shim time.Time as:string using:timetostr/strtotime
type T struct {
T time.Time
Expand Down
25 changes: 25 additions & 0 deletions _generated/issue94_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package _generated

import (
"bytes"
"os"
"testing"
)

// Issue 94: shims were not propogated recursively,
// which caused shims that weren't at the top level
// to be silently ignored.
//
// The following line will generate an error after
// the code is generated if the generated code doesn't
// have the right identifier in it.
func TestIssue94(t *testing.T) {
b, err := os.ReadFile("issue94_gen.go")
if err != nil {
t.Fatal(err)
}
const want = "timetostr"
if !bytes.Contains(b, []byte(want)) {
t.Errorf("generated code did not contain %q", want)
}
}
6 changes: 5 additions & 1 deletion gen/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,11 @@ func (d *decodeGen) gSlice(s *Slice) {
sz := randIdent()
d.p.declare(sz, u32)
d.assignAndCheck(sz, arrayHeader)
d.p.resizeSlice(sz, s)
if s.AllowNil() {
d.p.resizeSliceNoNil(sz, s)
} else {
d.p.resizeSlice(sz, s)
}
d.p.rangeBlock(d.ctx, s.Index, s.Varname(), d, s.Els)
}

Expand Down
7 changes: 7 additions & 0 deletions gen/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,13 @@ func (p *printer) resizeSlice(size string, s *Slice) {
p.printf("\nif cap(%[1]s) >= int(%[2]s) { %[1]s = (%[1]s)[:%[2]s] } else { %[1]s = make(%[3]s, %[2]s) }", s.Varname(), size, s.TypeName())
}

// resizeSliceNoNil will resize a slice and will not allow nil slices.
func (p *printer) resizeSliceNoNil(size string, s *Slice) {
p.printf("\nif %[1]s != nil && cap(%[1]s) >= int(%[2]s) {", s.Varname(), size)
p.printf("\n%[1]s = (%[1]s)[:%[2]s]", s.Varname(), size)
p.printf("\n} else { %[1]s = make(%[3]s, %[2]s) }", s.Varname(), size, s.TypeName())
}

func (p *printer) arrayCheck(want string, got string) {
p.printf("\nif %[1]s != %[2]s { err = msgp.ArrayError{Wanted: %[2]s, Got: %[1]s}; return }", got, want)
}
Expand Down
6 changes: 5 additions & 1 deletion gen/unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,11 @@ func (u *unmarshalGen) gSlice(s *Slice) {
sz := randIdent()
u.p.declare(sz, u32)
u.assignAndCheck(sz, arrayHeader)
u.p.resizeSlice(sz, s)
if s.AllowNil() {
u.p.resizeSliceNoNil(sz, s)
} else {
u.p.resizeSlice(sz, s)
}
u.p.rangeBlock(u.ctx, s.Index, s.Varname(), u, s.Els)
}

Expand Down

0 comments on commit dac711c

Please sign in to comment.