Skip to content

Commit

Permalink
Add UnmarshalPrefix
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Andersen <[email protected]>
  • Loading branch information
immesys committed Apr 28, 2023
1 parent 6dfffb1 commit 5f79917
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 0 deletions.
49 changes: 49 additions & 0 deletions bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package cbor

import (
"bytes"
"crypto/rand"
"io"
"reflect"
"testing"
Expand Down Expand Up @@ -211,6 +212,54 @@ func BenchmarkUnmarshal(b *testing.B) {
}
}

func BenchmarkUnmarshalPrefix(b *testing.B) {
trailingData := make([]byte, 12)
rand.Read(trailingData)
for _, bm := range decodeBenchmarks {
for _, t := range bm.decodeToTypes {
name := "CBOR " + bm.name + " to Go " + t.String()
if t.Kind() == reflect.Struct {
name = "CBOR " + bm.name + " to Go " + t.Kind().String()
}
data := make([]byte, 0, len(bm.cborData)+len(trailingData))
data = append(data, bm.cborData...)
data = append(data, trailingData...)
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
vPtr := reflect.New(t).Interface()
if _, err := UnmarshalPrefix(data, vPtr); err != nil {
b.Fatal("UnmarshalPrefix:", err)
}
}
})
}
}
}

func BenchmarkUnmarshalPrefixViaDecoder(b *testing.B) {
trailingData := make([]byte, 12)
rand.Read(trailingData)
for _, bm := range decodeBenchmarks {
for _, t := range bm.decodeToTypes {
name := "CBOR " + bm.name + " to Go " + t.String()
if t.Kind() == reflect.Struct {
name = "CBOR " + bm.name + " to Go " + t.Kind().String()
}
data := make([]byte, 0, len(bm.cborData)+len(trailingData))
data = append(data, bm.cborData...)
data = append(data, trailingData...)
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
vPtr := reflect.New(t).Interface()
if err := NewDecoder(bytes.NewReader(data)).Decode(vPtr); err != nil {
b.Fatal("UnmarshalDecoder:", err)
}
}
})
}
}
}

func BenchmarkDecode(b *testing.B) {
for _, bm := range decodeBenchmarks {
for _, t := range bm.decodeToTypes {
Expand Down
42 changes: 42 additions & 0 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,23 @@ import (
//
// Unmarshal supports CBOR tag 55799 (self-describe CBOR), tag 0 and 1 (time),
// and tag 2 and 3 (bignum).
//
// Unmarshal will return an ExtraneousDataError error if, upon decoding a valid
// message, there are remaining bytes at the end of the slice. See UnmarshalPrefix
// if you expect data to contain more than one CBOR-encoded item.
func Unmarshal(data []byte, v interface{}) error {
return defaultDecMode.Unmarshal(data, v)
}

// UnmarshalPrefix will unmarshal the first CBOR-encoded item found in data, in the
// same way as Unmarshal. Any remaining bytes after a single valid item
// has been parsed will be returned in rest.
//
// See Unmarshal for more information
func UnmarshalPrefix(data []byte, v interface{}) (rest []byte, err error) {
return defaultDecMode.UnmarshalPrefix(data, v)
}

// Valid checks whether the CBOR data is complete and well-formed.
func Valid(data []byte) error {
return defaultDecMode.Valid(data)
Expand Down Expand Up @@ -561,6 +574,35 @@ func (dm *decMode) Unmarshal(data []byte, v interface{}) error {
return d.value(v)
}

// UnmarshalPrefix parses the first CBOR-encoded item in data into the value
// pointed to by v using dm decoding mode. If v is nil, not a pointer, or a nil pointer,
// UnmarshalPrefix returns an error. Any remaining bytes after a single valid item
// has been parsed will be returned in rest.
//
// See the documentation for UnmarshalPrefix for details.
func (dm *decMode) UnmarshalPrefix(data []byte, v interface{}) (rest []byte, err error) {
d := decoder{data: data, dm: dm}

// check valid
off := d.off // Save offset before data validation
err = d.valid(true) // allow extra data after valid data item
d.off = off // Restore offset

// If it is valid, parse the value. This is structured like this to allow
// better test coverage
if err == nil {
err = d.value(v)
}

// If either valid or value returned an error, do not return rest bytes
if err != nil {
return nil, err
}

// Return the rest of the data slice (which might be len 0)
return data[d.off:], nil
}

// Valid checks whether the CBOR data is complete and well-formed.
func (dm *decMode) Valid(data []byte) error {
d := decoder{data: data, dm: dm}
Expand Down
58 changes: 58 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package cbor

import (
"bytes"
"crypto/rand"
"encoding/binary"
"encoding/hex"
"errors"
Expand Down Expand Up @@ -5981,3 +5982,60 @@ func TestUnmarshalToDefaultMapType(t *testing.T) {
})
}
}

func TestUnmarshalPrefixNoTrailing(t *testing.T) {
for _, tc := range unmarshalTests {
var v interface{}
if rest, err := UnmarshalPrefix(tc.cborData, &v); err != nil {
t.Errorf("UnmarshalPrefix(0x%x) returned error %v", tc.cborData, err)
} else {
if len(rest) != 0 {
t.Errorf("UnmarshalPrefix(0x%x) returned rest %x (want [])", tc.cborData, rest)
}
// Check the value as well, although this is covered by other tests
if tm, ok := tc.emptyInterfaceValue.(time.Time); ok {
if vt, ok := v.(time.Time); !ok || !tm.Equal(vt) {
t.Errorf("Unmarshal(0x%x) = %v (%T), want %v (%T)", tc.cborData, v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue)
}
} else if !reflect.DeepEqual(v, tc.emptyInterfaceValue) {
t.Errorf("Unmarshal(0x%x) = %v (%T), want %v (%T)", tc.cborData, v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue)
}
}
}
}

func TestUnmarshalPrefixTrailing(t *testing.T) {
trailingData := make([]byte, 12)
rand.Read(trailingData)
for _, tc := range unmarshalTests {
data := make([]byte, 0, len(tc.cborData)+len(trailingData))
data = append(data, tc.cborData...)
data = append(data, trailingData...)
var v interface{}
if rest, err := UnmarshalPrefix(data, &v); err != nil {
t.Errorf("UnmarshalPrefix(0x%x) returned error %v", data, err)
} else {
if !bytes.Equal(trailingData, rest) {
t.Errorf("UnmarshalPrefix(0x%x) returned rest %x (want %x)", data, rest, trailingData)
}
// Check the value as well, although this is covered by other tests
if tm, ok := tc.emptyInterfaceValue.(time.Time); ok {
if vt, ok := v.(time.Time); !ok || !tm.Equal(vt) {
t.Errorf("UnmarshalPrefix(0x%x) = %v (%T), want %v (%T)", data, v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue)
}
} else if !reflect.DeepEqual(v, tc.emptyInterfaceValue) {
t.Errorf("UnmarshalPrefix(0x%x) = %v (%T), want %v (%T)", data, v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue)
}
}
}
}

func TestUnmarshalPrefixInvalidItem(t *testing.T) {
// UnmarshalPrefix should not return "rest" if the item was not well-formed
invalidCBOR := hexDecode("83FF20030102")
var v interface{}
rest, err := UnmarshalPrefix(invalidCBOR, &v)
if rest != nil {
t.Errorf("UnmarshalPrefix(0x%x) = (%x, %v), want (nil, err)", invalidCBOR, rest, err)
}
}

0 comments on commit 5f79917

Please sign in to comment.