From 5f799179147509643b93f055dcc60e8a6cb080f8 Mon Sep 17 00:00:00 2001 From: Michael Andersen Date: Fri, 28 Apr 2023 10:16:53 -0700 Subject: [PATCH] Add UnmarshalPrefix Signed-off-by: Michael Andersen --- bench_test.go | 49 ++++++++++++++++++++++++++++++++++++++++++ decode.go | 42 ++++++++++++++++++++++++++++++++++++ decode_test.go | 58 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 149 insertions(+) diff --git a/bench_test.go b/bench_test.go index 761bff91..4f9fd778 100644 --- a/bench_test.go +++ b/bench_test.go @@ -5,6 +5,7 @@ package cbor import ( "bytes" + "crypto/rand" "io" "reflect" "testing" @@ -211,6 +212,54 @@ func BenchmarkUnmarshal(b *testing.B) { } } +func BenchmarkUnmarshalPrefix(b *testing.B) { + trailingData := make([]byte, 12) + rand.Read(trailingData) + for _, bm := range decodeBenchmarks { + for _, t := range bm.decodeToTypes { + name := "CBOR " + bm.name + " to Go " + t.String() + if t.Kind() == reflect.Struct { + name = "CBOR " + bm.name + " to Go " + t.Kind().String() + } + data := make([]byte, 0, len(bm.cborData)+len(trailingData)) + data = append(data, bm.cborData...) + data = append(data, trailingData...) + b.Run(name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + vPtr := reflect.New(t).Interface() + if _, err := UnmarshalPrefix(data, vPtr); err != nil { + b.Fatal("UnmarshalPrefix:", err) + } + } + }) + } + } +} + +func BenchmarkUnmarshalPrefixViaDecoder(b *testing.B) { + trailingData := make([]byte, 12) + rand.Read(trailingData) + for _, bm := range decodeBenchmarks { + for _, t := range bm.decodeToTypes { + name := "CBOR " + bm.name + " to Go " + t.String() + if t.Kind() == reflect.Struct { + name = "CBOR " + bm.name + " to Go " + t.Kind().String() + } + data := make([]byte, 0, len(bm.cborData)+len(trailingData)) + data = append(data, bm.cborData...) + data = append(data, trailingData...) + b.Run(name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + vPtr := reflect.New(t).Interface() + if err := NewDecoder(bytes.NewReader(data)).Decode(vPtr); err != nil { + b.Fatal("UnmarshalDecoder:", err) + } + } + }) + } + } +} + func BenchmarkDecode(b *testing.B) { for _, bm := range decodeBenchmarks { for _, t := range bm.decodeToTypes { diff --git a/decode.go b/decode.go index 343da9ea..95ca66e7 100644 --- a/decode.go +++ b/decode.go @@ -95,10 +95,23 @@ import ( // // Unmarshal supports CBOR tag 55799 (self-describe CBOR), tag 0 and 1 (time), // and tag 2 and 3 (bignum). +// +// Unmarshal will return an ExtraneousDataError error if, upon decoding a valid +// message, there are remaining bytes at the end of the slice. See UnmarshalPrefix +// if you expect data to contain more than one CBOR-encoded item. func Unmarshal(data []byte, v interface{}) error { return defaultDecMode.Unmarshal(data, v) } +// UnmarshalPrefix will unmarshal the first CBOR-encoded item found in data, in the +// same way as Unmarshal. Any remaining bytes after a single valid item +// has been parsed will be returned in rest. +// +// See Unmarshal for more information +func UnmarshalPrefix(data []byte, v interface{}) (rest []byte, err error) { + return defaultDecMode.UnmarshalPrefix(data, v) +} + // Valid checks whether the CBOR data is complete and well-formed. func Valid(data []byte) error { return defaultDecMode.Valid(data) @@ -561,6 +574,35 @@ func (dm *decMode) Unmarshal(data []byte, v interface{}) error { return d.value(v) } +// UnmarshalPrefix parses the first CBOR-encoded item in data into the value +// pointed to by v using dm decoding mode. If v is nil, not a pointer, or a nil pointer, +// UnmarshalPrefix returns an error. Any remaining bytes after a single valid item +// has been parsed will be returned in rest. +// +// See the documentation for UnmarshalPrefix for details. +func (dm *decMode) UnmarshalPrefix(data []byte, v interface{}) (rest []byte, err error) { + d := decoder{data: data, dm: dm} + + // check valid + off := d.off // Save offset before data validation + err = d.valid(true) // allow extra data after valid data item + d.off = off // Restore offset + + // If it is valid, parse the value. This is structured like this to allow + // better test coverage + if err == nil { + err = d.value(v) + } + + // If either valid or value returned an error, do not return rest bytes + if err != nil { + return nil, err + } + + // Return the rest of the data slice (which might be len 0) + return data[d.off:], nil +} + // Valid checks whether the CBOR data is complete and well-formed. func (dm *decMode) Valid(data []byte) error { d := decoder{data: data, dm: dm} diff --git a/decode_test.go b/decode_test.go index 0e585d3e..18d7615e 100644 --- a/decode_test.go +++ b/decode_test.go @@ -5,6 +5,7 @@ package cbor import ( "bytes" + "crypto/rand" "encoding/binary" "encoding/hex" "errors" @@ -5981,3 +5982,60 @@ func TestUnmarshalToDefaultMapType(t *testing.T) { }) } } + +func TestUnmarshalPrefixNoTrailing(t *testing.T) { + for _, tc := range unmarshalTests { + var v interface{} + if rest, err := UnmarshalPrefix(tc.cborData, &v); err != nil { + t.Errorf("UnmarshalPrefix(0x%x) returned error %v", tc.cborData, err) + } else { + if len(rest) != 0 { + t.Errorf("UnmarshalPrefix(0x%x) returned rest %x (want [])", tc.cborData, rest) + } + // Check the value as well, although this is covered by other tests + if tm, ok := tc.emptyInterfaceValue.(time.Time); ok { + if vt, ok := v.(time.Time); !ok || !tm.Equal(vt) { + t.Errorf("Unmarshal(0x%x) = %v (%T), want %v (%T)", tc.cborData, v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue) + } + } else if !reflect.DeepEqual(v, tc.emptyInterfaceValue) { + t.Errorf("Unmarshal(0x%x) = %v (%T), want %v (%T)", tc.cborData, v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue) + } + } + } +} + +func TestUnmarshalPrefixTrailing(t *testing.T) { + trailingData := make([]byte, 12) + rand.Read(trailingData) + for _, tc := range unmarshalTests { + data := make([]byte, 0, len(tc.cborData)+len(trailingData)) + data = append(data, tc.cborData...) + data = append(data, trailingData...) + var v interface{} + if rest, err := UnmarshalPrefix(data, &v); err != nil { + t.Errorf("UnmarshalPrefix(0x%x) returned error %v", data, err) + } else { + if !bytes.Equal(trailingData, rest) { + t.Errorf("UnmarshalPrefix(0x%x) returned rest %x (want %x)", data, rest, trailingData) + } + // Check the value as well, although this is covered by other tests + if tm, ok := tc.emptyInterfaceValue.(time.Time); ok { + if vt, ok := v.(time.Time); !ok || !tm.Equal(vt) { + t.Errorf("UnmarshalPrefix(0x%x) = %v (%T), want %v (%T)", data, v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue) + } + } else if !reflect.DeepEqual(v, tc.emptyInterfaceValue) { + t.Errorf("UnmarshalPrefix(0x%x) = %v (%T), want %v (%T)", data, v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue) + } + } + } +} + +func TestUnmarshalPrefixInvalidItem(t *testing.T) { + // UnmarshalPrefix should not return "rest" if the item was not well-formed + invalidCBOR := hexDecode("83FF20030102") + var v interface{} + rest, err := UnmarshalPrefix(invalidCBOR, &v) + if rest != nil { + t.Errorf("UnmarshalPrefix(0x%x) = (%x, %v), want (nil, err)", invalidCBOR, rest, err) + } +}