diff --git a/x/merkledb/codec.go b/x/merkledb/codec.go index 51984d141a1d..7baa37159981 100644 --- a/x/merkledb/codec.go +++ b/x/merkledb/codec.go @@ -23,6 +23,7 @@ const ( minVarIntLen = 1 boolLen = 1 idLen = hashing.HashLen + minCodecVersionLen = minVarIntLen minSerializedPathLen = minVarIntLen minByteSliceLen = minVarIntLen minDeletedKeyLen = minByteSliceLen @@ -30,11 +31,11 @@ const ( minProofPathLen = minVarIntLen minKeyValueLen = 2 * minByteSliceLen minProofNodeLen = minSerializedPathLen + minMaybeByteSliceLen + minVarIntLen - minProofLen = minProofPathLen + minByteSliceLen - minChangeProofLen = boolLen + 2*minProofPathLen + 2*minVarIntLen - minRangeProofLen = 2*minProofPathLen + minVarIntLen - minDBNodeLen = minMaybeByteSliceLen + minVarIntLen - minHashValuesLen = minVarIntLen + minMaybeByteSliceLen + minSerializedPathLen + minProofLen = minCodecVersionLen + minProofPathLen + minByteSliceLen + minChangeProofLen = minCodecVersionLen + +boolLen + 2*minProofPathLen + 2*minVarIntLen + minRangeProofLen = minCodecVersionLen + +2*minProofPathLen + minVarIntLen + minDBNodeLen = minCodecVersionLen + minMaybeByteSliceLen + minVarIntLen + minHashValuesLen = minCodecVersionLen + minVarIntLen + minMaybeByteSliceLen + minSerializedPathLen minProofNodeChildLen = minVarIntLen + idLen minChildLen = minVarIntLen + minSerializedPathLen + idLen ) @@ -60,6 +61,7 @@ var ( errNonZeroNibblePadding = errors.New("nibbles should be padded with 0s") errExtraSpace = errors.New("trailing buffer space") errNegativeSliceLength = errors.New("negative slice length") + errInvalidCodecVersion = errors.New("invalid codec version") ) // EncoderDecoder defines the interface needed by merkleDB to marshal @@ -69,7 +71,6 @@ type EncoderDecoder interface { Decoder } -// TODO actually encode the version and remove version from the interface type Encoder interface { EncodeProof(version uint16, p *Proof) ([]byte, error) EncodeChangeProof(version uint16, p *ChangeProof) ([]byte, error) @@ -107,10 +108,13 @@ func (c *codecImpl) EncodeProof(version uint16, proof *Proof) ([]byte, error) { } if version != codecVersion { - return nil, errUnknownVersion + return nil, fmt.Errorf("%w: %d", errUnknownVersion, version) } buf := &bytes.Buffer{} + if err := c.encodeInt(buf, int(version)); err != nil { + return nil, err + } if err := c.encodeProofPath(buf, proof.Path); err != nil { return nil, err } @@ -129,14 +133,17 @@ func (c *codecImpl) EncodeChangeProof(version uint16, proof *ChangeProof) ([]byt } if version != codecVersion { - return nil, errUnknownVersion + return nil, fmt.Errorf("%w: %d", errUnknownVersion, version) } buf := &bytes.Buffer{} + + if err := c.encodeInt(buf, int(version)); err != nil { + return nil, err + } if err := c.encodeBool(buf, proof.HadRootsInHistory); err != nil { return nil, err } - if err := c.encodeProofPath(buf, proof.StartProof); err != nil { return nil, err } @@ -169,10 +176,13 @@ func (c *codecImpl) EncodeRangeProof(version uint16, proof *RangeProof) ([]byte, } if version != codecVersion { - return nil, errUnknownVersion + return nil, fmt.Errorf("%w: %d", errUnknownVersion, version) } buf := &bytes.Buffer{} + if err := c.encodeInt(buf, int(version)); err != nil { + return nil, err + } if err := c.encodeProofPath(buf, proof.StartProof); err != nil { return nil, err } @@ -197,10 +207,13 @@ func (c *codecImpl) encodeDBNode(version uint16, n *dbNode) ([]byte, error) { } if version != codecVersion { - return nil, errUnknownVersion + return nil, fmt.Errorf("%w: %d", errUnknownVersion, version) } buf := &bytes.Buffer{} + if err := c.encodeInt(buf, int(version)); err != nil { + return nil, err + } if err := c.encodeMaybeByteSlice(buf, n.value); err != nil { return nil, err } @@ -231,10 +244,15 @@ func (c *codecImpl) encodeHashValues(version uint16, hv *hashValues) ([]byte, er } if version != codecVersion { - return nil, errUnknownVersion + return nil, fmt.Errorf("%w: %d", errUnknownVersion, version) } buf := &bytes.Buffer{} + + if err := c.encodeInt(buf, int(version)); err != nil { + return nil, err + } + length := len(hv.Children) if err := c.encodeInt(buf, length); err != nil { return nil, err @@ -273,7 +291,13 @@ func (c *codecImpl) DecodeProof(b []byte, proof *Proof) (uint16, error) { err error src = bytes.NewReader(b) ) - + gotCodecVersion, err := c.decodeInt(src) + if err != nil { + return 0, err + } + if codecVersion != gotCodecVersion { + return 0, fmt.Errorf("%w: %d", errInvalidCodecVersion, gotCodecVersion) + } if proof.Path, err = c.decodeProofPath(src); err != nil { return 0, err } @@ -302,6 +326,13 @@ func (c *codecImpl) DecodeChangeProof(b []byte, proof *ChangeProof) (uint16, err err error ) + gotCodecVersion, err := c.decodeInt(src) + if err != nil { + return 0, err + } + if gotCodecVersion != codecVersion { + return 0, fmt.Errorf("%w: %d", errInvalidCodecVersion, gotCodecVersion) + } if proof.HadRootsInHistory, err = c.decodeBool(src); err != nil { return 0, err } @@ -363,7 +394,13 @@ func (c *codecImpl) DecodeRangeProof(b []byte, proof *RangeProof) (uint16, error src = bytes.NewReader(b) err error ) - + gotCodecVersion, err := c.decodeInt(src) + if err != nil { + return 0, err + } + if codecVersion != gotCodecVersion { + return 0, fmt.Errorf("%w: %d", errInvalidCodecVersion, gotCodecVersion) + } if proof.StartProof, err = c.decodeProofPath(src); err != nil { return 0, err } @@ -406,6 +443,14 @@ func (c *codecImpl) decodeDBNode(b []byte, n *dbNode) (uint16, error) { err error ) + gotCodecVersion, err := c.decodeInt(src) + if err != nil { + return 0, err + } + if codecVersion != gotCodecVersion { + return 0, fmt.Errorf("%w: %d", errInvalidCodecVersion, gotCodecVersion) + } + if n.value, err = c.decodeMaybeByteSlice(src); err != nil { return 0, err }