Skip to content

Commit

Permalink
Merge pull request #398 from immesys/immesys/unmarshal-prefix
Browse files Browse the repository at this point in the history
Add UnmarshalFirst
  • Loading branch information
fxamacker authored May 6, 2023
2 parents 3343ed2 + d3a8ced commit 21a6738
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 0 deletions.
48 changes: 48 additions & 0 deletions bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,54 @@ func BenchmarkUnmarshal(b *testing.B) {
}
}

func BenchmarkUnmarshalFirst(b *testing.B) {
// Random trailing data
trailingData := hexDecode("4a6b0f4718c73f391091ea1c")
for _, bm := range decodeBenchmarks {
for _, t := range bm.decodeToTypes {
name := "CBOR " + bm.name + " to Go " + t.String()
if t.Kind() == reflect.Struct {
name = "CBOR " + bm.name + " to Go " + t.Kind().String()
}
data := make([]byte, 0, len(bm.cborData)+len(trailingData))
data = append(data, bm.cborData...)
data = append(data, trailingData...)
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
vPtr := reflect.New(t).Interface()
if _, err := UnmarshalFirst(data, vPtr); err != nil {
b.Fatal("UnmarshalFirst:", err)
}
}
})
}
}
}

func BenchmarkUnmarshalFirstViaDecoder(b *testing.B) {
// Random trailing data
trailingData := hexDecode("4a6b0f4718c73f391091ea1c")
for _, bm := range decodeBenchmarks {
for _, t := range bm.decodeToTypes {
name := "CBOR " + bm.name + " to Go " + t.String()
if t.Kind() == reflect.Struct {
name = "CBOR " + bm.name + " to Go " + t.Kind().String()
}
data := make([]byte, 0, len(bm.cborData)+len(trailingData))
data = append(data, bm.cborData...)
data = append(data, trailingData...)
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
vPtr := reflect.New(t).Interface()
if err := NewDecoder(bytes.NewReader(data)).Decode(vPtr); err != nil {
b.Fatal("UnmarshalDecoder:", err)
}
}
})
}
}
}

func BenchmarkDecode(b *testing.B) {
for _, bm := range decodeBenchmarks {
for _, t := range bm.decodeToTypes {
Expand Down
44 changes: 44 additions & 0 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,25 @@ import (
//
// Unmarshal supports CBOR tag 55799 (self-describe CBOR), tag 0 and 1 (time),
// and tag 2 and 3 (bignum).
//
// Unmarshal returns ExtraneousDataError error (without decoding into v)
// if there are any remaining bytes following the first valid CBOR data item.
// See UnmarshalFirst, if you want to unmarshal only the first
// CBOR data item without ExtraneousDataError caused by remaining bytes.
func Unmarshal(data []byte, v interface{}) error {
return defaultDecMode.Unmarshal(data, v)
}

// UnmarshalFirst parses the first CBOR data item into the value pointed to by v
// using default decoding options. Any remaining bytes are returned in rest.
//
// If v is nil, not a pointer, or a nil pointer, UnmarshalFirst returns an error.
//
// See the documentation for Unmarshal for details.
func UnmarshalFirst(data []byte, v interface{}) (rest []byte, err error) {
return defaultDecMode.UnmarshalFirst(data, v)
}

// Valid checks whether data is a well-formed encoded CBOR data item and
// that it complies with default restrictions such as MaxNestedLevels,
// MaxArrayElements, MaxMapPairs, etc.
Expand Down Expand Up @@ -604,6 +619,35 @@ func (dm *decMode) Unmarshal(data []byte, v interface{}) error {
return d.value(v)
}

// UnmarshalFirst parses the first CBOR data item into the value pointed to by v
// using dm decoding mode. Any remaining bytes are returned in rest.
//
// If v is nil, not a pointer, or a nil pointer, UnmarshalFirst returns an error.
//
// See the documentation for Unmarshal for details.
func (dm *decMode) UnmarshalFirst(data []byte, v interface{}) (rest []byte, err error) {
d := decoder{data: data, dm: dm}

// check well-formedness.
off := d.off // Save offset before data validation
err = d.wellformed(true) // allow extra data after well-formed data item
d.off = off // Restore offset

// If it is well-formed, parse the value. This is structured like this to allow
// better test coverage
if err == nil {
err = d.value(v)
}

// If either wellformed or value returned an error, do not return rest bytes
if err != nil {
return nil, err
}

// Return the rest of the data slice (which might be len 0)
return d.data[d.off:], nil
}

// Valid checks whether data is a well-formed encoded CBOR data item and
// that it complies with configurable restrictions such as MaxNestedLevels,
// MaxArrayElements, MaxMapPairs, etc.
Expand Down
57 changes: 57 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5981,3 +5981,60 @@ func TestUnmarshalToDefaultMapType(t *testing.T) {
})
}
}

func TestUnmarshalFirstNoTrailing(t *testing.T) {
for _, tc := range unmarshalTests {
var v interface{}
if rest, err := UnmarshalFirst(tc.cborData, &v); err != nil {
t.Errorf("UnmarshalFirst(0x%x) returned error %v", tc.cborData, err)
} else {
if len(rest) != 0 {
t.Errorf("UnmarshalFirst(0x%x) returned rest %x (want [])", tc.cborData, rest)
}
// Check the value as well, although this is covered by other tests
if tm, ok := tc.emptyInterfaceValue.(time.Time); ok {
if vt, ok := v.(time.Time); !ok || !tm.Equal(vt) {
t.Errorf("UnmarshalFirst(0x%x) = %v (%T), want %v (%T)", tc.cborData, v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue)
}
} else if !reflect.DeepEqual(v, tc.emptyInterfaceValue) {
t.Errorf("UnmarshalFirst(0x%x) = %v (%T), want %v (%T)", tc.cborData, v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue)
}
}
}
}

func TestUnmarshalfirstTrailing(t *testing.T) {
// Random trailing data
trailingData := hexDecode("4a6b0f4718c73f391091ea1c")
for _, tc := range unmarshalTests {
data := make([]byte, 0, len(tc.cborData)+len(trailingData))
data = append(data, tc.cborData...)
data = append(data, trailingData...)
var v interface{}
if rest, err := UnmarshalFirst(data, &v); err != nil {
t.Errorf("UnmarshalFirst(0x%x) returned error %v", data, err)
} else {
if !bytes.Equal(trailingData, rest) {
t.Errorf("UnmarshalFirst(0x%x) returned rest %x (want %x)", data, rest, trailingData)
}
// Check the value as well, although this is covered by other tests
if tm, ok := tc.emptyInterfaceValue.(time.Time); ok {
if vt, ok := v.(time.Time); !ok || !tm.Equal(vt) {
t.Errorf("UnmarshalFirst(0x%x) = %v (%T), want %v (%T)", data, v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue)
}
} else if !reflect.DeepEqual(v, tc.emptyInterfaceValue) {
t.Errorf("UnmarshalFirst(0x%x) = %v (%T), want %v (%T)", data, v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue)
}
}
}
}

func TestUnmarshalFirstInvalidItem(t *testing.T) {
// UnmarshalFirst should not return "rest" if the item was not well-formed
invalidCBOR := hexDecode("83FF20030102")
var v interface{}
rest, err := UnmarshalFirst(invalidCBOR, &v)
if rest != nil {
t.Errorf("UnmarshalFirst(0x%x) = (%x, %v), want (nil, err)", invalidCBOR, rest, err)
}
}

0 comments on commit 21a6738

Please sign in to comment.