diff --git a/codec/codec.go b/codec/codec.go index 6ee79966718..8d6ab445c61 100644 --- a/codec/codec.go +++ b/codec/codec.go @@ -14,7 +14,6 @@ var ( ErrMaxSliceLenExceeded = errors.New("max slice length exceeded") ErrDoesNotImplementInterface = errors.New("does not implement interface") ErrUnexportedField = errors.New("unexported field") - ErrExtraSpace = errors.New("trailing buffer space") ErrMarshalZeroLength = errors.New("can't marshal zero length value") ErrUnmarshalZeroLength = errors.New("can't unmarshal zero length value") ) @@ -22,7 +21,7 @@ var ( // Codec marshals and unmarshals type Codec interface { MarshalInto(interface{}, *wrappers.Packer) error - Unmarshal([]byte, interface{}) error + UnmarshalFrom(*wrappers.Packer, interface{}) error // Returns the size, in bytes, of [value] when it's marshaled Size(value interface{}) (int, error) diff --git a/codec/codectest/codectest.go b/codec/codectest/codectest.go index 528782fe913..95e4ba473ec 100644 --- a/codec/codectest/codectest.go +++ b/codec/codectest/codectest.go @@ -10,6 +10,8 @@ import ( "github.com/stretchr/testify/require" + "github.com/ava-labs/avalanchego/utils/wrappers" + codecpkg "github.com/ava-labs/avalanchego/codec" ) @@ -75,6 +77,7 @@ var ( {"Slice Length Overflow", TestSliceLengthOverflow}, {"Map", TestMap}, {"Can Marshal Large Slices", TestCanMarshalLargeSlices}, + {"Implements UnmarshalFrom", TestImplementsUnmarshalFrom}, } MultipleTagsTests = []NamedTest{ @@ -1153,3 +1156,34 @@ func FuzzStructUnmarshal(codec codecpkg.GeneralCodec, f *testing.F) { require.Len(bytes, size) }) } + +func TestImplementsUnmarshalFrom(t testing.TB, codec codecpkg.GeneralCodec) { + require := require.New(t) + + p := wrappers.Packer{MaxSize: 1024} + p.PackFixedBytes([]byte{0, 1, 2}) // pack 3 extra bytes prefix + + mySlice := []bool{true, false, true, true} + + require.NoError(codec.MarshalInto(mySlice, &p)) + + p.PackFixedBytes([]byte{7, 7, 7}) // pack 3 extra bytes suffix + + bytesLen, err := codec.Size(mySlice) + require.NoError(err) + require.Equal(3+bytesLen+3, p.Offset) + + p = wrappers.Packer{Bytes: p.Bytes, MaxSize: p.MaxSize, Offset: 3} + + var sliceUnmarshaled []bool + require.NoError(codec.UnmarshalFrom(&p, &sliceUnmarshaled)) + require.Equal(mySlice, sliceUnmarshaled) + require.Equal( + wrappers.Packer{ + Bytes: p.Bytes, + MaxSize: p.MaxSize, + Offset: 11, + }, + p, + ) +} diff --git a/codec/manager.go b/codec/manager.go index 00de14e5c9a..8230a5430c9 100644 --- a/codec/manager.go +++ b/codec/manager.go @@ -32,6 +32,7 @@ var ( ErrCantPackVersion = errors.New("couldn't pack codec version") ErrCantUnpackVersion = errors.New("couldn't unpack codec version") ErrDuplicatedVersion = errors.New("duplicated codec version") + ErrExtraSpace = errors.New("trailing buffer space") ) var _ Manager = (*manager)(nil) @@ -157,5 +158,17 @@ func (m *manager) Unmarshal(bytes []byte, dest interface{}) (uint16, error) { if !exists { return version, ErrUnknownVersion } - return version, c.Unmarshal(p.Bytes[p.Offset:], dest) + + if err := c.UnmarshalFrom(&p, dest); err != nil { + return version, err + } + if p.Offset != len(bytes) { + return version, fmt.Errorf("%w: read %d provided %d", + ErrExtraSpace, + p.Offset, + len(bytes), + ) + } + + return version, nil } diff --git a/codec/reflectcodec/type_codec.go b/codec/reflectcodec/type_codec.go index 901ff2bc906..649b691810b 100644 --- a/codec/reflectcodec/type_codec.go +++ b/codec/reflectcodec/type_codec.go @@ -496,31 +496,18 @@ func (c *genericCodec) marshal( } } -// Unmarshal unmarshals [bytes] into [dest], where [dest] must be a pointer or +// UnmarshalFrom unmarshals [p.Bytes] into [dest], where [dest] must be a pointer or // interface -func (c *genericCodec) Unmarshal(bytes []byte, dest interface{}) error { +func (c *genericCodec) UnmarshalFrom(p *wrappers.Packer, dest interface{}) error { if dest == nil { return codec.ErrUnmarshalNil } - p := wrappers.Packer{ - Bytes: bytes, - } destPtr := reflect.ValueOf(dest) if destPtr.Kind() != reflect.Ptr { return errNeedPointer } - if err := c.unmarshal(&p, destPtr.Elem(), nil /*=typeStack*/); err != nil { - return err - } - if p.Offset != len(bytes) { - return fmt.Errorf("%w: read %d provided %d", - codec.ErrExtraSpace, - p.Offset, - len(bytes), - ) - } - return nil + return c.unmarshal(p, destPtr.Elem(), nil /*=typeStack*/) } // Unmarshal from p.Bytes into [value]. [value] must be addressable.