From f80cb929b2866900f9fcbe73b0260319d0e85e6e Mon Sep 17 00:00:00 2001 From: Dan Laine Date: Thu, 14 Dec 2023 09:13:24 -0500 Subject: [PATCH] `merkledb` -- dynamic root (#2177) Signed-off-by: Dan Laine Co-authored-by: dboehm-avalabs --- scripts/mocks.mockgen.txt | 1 - x/merkledb/codec.go | 32 ++- x/merkledb/codec_test.go | 16 +- x/merkledb/db.go | 132 ++++++----- x/merkledb/db_test.go | 53 ++++- x/merkledb/history.go | 20 +- x/merkledb/history_test.go | 23 +- x/merkledb/metrics_test.go | 6 +- x/merkledb/mock_db.go | 406 +++++++++++++++++----------------- x/merkledb/node.go | 5 + x/merkledb/proof.go | 27 +-- x/merkledb/proof_test.go | 152 ++++++++++--- x/merkledb/trie.go | 10 +- x/merkledb/trie_test.go | 167 ++++++++------ x/merkledb/trieview.go | 245 ++++++++++---------- x/sync/client_test.go | 23 +- x/sync/g_db/db_client.go | 8 + x/sync/manager.go | 23 ++ x/sync/network_server.go | 4 + x/sync/network_server_test.go | 20 +- x/sync/sync_test.go | 29 ++- 21 files changed, 842 insertions(+), 560 deletions(-) diff --git a/scripts/mocks.mockgen.txt b/scripts/mocks.mockgen.txt index 3ec7849d6e0b..714c3b5932e4 100644 --- a/scripts/mocks.mockgen.txt +++ b/scripts/mocks.mockgen.txt @@ -43,5 +43,4 @@ github.com/ava-labs/avalanchego/vms/registry=VMGetter=vms/registry/mock_vm_gette github.com/ava-labs/avalanchego/vms/registry=VMRegisterer=vms/registry/mock_vm_registerer.go github.com/ava-labs/avalanchego/vms/registry=VMRegistry=vms/registry/mock_vm_registry.go github.com/ava-labs/avalanchego/vms=Factory,Manager=vms/mock_manager.go -github.com/ava-labs/avalanchego/x/merkledb=MerkleDB=x/merkledb/mock_db.go github.com/ava-labs/avalanchego/x/sync=Client=x/sync/mock_client.go diff --git a/x/merkledb/codec.go b/x/merkledb/codec.go index 973dd5888ab6..c14534d9cead 100644 --- a/x/merkledb/codec.go +++ b/x/merkledb/codec.go @@ -66,11 +66,13 @@ type encoder interface { // Returns the bytes that will be hashed to generate [n]'s ID. // Assumes [n] is non-nil. encodeHashValues(n *node) []byte + encodeKey(key Key) []byte } type decoder interface { // Assumes [n] is non-nil. decodeDBNode(bytes []byte, n *dbNode) error + decodeKey(bytes []byte) (Key, error) } func newCodec() encoderDecoder { @@ -98,7 +100,6 @@ func (c *codecImpl) encodeDBNode(n *dbNode) []byte { estimatedLen = estimatedValueLen + minVarIntLen + estimatedNodeChildLen*numChildren buf = bytes.NewBuffer(make([]byte, 0, estimatedLen)) ) - c.encodeMaybeByteSlice(buf, n.value) c.encodeUint(buf, uint64(numChildren)) // Note we insert children in order of increasing index @@ -108,7 +109,7 @@ func (c *codecImpl) encodeDBNode(n *dbNode) []byte { for _, index := range keys { entry := n.children[index] c.encodeUint(buf, uint64(index)) - c.encodeKey(buf, entry.compressedKey) + c.encodeKeyToBuffer(buf, entry.compressedKey) _, _ = buf.Write(entry.id[:]) c.encodeBool(buf, entry.hasValue) } @@ -134,7 +135,7 @@ func (c *codecImpl) encodeHashValues(n *node) []byte { _, _ = buf.Write(entry.id[:]) } c.encodeMaybeByteSlice(buf, n.valueDigest) - c.encodeKey(buf, n.key) + c.encodeKeyToBuffer(buf, n.key) return buf.Bytes() } @@ -172,7 +173,7 @@ func (c *codecImpl) decodeDBNode(b []byte, n *dbNode) error { } previousChild = index - compressedKey, err := c.decodeKey(src) + compressedKey, err := c.decodeKeyFromReader(src) if err != nil { return err } @@ -330,12 +331,31 @@ func (*codecImpl) decodeID(src *bytes.Reader) (ids.ID, error) { return id, err } -func (c *codecImpl) encodeKey(dst *bytes.Buffer, key Key) { +func (c *codecImpl) encodeKey(key Key) []byte { + estimatedLen := binary.MaxVarintLen64 + len(key.Bytes()) + dst := bytes.NewBuffer(make([]byte, 0, estimatedLen)) + c.encodeKeyToBuffer(dst, key) + return dst.Bytes() +} + +func (c *codecImpl) encodeKeyToBuffer(dst *bytes.Buffer, key Key) { c.encodeUint(dst, uint64(key.length)) _, _ = dst.Write(key.Bytes()) } -func (c *codecImpl) decodeKey(src *bytes.Reader) (Key, error) { +func (c *codecImpl) decodeKey(b []byte) (Key, error) { + src := bytes.NewReader(b) + key, err := c.decodeKeyFromReader(src) + if err != nil { + return Key{}, err + } + if src.Len() != 0 { + return Key{}, errExtraSpace + } + return key, err +} + +func (c *codecImpl) decodeKeyFromReader(src *bytes.Reader) (Key, error) { if minKeyLen > src.Len() { return Key{}, io.ErrUnexpectedEOF } diff --git a/x/merkledb/codec_test.go b/x/merkledb/codec_test.go index 699db9a4bd81..1f463ca50858 100644 --- a/x/merkledb/codec_test.go +++ b/x/merkledb/codec_test.go @@ -81,21 +81,14 @@ func FuzzCodecKey(f *testing.F) { ) { require := require.New(t) codec := codec.(*codecImpl) - reader := bytes.NewReader(b) - startLen := reader.Len() - got, err := codec.decodeKey(reader) + got, err := codec.decodeKey(b) if err != nil { t.SkipNow() } - endLen := reader.Len() - numRead := startLen - endLen // Encoding [got] should be the same as [b]. - var buf bytes.Buffer - codec.encodeKey(&buf, got) - bufBytes := buf.Bytes() - require.Len(bufBytes, numRead) - require.Equal(b[:numRead], bufBytes) + gotBytes := codec.encodeKey(got) + require.Equal(b, gotBytes) }, ) } @@ -248,7 +241,6 @@ func FuzzEncodeHashValues(f *testing.F) { func TestCodecDecodeKeyLengthOverflowRegression(t *testing.T) { codec := codec.(*codecImpl) - bytes := bytes.NewReader(binary.AppendUvarint(nil, math.MaxInt)) - _, err := codec.decodeKey(bytes) + _, err := codec.decodeKey(binary.AppendUvarint(nil, math.MaxInt)) require.ErrorIs(t, err, io.ErrUnexpectedEOF) } diff --git a/x/merkledb/db.go b/x/merkledb/db.go index bcc8a2a803f2..823f754d2f3c 100644 --- a/x/merkledb/db.go +++ b/x/merkledb/db.go @@ -51,11 +51,11 @@ var ( intermediateNodePrefix = []byte{2} cleanShutdownKey = []byte(string(metadataPrefix) + "cleanShutdown") + rootDBKey = []byte(string(metadataPrefix) + "root") hadCleanShutdown = []byte{1} didNotHaveCleanShutdown = []byte{0} - errSameRoot = errors.New("start and end root are the same") - errNoNewSentinel = errors.New("there was no updated sentinel node in change list") + errSameRoot = errors.New("start and end root are the same") ) type ChangeProofer interface { @@ -64,6 +64,9 @@ type ChangeProofer interface { // Returns at most [maxLength] key/value pairs. // Returns [ErrInsufficientHistory] if this node has insufficient history // to generate the proof. + // Returns ErrEmptyProof if [endRootID] is ids.Empty. + // Note that [endRootID] == ids.Empty means the trie is empty + // (i.e. we don't need a change proof.) // Returns [ErrNoEndRoot], which wraps [ErrInsufficientHistory], if the // history doesn't contain the [endRootID]. GetChangeProof( @@ -102,6 +105,9 @@ type RangeProofer interface { // [start, end] when the root of the trie was [rootID]. // If [start] is Nothing, there's no lower bound on the range. // If [end] is Nothing, there's no upper bound on the range. + // Returns ErrEmptyProof if [rootID] is ids.Empty. + // Note that [rootID] == ids.Empty means the trie is empty + // (i.e. we don't need a range proof.) GetRangeProofAtRoot( ctx context.Context, rootID ids.ID, @@ -203,11 +209,11 @@ type merkleDB struct { debugTracer trace.Tracer infoTracer trace.Tracer - // The sentinel node of this trie. - // It is the node with a nil key and is the ancestor of all nodes in the trie. - // If it has a value or has multiple children, it is also the root of the trie. - sentinelNode *node - rootID ids.ID + // The root of this trie. + // Nothing if the trie is empty. + root maybe.Maybe[*node] + + rootID ids.ID // Valid children of this trie. childViews []*trieView @@ -270,6 +276,9 @@ func newDatabase( // add current root to history (has no changes) trieDB.history.record(&changeSummary{ rootID: trieDB.rootID, + rootChange: change[maybe.Maybe[*node]]{ + after: trieDB.root, + }, values: map[Key]*change[maybe.Maybe[[]byte]]{}, nodes: map[Key]*change[*node]{}, }) @@ -297,7 +306,8 @@ func newDatabase( // Deletes every intermediate node and rebuilds them by re-adding every key/value. // TODO: make this more efficient by only clearing out the stale portions of the trie. func (db *merkleDB) rebuild(ctx context.Context, cacheSize int) error { - db.sentinelNode = newNode(Key{}) + db.root = maybe.Nothing[*node]() + db.rootID = ids.Empty // Delete intermediate nodes. if err := database.ClearPrefix(db.baseDB, intermediateNodePrefix, rebuildIntermediateDeletionWriteSize); err != nil { @@ -591,13 +601,6 @@ func (db *merkleDB) getMerkleRoot() ids.ID { return db.rootID } -// isSentinelNodeTheRoot returns true if the passed in sentinel node has a value and or multiple child nodes -// When this is true, the root of the trie is the sentinel node -// When this is false, the root of the trie is the sentinel node's single child -func isSentinelNodeTheRoot(sentinel *node) bool { - return sentinel.valueDigest.HasValue() || len(sentinel.children) != 1 -} - func (db *merkleDB) GetProof(ctx context.Context, key []byte) (*Proof, error) { db.commitLock.RLock() defer db.commitLock.RUnlock() @@ -606,7 +609,6 @@ func (db *merkleDB) GetProof(ctx context.Context, key []byte) (*Proof, error) { } // Assumes [db.commitLock] is read locked. -// Assumes [db.lock] is not held func (db *merkleDB) getProof(ctx context.Context, key []byte) (*Proof, error) { if db.closed { return nil, database.ErrClosed @@ -654,11 +656,13 @@ func (db *merkleDB) getRangeProofAtRoot( end maybe.Maybe[[]byte], maxLength int, ) (*RangeProof, error) { - if db.closed { + switch { + case db.closed: return nil, database.ErrClosed - } - if maxLength <= 0 { + case maxLength <= 0: return nil, fmt.Errorf("%w but was %d", ErrInvalidMaxLength, maxLength) + case rootID == ids.Empty: + return nil, ErrEmptyProof } historicalView, err := db.getHistoricalViewForRange(rootID, start, end) @@ -676,11 +680,13 @@ func (db *merkleDB) GetChangeProof( end maybe.Maybe[[]byte], maxLength int, ) (*ChangeProof, error) { - if start.HasValue() && end.HasValue() && bytes.Compare(start.Value(), end.Value()) == 1 { + switch { + case start.HasValue() && end.HasValue() && bytes.Compare(start.Value(), end.Value()) == 1: return nil, ErrStartAfterEnd - } - if startRootID == endRootID { + case startRootID == endRootID: return nil, errSameRoot + case endRootID == ids.Empty: + return nil, ErrEmptyProof } db.commitLock.RLock() @@ -941,13 +947,7 @@ func (db *merkleDB) commitChanges(ctx context.Context, trieToCommit *trieView) e return nil } - sentinelChange, ok := changes.nodes[Key{}] - if !ok { - return errNoNewSentinel - } - currentValueNodeBatch := db.valueNodeDB.NewBatch() - _, nodesSpan := db.infoTracer.Start(ctx, "MerkleDB.commitChanges.writeNodes") for key, nodeChange := range changes.nodes { shouldAddIntermediate := nodeChange.after != nil && !nodeChange.after.hasValue() @@ -983,12 +983,18 @@ func (db *merkleDB) commitChanges(ctx context.Context, trieToCommit *trieView) e return err } - // Only modify in-memory state after the commit succeeds - // so that we don't need to clean up on error. - db.sentinelNode = sentinelChange.after - db.rootID = changes.rootID db.history.record(changes) - return nil + + // Update root in database. + db.root = changes.rootChange.after + db.rootID = changes.rootID + + if db.root.IsNothing() { + return db.baseDB.Delete(rootDBKey) + } + + rootKey := codec.encodeKey(db.root.Value().key) + return db.baseDB.Put(rootDBKey, rootKey) } // moveChildViewsToDB removes any child views from the trieToCommit and moves them to the db @@ -1024,7 +1030,7 @@ func (db *merkleDB) VerifyChangeProof( case start.HasValue() && end.HasValue() && bytes.Compare(start.Value(), end.Value()) > 0: return ErrStartAfterEnd case proof.Empty(): - return ErrNoMerkleProof + return ErrEmptyProof case end.HasValue() && len(proof.KeyChanges) == 0 && len(proof.EndProof) == 0: // We requested an end proof but didn't get one. return ErrNoEndProof @@ -1166,37 +1172,41 @@ func (db *merkleDB) invalidateChildrenExcept(exception *trieView) { } } +// If the root is on disk, set [db.root] to it. +// Otherwise leave [db.root] as Nothing. func (db *merkleDB) initializeRoot() error { - // Not sure if the sentinel node exists or if it had a value, - // so check under both prefixes - var err error - db.sentinelNode, err = db.intermediateNodeDB.Get(Key{}) + rootKeyBytes, err := db.baseDB.Get(rootDBKey) + if err != nil { + if !errors.Is(err, database.ErrNotFound) { + return err + } + // Root isn't on disk. + return nil + } - if errors.Is(err, database.ErrNotFound) { - // Didn't find the sentinel in the intermediateNodeDB, check the valueNodeDB - db.sentinelNode, err = db.valueNodeDB.Get(Key{}) + // Root is on disk. + rootKey, err := codec.decodeKey(rootKeyBytes) + if err != nil { + return err } + // First, see if root is an intermediate node. + var root *node + root, err = db.getEditableNode(rootKey, false /* hasValue */) if err != nil { if !errors.Is(err, database.ErrNotFound) { return err } - // Sentinel node doesn't exist in either database prefix. - // Make a new one and store it in the intermediateNodeDB - db.sentinelNode = newNode(Key{}) - if err := db.intermediateNodeDB.Put(Key{}, db.sentinelNode); err != nil { + // The root must be a value node. + root, err = db.getEditableNode(rootKey, true /* hasValue */) + if err != nil { return err } } - db.rootID = db.sentinelNode.calculateID(db.metrics) - if !isSentinelNodeTheRoot(db.sentinelNode) { - // If the sentinel node is not the root, the trie's root is the sentinel node's only child - for _, childEntry := range db.sentinelNode.children { - db.rootID = childEntry.id - } - } + db.rootID = root.calculateID(db.metrics) + db.root = maybe.Some(root) return nil } @@ -1204,7 +1214,6 @@ func (db *merkleDB) initializeRoot() error { // If [start] is Nothing, there's no lower bound on the range. // If [end] is Nothing, there's no upper bound on the range. // Assumes [db.commitLock] is read locked. -// Assumes [db.lock] isn't held. func (db *merkleDB) getHistoricalViewForRange( rootID ids.ID, start maybe.Maybe[[]byte], @@ -1273,12 +1282,17 @@ func (db *merkleDB) getNode(key Key, hasValue bool) (*node, error) { switch { case db.closed: return nil, database.ErrClosed - case key == Key{}: - return db.sentinelNode, nil + case db.root.HasValue() && key == db.root.Value().key: + return db.root.Value(), nil case hasValue: return db.valueNodeDB.Get(key) + default: + return db.intermediateNodeDB.Get(key) } - return db.intermediateNodeDB.Get(key) +} + +func (db *merkleDB) getRoot() maybe.Maybe[*node] { + return db.root } func (db *merkleDB) Clear() error { @@ -1297,13 +1311,13 @@ func (db *merkleDB) Clear() error { } // Clear root - db.sentinelNode = newNode(Key{}) - db.rootID = db.sentinelNode.calculateID(db.metrics) + db.root = maybe.Nothing[*node]() + db.rootID = ids.Empty // Clear history db.history = newTrieHistory(db.history.maxHistoryLen) db.history.record(&changeSummary{ - rootID: db.getMerkleRoot(), + rootID: db.rootID, values: map[Key]*change[maybe.Maybe[[]byte]]{}, nodes: map[Key]*change[*node]{}, }) diff --git a/x/merkledb/db_test.go b/x/merkledb/db_test.go index 1cbce5a7792d..e14149ae1471 100644 --- a/x/merkledb/db_test.go +++ b/x/merkledb/db_test.go @@ -31,8 +31,6 @@ import ( const defaultHistoryLength = 300 -var emptyKey Key - // newDB returns a new merkle database with the underlying type so that tests can access unexported fields func newDB(ctx context.Context, db database.Database, config Config) (*merkleDB, error) { db, err := New(ctx, db, config) @@ -153,7 +151,7 @@ func Test_MerkleDB_DB_Load_Root_From_DB(t *testing.T) { require.NoError(db.Close()) - // reloading the db, should set the root back to the one that was saved to [baseDB] + // reloading the db should set the root back to the one that was saved to [baseDB] db, err = New( context.Background(), baseDB, @@ -804,8 +802,8 @@ func TestMerkleDBClear(t *testing.T) { iter := db.NewIterator() defer iter.Release() require.False(iter.Next()) - require.Equal(emptyRootID, db.getMerkleRoot()) - require.Equal(emptyKey, db.sentinelNode.key) + require.Equal(ids.Empty, db.getMerkleRoot()) + require.True(db.root.IsNothing()) // Assert caches are empty. require.Zero(db.valueNodeDB.nodeCache.Len()) @@ -948,6 +946,10 @@ func runRandDBTest(require *require.Assertions, r *rand.Rand, rt randTest, token } rangeProof, err := db.GetRangeProofAtRoot(context.Background(), root, start, end, maxProofLen) + if root == ids.Empty { + require.ErrorIs(err, ErrEmptyProof) + continue + } require.NoError(err) require.LessOrEqual(len(rangeProof.KeyValues), maxProofLen) @@ -981,6 +983,10 @@ func runRandDBTest(require *require.Assertions, r *rand.Rand, rt randTest, token require.ErrorIs(err, errSameRoot) continue } + if root == ids.Empty { + require.ErrorIs(err, ErrEmptyProof) + continue + } require.NoError(err) require.LessOrEqual(len(changeProof.KeyChanges), maxProofLen) @@ -1242,3 +1248,40 @@ func insertRandomKeyValues( } } } + +func TestGetRangeProofAtRootEmptyRootID(t *testing.T) { + require := require.New(t) + + db, err := getBasicDB() + require.NoError(err) + + _, err = db.getRangeProofAtRoot( + context.Background(), + ids.Empty, + maybe.Nothing[[]byte](), + maybe.Nothing[[]byte](), + 10, + ) + require.ErrorIs(err, ErrEmptyProof) +} + +func TestGetChangeProofEmptyRootID(t *testing.T) { + require := require.New(t) + + db, err := getBasicDB() + require.NoError(err) + + require.NoError(db.Put([]byte("key"), []byte("value"))) + + rootID := db.getMerkleRoot() + + _, err = db.GetChangeProof( + context.Background(), + rootID, + ids.Empty, + maybe.Nothing[[]byte](), + maybe.Nothing[[]byte](), + 10, + ) + require.ErrorIs(err, ErrEmptyProof) +} diff --git a/x/merkledb/history.go b/x/merkledb/history.go index 3717f9ef9b96..1f55b93c5d58 100644 --- a/x/merkledb/history.go +++ b/x/merkledb/history.go @@ -54,15 +54,20 @@ type changeSummaryAndInsertNumber struct { // Tracks all the node and value changes that resulted in the rootID. type changeSummary struct { + // The ID of the trie after these changes. rootID ids.ID - nodes map[Key]*change[*node] - values map[Key]*change[maybe.Maybe[[]byte]] + // The root before/after this change. + // Set in [calculateNodeIDs]. + rootChange change[maybe.Maybe[*node]] + nodes map[Key]*change[*node] + values map[Key]*change[maybe.Maybe[[]byte]] } func newChangeSummary(estimatedSize int) *changeSummary { return &changeSummary{ - nodes: make(map[Key]*change[*node], estimatedSize), - values: make(map[Key]*change[maybe.Maybe[[]byte]], estimatedSize), + nodes: make(map[Key]*change[*node], estimatedSize), + values: make(map[Key]*change[maybe.Maybe[[]byte]], estimatedSize), + rootChange: change[maybe.Maybe[*node]]{}, } } @@ -250,6 +255,13 @@ func (th *trieHistory) getChangesToGetToRoot(rootID ids.ID, start maybe.Maybe[[] for i := mostRecentChangeIndex; i > lastRootChangeIndex; i-- { changes, _ := th.history.Index(i) + if i == mostRecentChangeIndex { + combinedChanges.rootChange.before = changes.rootChange.after + } + if i == lastRootChangeIndex+1 { + combinedChanges.rootChange.after = changes.rootChange.before + } + for key, changedNode := range changes.nodes { combinedChanges.nodes[key] = &change[*node]{ after: changedNode.before, diff --git a/x/merkledb/history_test.go b/x/merkledb/history_test.go index 3c8e8700d567..6af39e0e08e3 100644 --- a/x/merkledb/history_test.go +++ b/x/merkledb/history_test.go @@ -36,8 +36,7 @@ func Test_History_Simple(t *testing.T) { origProof, err := db.GetRangeProof(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), 10) require.NoError(err) require.NotNil(origProof) - - origRootID := db.getMerkleRoot() + origRootID := db.rootID require.NoError(origProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID, db.tokenSize)) batch = db.NewBatch() @@ -338,8 +337,7 @@ func Test_History_RepeatedRoot(t *testing.T) { origProof, err := db.GetRangeProof(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), 10) require.NoError(err) require.NotNil(origProof) - - origRootID := db.getMerkleRoot() + origRootID := db.rootID require.NoError(origProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID, db.tokenSize)) batch = db.NewBatch() @@ -381,8 +379,7 @@ func Test_History_ExcessDeletes(t *testing.T) { origProof, err := db.GetRangeProof(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), 10) require.NoError(err) require.NotNil(origProof) - - origRootID := db.getMerkleRoot() + origRootID := db.rootID require.NoError(origProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID, db.tokenSize)) batch = db.NewBatch() @@ -414,8 +411,7 @@ func Test_History_DontIncludeAllNodes(t *testing.T) { origProof, err := db.GetRangeProof(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), 10) require.NoError(err) require.NotNil(origProof) - - origRootID := db.getMerkleRoot() + origRootID := db.rootID require.NoError(origProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID, db.tokenSize)) batch = db.NewBatch() @@ -443,7 +439,7 @@ func Test_History_Branching2Nodes(t *testing.T) { origProof, err := db.GetRangeProof(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), 10) require.NoError(err) require.NotNil(origProof) - origRootID := db.getMerkleRoot() + origRootID := db.rootID require.NoError(origProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID, db.tokenSize)) batch = db.NewBatch() @@ -471,8 +467,7 @@ func Test_History_Branching3Nodes(t *testing.T) { origProof, err := db.GetRangeProof(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), 10) require.NoError(err) require.NotNil(origProof) - - origRootID := db.getMerkleRoot() + origRootID := db.rootID require.NoError(origProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID, db.tokenSize)) batch = db.NewBatch() @@ -657,6 +652,9 @@ func TestHistoryGetChangesToRoot(t *testing.T) { for i := 0; i < maxHistoryLen; i++ { // Fill the history changes = append(changes, &changeSummary{ rootID: ids.GenerateTestID(), + rootChange: change[maybe.Maybe[*node]]{ + before: maybe.Some(&node{}), + }, nodes: map[Key]*change[*node]{ ToKey([]byte{byte(i)}): { before: &node{}, @@ -692,7 +690,8 @@ func TestHistoryGetChangesToRoot(t *testing.T) { name: "most recent change", rootID: changes[maxHistoryLen-1].rootID, validateFunc: func(require *require.Assertions, got *changeSummary) { - require.Equal(newChangeSummary(defaultPreallocationSize), got) + expected := newChangeSummary(defaultPreallocationSize) + require.Equal(expected, got) }, }, { diff --git a/x/merkledb/metrics_test.go b/x/merkledb/metrics_test.go index 3bf5a9480a54..304c3027133b 100644 --- a/x/merkledb/metrics_test.go +++ b/x/merkledb/metrics_test.go @@ -34,20 +34,20 @@ func Test_Metrics_Basic_Usage(t *testing.T) { require.Equal(t, int64(1), db.metrics.(*mockMetrics).keyReadCount) require.Equal(t, int64(1), db.metrics.(*mockMetrics).keyWriteCount) - require.Equal(t, int64(2), db.metrics.(*mockMetrics).hashCount) + require.Equal(t, int64(1), db.metrics.(*mockMetrics).hashCount) require.NoError(t, db.Delete([]byte("key"))) require.Equal(t, int64(1), db.metrics.(*mockMetrics).keyReadCount) require.Equal(t, int64(2), db.metrics.(*mockMetrics).keyWriteCount) - require.Equal(t, int64(3), db.metrics.(*mockMetrics).hashCount) + require.Equal(t, int64(1), db.metrics.(*mockMetrics).hashCount) _, err = db.Get([]byte("key2")) require.ErrorIs(t, err, database.ErrNotFound) require.Equal(t, int64(2), db.metrics.(*mockMetrics).keyReadCount) require.Equal(t, int64(2), db.metrics.(*mockMetrics).keyWriteCount) - require.Equal(t, int64(3), db.metrics.(*mockMetrics).hashCount) + require.Equal(t, int64(1), db.metrics.(*mockMetrics).hashCount) } func Test_Metrics_Initialize(t *testing.T) { diff --git a/x/merkledb/mock_db.go b/x/merkledb/mock_db.go index a4d1d6b6d6f3..e07354e6b3e8 100644 --- a/x/merkledb/mock_db.go +++ b/x/merkledb/mock_db.go @@ -1,6 +1,3 @@ -// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - // Code generated by MockGen. DO NOT EDIT. // Source: github.com/ava-labs/avalanchego/x/merkledb (interfaces: MerkleDB) @@ -8,439 +5,452 @@ package merkledb import ( - context "context" - reflect "reflect" - - database "github.com/ava-labs/avalanchego/database" - ids "github.com/ava-labs/avalanchego/ids" - maybe "github.com/ava-labs/avalanchego/utils/maybe" - gomock "go.uber.org/mock/gomock" + reflect "reflect" + context "context" + database "github.com/ava-labs/avalanchego/database" + ids "github.com/ava-labs/avalanchego/ids" + maybe "github.com/ava-labs/avalanchego/utils/maybe" + gomock "go.uber.org/mock/gomock" ) // MockMerkleDB is a mock of MerkleDB interface. type MockMerkleDB struct { - ctrl *gomock.Controller - recorder *MockMerkleDBMockRecorder + ctrl *gomock.Controller + recorder *MockMerkleDBMockRecorder } // MockMerkleDBMockRecorder is the mock recorder for MockMerkleDB. type MockMerkleDBMockRecorder struct { - mock *MockMerkleDB + mock *MockMerkleDB } // NewMockMerkleDB creates a new mock instance. func NewMockMerkleDB(ctrl *gomock.Controller) *MockMerkleDB { - mock := &MockMerkleDB{ctrl: ctrl} - mock.recorder = &MockMerkleDBMockRecorder{mock} - return mock + mock := &MockMerkleDB{ctrl: ctrl} + mock.recorder = &MockMerkleDBMockRecorder{mock} + return mock } // EXPECT returns an object that allows the caller to indicate expected use. func (m *MockMerkleDB) EXPECT() *MockMerkleDBMockRecorder { - return m.recorder + return m.recorder } // Clear mocks base method. func (m *MockMerkleDB) Clear() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Clear") - ret0, _ := ret[0].(error) - return ret0 + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Clear") + ret0, _ := ret[0].(error) + return ret0 } // Clear indicates an expected call of Clear. func (mr *MockMerkleDBMockRecorder) Clear() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Clear", reflect.TypeOf((*MockMerkleDB)(nil).Clear)) + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Clear", reflect.TypeOf((*MockMerkleDB)(nil).Clear)) } // Close mocks base method. func (m *MockMerkleDB) Close() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Close") - ret0, _ := ret[0].(error) - return ret0 + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 } // Close indicates an expected call of Close. func (mr *MockMerkleDBMockRecorder) Close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockMerkleDB)(nil).Close)) + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockMerkleDB)(nil).Close)) } // CommitChangeProof mocks base method. func (m *MockMerkleDB) CommitChangeProof(arg0 context.Context, arg1 *ChangeProof) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CommitChangeProof", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CommitChangeProof", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 } // CommitChangeProof indicates an expected call of CommitChangeProof. func (mr *MockMerkleDBMockRecorder) CommitChangeProof(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CommitChangeProof", reflect.TypeOf((*MockMerkleDB)(nil).CommitChangeProof), arg0, arg1) + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CommitChangeProof", reflect.TypeOf((*MockMerkleDB)(nil).CommitChangeProof), arg0, arg1) } // CommitRangeProof mocks base method. func (m *MockMerkleDB) CommitRangeProof(arg0 context.Context, arg1, arg2 maybe.Maybe[[]uint8], arg3 *RangeProof) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CommitRangeProof", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CommitRangeProof", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(error) + return ret0 } // CommitRangeProof indicates an expected call of CommitRangeProof. func (mr *MockMerkleDBMockRecorder) CommitRangeProof(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CommitRangeProof", reflect.TypeOf((*MockMerkleDB)(nil).CommitRangeProof), arg0, arg1, arg2, arg3) + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CommitRangeProof", reflect.TypeOf((*MockMerkleDB)(nil).CommitRangeProof), arg0, arg1, arg2, arg3) } // Compact mocks base method. func (m *MockMerkleDB) Compact(arg0, arg1 []byte) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Compact", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Compact", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 } // Compact indicates an expected call of Compact. func (mr *MockMerkleDBMockRecorder) Compact(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Compact", reflect.TypeOf((*MockMerkleDB)(nil).Compact), arg0, arg1) + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Compact", reflect.TypeOf((*MockMerkleDB)(nil).Compact), arg0, arg1) } // Delete mocks base method. func (m *MockMerkleDB) Delete(arg0 []byte) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Delete", arg0) - ret0, _ := ret[0].(error) - return ret0 + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Delete", arg0) + ret0, _ := ret[0].(error) + return ret0 } // Delete indicates an expected call of Delete. func (mr *MockMerkleDBMockRecorder) Delete(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockMerkleDB)(nil).Delete), arg0) + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockMerkleDB)(nil).Delete), arg0) } // Get mocks base method. func (m *MockMerkleDB) Get(arg0 []byte) ([]byte, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get", arg0) - ret0, _ := ret[0].([]byte) - ret1, _ := ret[1].(error) - return ret0, ret1 + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", arg0) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 } // Get indicates an expected call of Get. func (mr *MockMerkleDBMockRecorder) Get(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockMerkleDB)(nil).Get), arg0) + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockMerkleDB)(nil).Get), arg0) } // GetChangeProof mocks base method. func (m *MockMerkleDB) GetChangeProof(arg0 context.Context, arg1, arg2 ids.ID, arg3, arg4 maybe.Maybe[[]uint8], arg5 int) (*ChangeProof, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetChangeProof", arg0, arg1, arg2, arg3, arg4, arg5) - ret0, _ := ret[0].(*ChangeProof) - ret1, _ := ret[1].(error) - return ret0, ret1 + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetChangeProof", arg0, arg1, arg2, arg3, arg4, arg5) + ret0, _ := ret[0].(*ChangeProof) + ret1, _ := ret[1].(error) + return ret0, ret1 } // GetChangeProof indicates an expected call of GetChangeProof. func (mr *MockMerkleDBMockRecorder) GetChangeProof(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChangeProof", reflect.TypeOf((*MockMerkleDB)(nil).GetChangeProof), arg0, arg1, arg2, arg3, arg4, arg5) + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChangeProof", reflect.TypeOf((*MockMerkleDB)(nil).GetChangeProof), arg0, arg1, arg2, arg3, arg4, arg5) } // GetMerkleRoot mocks base method. func (m *MockMerkleDB) GetMerkleRoot(arg0 context.Context) (ids.ID, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetMerkleRoot", arg0) - ret0, _ := ret[0].(ids.ID) - ret1, _ := ret[1].(error) - return ret0, ret1 + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMerkleRoot", arg0) + ret0, _ := ret[0].(ids.ID) + ret1, _ := ret[1].(error) + return ret0, ret1 } // GetMerkleRoot indicates an expected call of GetMerkleRoot. func (mr *MockMerkleDBMockRecorder) GetMerkleRoot(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMerkleRoot", reflect.TypeOf((*MockMerkleDB)(nil).GetMerkleRoot), arg0) + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMerkleRoot", reflect.TypeOf((*MockMerkleDB)(nil).GetMerkleRoot), arg0) } // GetProof mocks base method. func (m *MockMerkleDB) GetProof(arg0 context.Context, arg1 []byte) (*Proof, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetProof", arg0, arg1) - ret0, _ := ret[0].(*Proof) - ret1, _ := ret[1].(error) - return ret0, ret1 + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetProof", arg0, arg1) + ret0, _ := ret[0].(*Proof) + ret1, _ := ret[1].(error) + return ret0, ret1 } // GetProof indicates an expected call of GetProof. func (mr *MockMerkleDBMockRecorder) GetProof(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProof", reflect.TypeOf((*MockMerkleDB)(nil).GetProof), arg0, arg1) + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProof", reflect.TypeOf((*MockMerkleDB)(nil).GetProof), arg0, arg1) } // GetRangeProof mocks base method. func (m *MockMerkleDB) GetRangeProof(arg0 context.Context, arg1, arg2 maybe.Maybe[[]uint8], arg3 int) (*RangeProof, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetRangeProof", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(*RangeProof) - ret1, _ := ret[1].(error) - return ret0, ret1 + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRangeProof", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(*RangeProof) + ret1, _ := ret[1].(error) + return ret0, ret1 } // GetRangeProof indicates an expected call of GetRangeProof. func (mr *MockMerkleDBMockRecorder) GetRangeProof(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRangeProof", reflect.TypeOf((*MockMerkleDB)(nil).GetRangeProof), arg0, arg1, arg2, arg3) + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRangeProof", reflect.TypeOf((*MockMerkleDB)(nil).GetRangeProof), arg0, arg1, arg2, arg3) } // GetRangeProofAtRoot mocks base method. func (m *MockMerkleDB) GetRangeProofAtRoot(arg0 context.Context, arg1 ids.ID, arg2, arg3 maybe.Maybe[[]uint8], arg4 int) (*RangeProof, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetRangeProofAtRoot", arg0, arg1, arg2, arg3, arg4) - ret0, _ := ret[0].(*RangeProof) - ret1, _ := ret[1].(error) - return ret0, ret1 + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRangeProofAtRoot", arg0, arg1, arg2, arg3, arg4) + ret0, _ := ret[0].(*RangeProof) + ret1, _ := ret[1].(error) + return ret0, ret1 } // GetRangeProofAtRoot indicates an expected call of GetRangeProofAtRoot. func (mr *MockMerkleDBMockRecorder) GetRangeProofAtRoot(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRangeProofAtRoot", reflect.TypeOf((*MockMerkleDB)(nil).GetRangeProofAtRoot), arg0, arg1, arg2, arg3, arg4) + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRangeProofAtRoot", reflect.TypeOf((*MockMerkleDB)(nil).GetRangeProofAtRoot), arg0, arg1, arg2, arg3, arg4) } // GetValue mocks base method. func (m *MockMerkleDB) GetValue(arg0 context.Context, arg1 []byte) ([]byte, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetValue", arg0, arg1) - ret0, _ := ret[0].([]byte) - ret1, _ := ret[1].(error) - return ret0, ret1 + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetValue", arg0, arg1) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 } // GetValue indicates an expected call of GetValue. func (mr *MockMerkleDBMockRecorder) GetValue(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValue", reflect.TypeOf((*MockMerkleDB)(nil).GetValue), arg0, arg1) + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValue", reflect.TypeOf((*MockMerkleDB)(nil).GetValue), arg0, arg1) } // GetValues mocks base method. func (m *MockMerkleDB) GetValues(arg0 context.Context, arg1 [][]byte) ([][]byte, []error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetValues", arg0, arg1) - ret0, _ := ret[0].([][]byte) - ret1, _ := ret[1].([]error) - return ret0, ret1 + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetValues", arg0, arg1) + ret0, _ := ret[0].([][]byte) + ret1, _ := ret[1].([]error) + return ret0, ret1 } // GetValues indicates an expected call of GetValues. func (mr *MockMerkleDBMockRecorder) GetValues(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValues", reflect.TypeOf((*MockMerkleDB)(nil).GetValues), arg0, arg1) + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValues", reflect.TypeOf((*MockMerkleDB)(nil).GetValues), arg0, arg1) } // Has mocks base method. func (m *MockMerkleDB) Has(arg0 []byte) (bool, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Has", arg0) - ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(error) - return ret0, ret1 + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Has", arg0) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 } // Has indicates an expected call of Has. func (mr *MockMerkleDBMockRecorder) Has(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Has", reflect.TypeOf((*MockMerkleDB)(nil).Has), arg0) + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Has", reflect.TypeOf((*MockMerkleDB)(nil).Has), arg0) } // HealthCheck mocks base method. func (m *MockMerkleDB) HealthCheck(arg0 context.Context) (interface{}, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HealthCheck", arg0) - ret0, _ := ret[0].(interface{}) - ret1, _ := ret[1].(error) - return ret0, ret1 + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HealthCheck", arg0) + ret0, _ := ret[0].(interface{}) + ret1, _ := ret[1].(error) + return ret0, ret1 } // HealthCheck indicates an expected call of HealthCheck. func (mr *MockMerkleDBMockRecorder) HealthCheck(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HealthCheck", reflect.TypeOf((*MockMerkleDB)(nil).HealthCheck), arg0) + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HealthCheck", reflect.TypeOf((*MockMerkleDB)(nil).HealthCheck), arg0) } // NewBatch mocks base method. func (m *MockMerkleDB) NewBatch() database.Batch { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NewBatch") - ret0, _ := ret[0].(database.Batch) - return ret0 + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewBatch") + ret0, _ := ret[0].(database.Batch) + return ret0 } // NewBatch indicates an expected call of NewBatch. func (mr *MockMerkleDBMockRecorder) NewBatch() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewBatch", reflect.TypeOf((*MockMerkleDB)(nil).NewBatch)) + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewBatch", reflect.TypeOf((*MockMerkleDB)(nil).NewBatch)) } // NewIterator mocks base method. func (m *MockMerkleDB) NewIterator() database.Iterator { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NewIterator") - ret0, _ := ret[0].(database.Iterator) - return ret0 + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewIterator") + ret0, _ := ret[0].(database.Iterator) + return ret0 } // NewIterator indicates an expected call of NewIterator. func (mr *MockMerkleDBMockRecorder) NewIterator() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewIterator", reflect.TypeOf((*MockMerkleDB)(nil).NewIterator)) + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewIterator", reflect.TypeOf((*MockMerkleDB)(nil).NewIterator)) } // NewIteratorWithPrefix mocks base method. func (m *MockMerkleDB) NewIteratorWithPrefix(arg0 []byte) database.Iterator { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NewIteratorWithPrefix", arg0) - ret0, _ := ret[0].(database.Iterator) - return ret0 + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewIteratorWithPrefix", arg0) + ret0, _ := ret[0].(database.Iterator) + return ret0 } // NewIteratorWithPrefix indicates an expected call of NewIteratorWithPrefix. func (mr *MockMerkleDBMockRecorder) NewIteratorWithPrefix(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewIteratorWithPrefix", reflect.TypeOf((*MockMerkleDB)(nil).NewIteratorWithPrefix), arg0) + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewIteratorWithPrefix", reflect.TypeOf((*MockMerkleDB)(nil).NewIteratorWithPrefix), arg0) } // NewIteratorWithStart mocks base method. func (m *MockMerkleDB) NewIteratorWithStart(arg0 []byte) database.Iterator { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NewIteratorWithStart", arg0) - ret0, _ := ret[0].(database.Iterator) - return ret0 + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewIteratorWithStart", arg0) + ret0, _ := ret[0].(database.Iterator) + return ret0 } // NewIteratorWithStart indicates an expected call of NewIteratorWithStart. func (mr *MockMerkleDBMockRecorder) NewIteratorWithStart(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewIteratorWithStart", reflect.TypeOf((*MockMerkleDB)(nil).NewIteratorWithStart), arg0) + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewIteratorWithStart", reflect.TypeOf((*MockMerkleDB)(nil).NewIteratorWithStart), arg0) } // NewIteratorWithStartAndPrefix mocks base method. func (m *MockMerkleDB) NewIteratorWithStartAndPrefix(arg0, arg1 []byte) database.Iterator { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NewIteratorWithStartAndPrefix", arg0, arg1) - ret0, _ := ret[0].(database.Iterator) - return ret0 + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewIteratorWithStartAndPrefix", arg0, arg1) + ret0, _ := ret[0].(database.Iterator) + return ret0 } // NewIteratorWithStartAndPrefix indicates an expected call of NewIteratorWithStartAndPrefix. func (mr *MockMerkleDBMockRecorder) NewIteratorWithStartAndPrefix(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewIteratorWithStartAndPrefix", reflect.TypeOf((*MockMerkleDB)(nil).NewIteratorWithStartAndPrefix), arg0, arg1) + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewIteratorWithStartAndPrefix", reflect.TypeOf((*MockMerkleDB)(nil).NewIteratorWithStartAndPrefix), arg0, arg1) } // NewView mocks base method. func (m *MockMerkleDB) NewView(arg0 context.Context, arg1 ViewChanges) (TrieView, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NewView", arg0, arg1) - ret0, _ := ret[0].(TrieView) - ret1, _ := ret[1].(error) - return ret0, ret1 + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewView", arg0, arg1) + ret0, _ := ret[0].(TrieView) + ret1, _ := ret[1].(error) + return ret0, ret1 } // NewView indicates an expected call of NewView. func (mr *MockMerkleDBMockRecorder) NewView(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewView", reflect.TypeOf((*MockMerkleDB)(nil).NewView), arg0, arg1) + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewView", reflect.TypeOf((*MockMerkleDB)(nil).NewView), arg0, arg1) } // PrefetchPath mocks base method. func (m *MockMerkleDB) PrefetchPath(arg0 []byte) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PrefetchPath", arg0) - ret0, _ := ret[0].(error) - return ret0 + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PrefetchPath", arg0) + ret0, _ := ret[0].(error) + return ret0 } // PrefetchPath indicates an expected call of PrefetchPath. func (mr *MockMerkleDBMockRecorder) PrefetchPath(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PrefetchPath", reflect.TypeOf((*MockMerkleDB)(nil).PrefetchPath), arg0) + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PrefetchPath", reflect.TypeOf((*MockMerkleDB)(nil).PrefetchPath), arg0) } // PrefetchPaths mocks base method. func (m *MockMerkleDB) PrefetchPaths(arg0 [][]byte) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PrefetchPaths", arg0) - ret0, _ := ret[0].(error) - return ret0 + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PrefetchPaths", arg0) + ret0, _ := ret[0].(error) + return ret0 } // PrefetchPaths indicates an expected call of PrefetchPaths. func (mr *MockMerkleDBMockRecorder) PrefetchPaths(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PrefetchPaths", reflect.TypeOf((*MockMerkleDB)(nil).PrefetchPaths), arg0) + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PrefetchPaths", reflect.TypeOf((*MockMerkleDB)(nil).PrefetchPaths), arg0) } // Put mocks base method. func (m *MockMerkleDB) Put(arg0, arg1 []byte) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Put", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Put", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 } // Put indicates an expected call of Put. func (mr *MockMerkleDBMockRecorder) Put(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockMerkleDB)(nil).Put), arg0, arg1) + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockMerkleDB)(nil).Put), arg0, arg1) } // VerifyChangeProof mocks base method. func (m *MockMerkleDB) VerifyChangeProof(arg0 context.Context, arg1 *ChangeProof, arg2, arg3 maybe.Maybe[[]uint8], arg4 ids.ID) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "VerifyChangeProof", arg0, arg1, arg2, arg3, arg4) - ret0, _ := ret[0].(error) - return ret0 + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "VerifyChangeProof", arg0, arg1, arg2, arg3, arg4) + ret0, _ := ret[0].(error) + return ret0 } // VerifyChangeProof indicates an expected call of VerifyChangeProof. func (mr *MockMerkleDBMockRecorder) VerifyChangeProof(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "VerifyChangeProof", reflect.TypeOf((*MockMerkleDB)(nil).VerifyChangeProof), arg0, arg1, arg2, arg3, arg4) + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "VerifyChangeProof", reflect.TypeOf((*MockMerkleDB)(nil).VerifyChangeProof), arg0, arg1, arg2, arg3, arg4) } // getEditableNode mocks base method. func (m *MockMerkleDB) getEditableNode(arg0 Key, arg1 bool) (*node, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "getEditableNode", arg0, arg1) - ret0, _ := ret[0].(*node) - ret1, _ := ret[1].(error) - return ret0, ret1 + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "getEditableNode", arg0, arg1) + ret0, _ := ret[0].(*node) + ret1, _ := ret[1].(error) + return ret0, ret1 } // getEditableNode indicates an expected call of getEditableNode. func (mr *MockMerkleDBMockRecorder) getEditableNode(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getEditableNode", reflect.TypeOf((*MockMerkleDB)(nil).getEditableNode), arg0, arg1) + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getEditableNode", reflect.TypeOf((*MockMerkleDB)(nil).getEditableNode), arg0, arg1) +} + +// getRoot mocks base method. +func (m *MockMerkleDB) getRoot() maybe.Maybe[*node] { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "getRoot") + ret0, _ := ret[0].(maybe.Maybe[*node]) + return ret0 +} + +// getRoot indicates an expected call of getRoot. +func (mr *MockMerkleDBMockRecorder) getRoot() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getRoot", reflect.TypeOf((*MockMerkleDB)(nil).getRoot)) } // getValue mocks base method. func (m *MockMerkleDB) getValue(arg0 Key) ([]byte, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "getValue", arg0) - ret0, _ := ret[0].([]byte) - ret1, _ := ret[1].(error) - return ret0, ret1 + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "getValue", arg0) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 } // getValue indicates an expected call of getValue. func (mr *MockMerkleDBMockRecorder) getValue(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getValue", reflect.TypeOf((*MockMerkleDB)(nil).getValue), arg0) + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getValue", reflect.TypeOf((*MockMerkleDB)(nil).getValue), arg0) } diff --git a/x/merkledb/node.go b/x/merkledb/node.go index 9a63ef82c4a7..d15eb6ae7e14 100644 --- a/x/merkledb/node.go +++ b/x/merkledb/node.go @@ -105,10 +105,15 @@ func (n *node) setValueDigest() { // Assumes [child]'s key is valid as a child of [n]. // That is, [n.key] is a prefix of [child.key]. func (n *node) addChild(childNode *node, tokenSize int) { + n.addChildWithID(childNode, tokenSize, ids.Empty) +} + +func (n *node) addChildWithID(childNode *node, tokenSize int, childID ids.ID) { n.setChildEntry( childNode.key.Token(n.key.length, tokenSize), &child{ compressedKey: childNode.key.Skip(n.key.length + tokenSize), + id: childID, hasValue: childNode.hasValue(), }, ) diff --git a/x/merkledb/proof.go b/x/merkledb/proof.go index 39ceff3d3157..940c172deecf 100644 --- a/x/merkledb/proof.go +++ b/x/merkledb/proof.go @@ -31,11 +31,13 @@ var ( ErrNonIncreasingValues = errors.New("keys sent are not in increasing order") ErrStateFromOutsideOfRange = errors.New("state key falls outside of the start->end range") ErrNonIncreasingProofNodes = errors.New("each proof node key must be a strict prefix of the next") + ErrExtraProofNodes = errors.New("extra proof nodes in path") + ErrDataInMissingRootProof = errors.New("there should be no state or deleted keys in a change proof that had a missing root") + ErrEmptyProof = errors.New("proof is empty") ErrNoMerkleProof = errors.New("empty key response must include merkle proof") ErrShouldJustBeRoot = errors.New("end proof should only contain root") ErrNoStartProof = errors.New("no start proof") ErrNoEndProof = errors.New("no end proof") - ErrNoProof = errors.New("proof has no nodes") ErrProofNodeNotForKey = errors.New("the provided node has a key that is not a prefix of the specified key") ErrProofValueDoesntMatch = errors.New("the provided value does not match the proof node for the provided key's value") ErrProofNodeHasUnincludedValue = errors.New("the provided proof has a value for a key within the range that is not present in the provided key/values") @@ -121,7 +123,7 @@ func (node *ProofNode) UnmarshalProto(pbNode *pb.ProofNode) error { type Proof struct { // Nodes in the proof path from root --> target key // (or node that would be where key is if it doesn't exist). - // Must always be non-empty (i.e. have the root node). + // Always contains at least the root. Path []ProofNode // This is a proof that [key] exists/doesn't exist. Key Key @@ -137,8 +139,9 @@ type Proof struct { func (proof *Proof) Verify(ctx context.Context, expectedRootID ids.ID, tokenSize int) error { // Make sure the proof is well-formed. if len(proof.Path) == 0 { - return ErrNoProof + return ErrEmptyProof } + if err := verifyProofPath(proof.Path, maybe.Some(proof.Key)); err != nil { return err } @@ -249,16 +252,12 @@ type RangeProof struct { // they are also in [EndProof]. StartProof []ProofNode - // If no upper range bound was given, [KeyValues] is empty, - // and [StartProof] is non-empty, this is empty. - // - // If no upper range bound was given, [KeyValues] is empty, - // and [StartProof] is empty, this is the root. + // If no upper range bound was given and [KeyValues] is empty, this is empty. // - // If an upper range bound was given and [KeyValues] is empty, - // this is a proof for the upper range bound. + // If no upper range bound was given and [KeyValues] is non-empty, this is + // a proof for the largest key in [KeyValues]. // - // Otherwise, this is a proof for the largest key in [KeyValues]. + // Otherwise this is a proof for the upper range bound. EndProof []ProofNode // This proof proves that the key-value pairs in [KeyValues] are in the trie. @@ -287,11 +286,9 @@ func (proof *RangeProof) Verify( case start.HasValue() && end.HasValue() && bytes.Compare(start.Value(), end.Value()) > 0: return ErrStartAfterEnd case len(proof.KeyValues) == 0 && len(proof.StartProof) == 0 && len(proof.EndProof) == 0: - return ErrNoMerkleProof - case end.IsNothing() && len(proof.KeyValues) == 0 && len(proof.StartProof) > 0 && len(proof.EndProof) != 0: + return ErrEmptyProof + case end.IsNothing() && len(proof.KeyValues) == 0 && len(proof.EndProof) != 0: return ErrUnexpectedEndProof - case end.IsNothing() && len(proof.KeyValues) == 0 && len(proof.StartProof) == 0 && len(proof.EndProof) != 1: - return ErrShouldJustBeRoot case len(proof.EndProof) == 0 && (end.HasValue() || len(proof.KeyValues) > 0): return ErrNoEndProof } diff --git a/x/merkledb/proof_test.go b/x/merkledb/proof_test.go index fbee117d4e68..e00326a56408 100644 --- a/x/merkledb/proof_test.go +++ b/x/merkledb/proof_test.go @@ -24,7 +24,7 @@ import ( func Test_Proof_Empty(t *testing.T) { proof := &Proof{} err := proof.Verify(context.Background(), ids.Empty, 4) - require.ErrorIs(t, err, ErrNoProof) + require.ErrorIs(t, err, ErrEmptyProof) } func Test_Proof_Simple(t *testing.T) { @@ -59,6 +59,13 @@ func Test_Proof_Verify_Bad_Data(t *testing.T) { malform: func(proof *Proof) {}, expectedErr: nil, }, + { + name: "empty", + malform: func(proof *Proof) { + proof.Path = nil + }, + expectedErr: ErrEmptyProof, + }, { name: "odd length key path with value", malform: func(proof *Proof) { @@ -150,7 +157,7 @@ func Test_RangeProof_Extra_Value(t *testing.T) { context.Background(), maybe.Some([]byte{1}), maybe.Some([]byte{5, 5}), - db.getMerkleRoot(), + db.rootID, db.tokenSize, )) @@ -160,7 +167,7 @@ func Test_RangeProof_Extra_Value(t *testing.T) { context.Background(), maybe.Some([]byte{1}), maybe.Some([]byte{5, 5}), - db.getMerkleRoot(), + db.rootID, db.tokenSize, ) require.ErrorIs(err, ErrInvalidProof) @@ -179,6 +186,15 @@ func Test_RangeProof_Verify_Bad_Data(t *testing.T) { malform: func(proof *RangeProof) {}, expectedErr: nil, }, + { + name: "empty", + malform: func(proof *RangeProof) { + proof.KeyValues = nil + proof.StartProof = nil + proof.EndProof = nil + }, + expectedErr: ErrEmptyProof, + }, { name: "StartProof: last proof node has missing value", malform: func(proof *RangeProof) { @@ -276,6 +292,8 @@ func Test_Proof(t *testing.T) { require.Equal(ToKey([]byte("key")), proof.Path[0].Key) require.Equal(maybe.Some([]byte("value")), proof.Path[0].ValueOrHash) + require.Equal(ToKey([]byte("key0")).Take(28), proof.Path[1].Key) + require.True(proof.Path[1].ValueOrHash.IsNothing()) // intermediate node require.Equal(ToKey([]byte("key1")), proof.Path[2].Key) require.Equal(maybe.Some([]byte("value1")), proof.Path[2].ValueOrHash) @@ -283,10 +301,9 @@ func Test_Proof(t *testing.T) { require.NoError(err) require.NoError(proof.Verify(context.Background(), expectedRootID, dbTrie.tokenSize)) - proof.Path[0].ValueOrHash = maybe.Some([]byte("value2")) - + proof.Path[0].Key = ToKey([]byte("key1")) err = proof.Verify(context.Background(), expectedRootID, dbTrie.tokenSize) - require.ErrorIs(err, ErrInvalidProof) + require.ErrorIs(err, ErrProofNodeNotForKey) } func Test_RangeProof_Syntactic_Verify(t *testing.T) { @@ -307,11 +324,11 @@ func Test_RangeProof_Syntactic_Verify(t *testing.T) { expectedErr: ErrStartAfterEnd, }, { - name: "empty", // Also tests start can be > end if end is nil + name: "empty", start: maybe.Some([]byte{1}), end: maybe.Nothing[[]byte](), proof: &RangeProof{}, - expectedErr: ErrNoMerkleProof, + expectedErr: ErrEmptyProof, }, { name: "unexpected end proof", @@ -323,15 +340,6 @@ func Test_RangeProof_Syntactic_Verify(t *testing.T) { }, expectedErr: ErrUnexpectedEndProof, }, - { - name: "should just be root", - start: maybe.Nothing[[]byte](), - end: maybe.Nothing[[]byte](), - proof: &RangeProof{ - EndProof: []ProofNode{{}, {}}, - }, - expectedErr: ErrShouldJustBeRoot, - }, { name: "no end proof; has end bound", start: maybe.Some([]byte{1}), @@ -501,7 +509,9 @@ func Test_RangeProof(t *testing.T) { require.Equal([]byte{2}, proof.KeyValues[1].Value) require.Equal([]byte{3}, proof.KeyValues[2].Value) + require.Len(proof.EndProof, 2) require.Equal([]byte{0}, proof.EndProof[0].Key.Bytes()) + require.Len(proof.EndProof[0].Children, 5) // 0,1,2,3,4 require.Equal([]byte{3}, proof.EndProof[1].Key.Bytes()) // only a single node here since others are duplicates in endproof @@ -511,7 +521,7 @@ func Test_RangeProof(t *testing.T) { context.Background(), maybe.Some([]byte{1}), maybe.Some([]byte{3, 5}), - db.getMerkleRoot(), + db.rootID, db.tokenSize, )) } @@ -522,6 +532,8 @@ func Test_RangeProof_BadBounds(t *testing.T) { db, err := getBasicDB() require.NoError(err) + require.NoError(db.Put(nil, nil)) + // non-nil start/end proof, err := db.GetRangeProof(context.Background(), maybe.Some([]byte{4}), maybe.Some([]byte{3}), 50) require.ErrorIs(err, ErrStartAfterEnd) @@ -556,14 +568,14 @@ func Test_RangeProof_NilStart(t *testing.T) { require.Equal([]byte("value1"), proof.KeyValues[0].Value) require.Equal([]byte("value2"), proof.KeyValues[1].Value) - require.Equal(ToKey([]byte("key2")), proof.EndProof[1].Key) + require.Equal(ToKey([]byte("key2")), proof.EndProof[1].Key, db.tokenSize) require.Equal(ToKey([]byte("key2")).Take(28), proof.EndProof[0].Key) require.NoError(proof.Verify( context.Background(), maybe.Nothing[[]byte](), maybe.Some([]byte("key35")), - db.getMerkleRoot(), + db.rootID, db.tokenSize, )) } @@ -573,10 +585,16 @@ func Test_RangeProof_NilEnd(t *testing.T) { db, err := getBasicDB() require.NoError(err) + writeBasicBatch(t, db) require.NoError(err) - proof, err := db.GetRangeProof(context.Background(), maybe.Some([]byte{1}), maybe.Nothing[[]byte](), 2) + proof, err := db.GetRangeProof( // Should have keys [1], [2] + context.Background(), + maybe.Some([]byte{1}), + maybe.Nothing[[]byte](), + 2, + ) require.NoError(err) require.NotNil(proof) @@ -590,14 +608,14 @@ func Test_RangeProof_NilEnd(t *testing.T) { require.Equal([]byte{1}, proof.StartProof[0].Key.Bytes()) - require.Equal([]byte{0}, proof.EndProof[0].Key.Bytes()) + require.Equal(db.root.Value().key, proof.EndProof[0].Key) require.Equal([]byte{2}, proof.EndProof[1].Key.Bytes()) require.NoError(proof.Verify( context.Background(), maybe.Some([]byte{1}), maybe.Nothing[[]byte](), - db.getMerkleRoot(), + db.rootID, db.tokenSize, )) } @@ -633,29 +651,68 @@ func Test_RangeProof_EmptyValues(t *testing.T) { require.Equal(ToKey([]byte("key1")), proof.StartProof[0].Key) require.Len(proof.EndProof, 2) - require.Equal(ToKey([]byte("key2")), proof.EndProof[1].Key) - require.Equal(ToKey([]byte("key2")).Take(28), proof.EndProof[0].Key) + require.Equal(ToKey([]byte("key1")).Take(28), proof.EndProof[0].Key, db.tokenSize) // root + require.Equal(ToKey([]byte("key2")), proof.EndProof[1].Key, db.tokenSize) require.NoError(proof.Verify( context.Background(), maybe.Some([]byte("key1")), maybe.Some([]byte("key2")), - db.getMerkleRoot(), + db.rootID, db.tokenSize, )) } func Test_ChangeProof_Missing_History_For_EndRoot(t *testing.T) { require := require.New(t) + seed := time.Now().UnixNano() + t.Logf("Seed: %d", seed) + rand := rand.New(rand.NewSource(seed)) // #nosec G404 db, err := getBasicDB() require.NoError(err) - startRoot, err := db.GetMerkleRoot(context.Background()) - require.NoError(err) - _, err = db.GetChangeProof(context.Background(), startRoot, ids.Empty, maybe.Nothing[[]byte](), maybe.Nothing[[]byte](), 50) + roots := []ids.ID{} + for i := 0; i < defaultHistoryLength+1; i++ { + key := make([]byte, 16) + _, _ = rand.Read(key) + require.NoError(db.Put(key, nil)) + root, err := db.GetMerkleRoot(context.Background()) + require.NoError(err) + roots = append(roots, root) + } + + _, err = db.GetChangeProof( + context.Background(), + roots[len(roots)-1], + ids.GenerateTestID(), + maybe.Nothing[[]byte](), + maybe.Nothing[[]byte](), + 50, + ) require.ErrorIs(err, ErrNoEndRoot) require.ErrorIs(err, ErrInsufficientHistory) + + _, err = db.GetChangeProof( + context.Background(), + roots[0], + roots[len(roots)-1], + maybe.Nothing[[]byte](), + maybe.Nothing[[]byte](), + 50, + ) + require.NotErrorIs(err, ErrNoEndRoot) + require.ErrorIs(err, ErrInsufficientHistory) + + _, err = db.GetChangeProof( + context.Background(), + roots[1], + roots[len(roots)-1], + maybe.Nothing[[]byte](), + maybe.Nothing[[]byte](), + 50, + ) + require.NoError(err) } func Test_ChangeProof_BadBounds(t *testing.T) { @@ -816,13 +873,26 @@ func Test_ChangeProof_Verify_Bad_Data(t *testing.T) { dbClone, err := getBasicDB() require.NoError(err) - proof, err := db.GetChangeProof(context.Background(), startRoot, endRoot, maybe.Some([]byte{2}), maybe.Some([]byte{3, 0}), 50) + proof, err := db.GetChangeProof( + context.Background(), + startRoot, + endRoot, + maybe.Some([]byte{2}), + maybe.Some([]byte{3, 0}), + 50, + ) require.NoError(err) require.NotNil(proof) tt.malform(proof) - err = dbClone.VerifyChangeProof(context.Background(), proof, maybe.Some([]byte{2}), maybe.Some([]byte{3, 0}), db.getMerkleRoot()) + err = dbClone.VerifyChangeProof( + context.Background(), + proof, + maybe.Some([]byte{2}), + maybe.Some([]byte{3, 0}), + db.getMerkleRoot(), + ) require.ErrorIs(err, tt.expectedErr) }) } @@ -850,7 +920,7 @@ func Test_ChangeProof_Syntactic_Verify(t *testing.T) { proof: &ChangeProof{}, start: maybe.Nothing[[]byte](), end: maybe.Nothing[[]byte](), - expectedErr: ErrNoMerkleProof, + expectedErr: ErrEmptyProof, }, { name: "no end proof", @@ -1627,6 +1697,9 @@ func FuzzRangeProofInvariants(f *testing.F) { if maxProofLen == 0 { t.SkipNow() } + if numKeyValues == 0 { + t.SkipNow() + } // Make sure proof bounds are valid if len(endBytes) != 0 && bytes.Compare(startBytes, endBytes) > 0 { @@ -1657,15 +1730,19 @@ func FuzzRangeProofInvariants(f *testing.F) { end = maybe.Some(endBytes) } + rootID, err := db.GetMerkleRoot(context.Background()) + require.NoError(err) + rangeProof, err := db.GetRangeProof( context.Background(), start, end, int(maxProofLen), ) - require.NoError(err) - - rootID, err := db.GetMerkleRoot(context.Background()) + if rootID == ids.Empty { + require.ErrorIs(err, ErrEmptyProof) + return + } require.NoError(err) require.NoError(rangeProof.Verify( @@ -1761,10 +1838,15 @@ func FuzzProofVerification(f *testing.F) { deletePortion, ) + if db.getMerkleRoot() == ids.Empty { + return + } + proof, err := db.GetProof( context.Background(), key, ) + require.NoError(err) rootID, err := db.GetMerkleRoot(context.Background()) diff --git a/x/merkledb/trie.go b/x/merkledb/trie.go index d4b01d2de29a..1e6870df27cb 100644 --- a/x/merkledb/trie.go +++ b/x/merkledb/trie.go @@ -12,13 +12,15 @@ import ( ) type MerkleRootGetter interface { - // GetMerkleRoot returns the merkle root of the Trie + // GetMerkleRoot returns the merkle root of the trie. + // Returns ids.Empty if the trie is empty. GetMerkleRoot(ctx context.Context) (ids.ID, error) } type ProofGetter interface { // GetProof generates a proof of the value associated with a particular key, // or a proof of its absence from the trie + // Returns ErrEmptyProof if the trie is empty. GetProof(ctx context.Context, keyBytes []byte) (*Proof, error) } @@ -38,6 +40,11 @@ type ReadOnlyTrie interface { // database.ErrNotFound if the key is not present getValue(key Key) ([]byte, error) + // If this trie is non-empty, returns the root node. + // Must be copied before modification. + // Otherwise returns Nothing. + getRoot() maybe.Maybe[*node] + // get an editable copy of the node with the given key path // hasValue indicates which db to look in (value or intermediate) getEditableNode(key Key, hasValue bool) (*node, error) @@ -46,6 +53,7 @@ type ReadOnlyTrie interface { // keys in range [start, end]. // If [start] is Nothing, there's no lower bound on the range. // If [end] is Nothing, there's no upper bound on the range. + // Returns ErrEmptyProof if the trie is empty. GetRangeProof(ctx context.Context, start maybe.Maybe[[]byte], end maybe.Maybe[[]byte], maxLength int) (*RangeProof, error) database.Iteratee diff --git a/x/merkledb/trie_test.go b/x/merkledb/trie_test.go index a431dd6b254d..fec8a435d60a 100644 --- a/x/merkledb/trie_test.go +++ b/x/merkledb/trie_test.go @@ -124,9 +124,7 @@ func TestTrieViewVisitPathToKey(t *testing.T) { return nil })) - // Just the root - require.Len(nodePath, 1) - require.Equal(trie.sentinelNode, nodePath[0]) + require.Empty(nodePath) // Insert a key key1 := []byte{0} @@ -143,17 +141,15 @@ func TestTrieViewVisitPathToKey(t *testing.T) { trie = trieIntf.(*trieView) require.NoError(trie.calculateNodeIDs(context.Background())) - nodePath = make([]*node, 0, 2) + nodePath = make([]*node, 0, 1) require.NoError(trie.visitPathToKey(ToKey(key1), func(n *node) error { nodePath = append(nodePath, n) return nil })) - // Root and 1 value - require.Len(nodePath, 2) - - require.Equal(trie.sentinelNode, nodePath[0]) - require.Equal(ToKey(key1), nodePath[1].key) + // 1 value + require.Len(nodePath, 1) + require.Equal(ToKey(key1), nodePath[0].key) // Insert another key which is a child of the first key2 := []byte{0, 1} @@ -170,17 +166,20 @@ func TestTrieViewVisitPathToKey(t *testing.T) { trie = trieIntf.(*trieView) require.NoError(trie.calculateNodeIDs(context.Background())) - nodePath = make([]*node, 0, 3) + nodePath = make([]*node, 0, 2) require.NoError(trie.visitPathToKey(ToKey(key2), func(n *node) error { nodePath = append(nodePath, n) return nil })) - require.Len(nodePath, 3) - - require.Equal(trie.sentinelNode, nodePath[0]) - require.Equal(ToKey(key1), nodePath[1].key) - require.Equal(ToKey(key2), nodePath[2].key) - + require.Len(nodePath, 2) + require.Equal(trie.root.Value(), nodePath[0]) + require.Equal(ToKey(key1), nodePath[0].key) + require.Equal(ToKey(key2), nodePath[1].key) + + // Trie is: + // [0] + // | + // [0,1] // Insert a key which shares no prefix with the others key3 := []byte{255} trieIntf, err = trie.NewView( @@ -196,6 +195,12 @@ func TestTrieViewVisitPathToKey(t *testing.T) { trie = trieIntf.(*trieView) require.NoError(trie.calculateNodeIDs(context.Background())) + // Trie is: + // [] + // / \ + // [0] [255] + // | + // [0,1] nodePath = make([]*node, 0, 2) require.NoError(trie.visitPathToKey(ToKey(key3), func(n *node) error { nodePath = append(nodePath, n) @@ -203,8 +208,8 @@ func TestTrieViewVisitPathToKey(t *testing.T) { })) require.Len(nodePath, 2) - - require.Equal(trie.sentinelNode, nodePath[0]) + require.Equal(trie.root.Value(), nodePath[0]) + require.Zero(trie.root.Value().key.length) require.Equal(ToKey(key3), nodePath[1].key) // Other key path not affected @@ -214,8 +219,7 @@ func TestTrieViewVisitPathToKey(t *testing.T) { return nil })) require.Len(nodePath, 3) - - require.Equal(trie.sentinelNode, nodePath[0]) + require.Equal(trie.root.Value(), nodePath[0]) require.Equal(ToKey(key1), nodePath[1].key) require.Equal(ToKey(key2), nodePath[2].key) @@ -228,7 +232,7 @@ func TestTrieViewVisitPathToKey(t *testing.T) { })) require.Len(nodePath, 3) - require.Equal(trie.sentinelNode, nodePath[0]) + require.Equal(trie.root.Value(), nodePath[0]) require.Equal(ToKey(key1), nodePath[1].key) require.Equal(ToKey(key2), nodePath[2].key) @@ -240,7 +244,7 @@ func TestTrieViewVisitPathToKey(t *testing.T) { return nil })) require.Len(nodePath, 1) - require.Equal(trie.sentinelNode, nodePath[0]) + require.Equal(trie.root.Value(), nodePath[0]) } func Test_Trie_ViewOnCommitedView(t *testing.T) { @@ -490,7 +494,7 @@ func Test_Trie_ExpandOnKeyPath(t *testing.T) { require.Equal([]byte("value12"), value) } -func Test_Trie_compressedKeys(t *testing.T) { +func Test_Trie_CompressedKeys(t *testing.T) { require := require.New(t) dbTrie, err := getBasicDB() @@ -589,9 +593,9 @@ func Test_Trie_HashCountOnBranch(t *testing.T) { require.NoError(err) require.NotNil(dbTrie) - key1, key2, keyPrefix := []byte("key12"), []byte("key1F"), []byte("key1") + key1, key2, keyPrefix := []byte("12"), []byte("1F"), []byte("1") - trieIntf, err := dbTrie.NewView( + view1, err := dbTrie.NewView( context.Background(), ViewChanges{ BatchOps: []database.BatchOp{ @@ -599,11 +603,13 @@ func Test_Trie_HashCountOnBranch(t *testing.T) { }, }) require.NoError(err) - trie := trieIntf.(*trieView) + + // trie is: + // [1] // create new node with common prefix whose children // are key1, key2 - view2, err := trie.NewView( + view2, err := view1.NewView( context.Background(), ViewChanges{ BatchOps: []database.BatchOp{ @@ -612,20 +618,27 @@ func Test_Trie_HashCountOnBranch(t *testing.T) { }) require.NoError(err) + // trie is: + // [1] + // / \ + // [12] [1F] + // clear the hash count to ignore setup dbTrie.metrics.(*mockMetrics).hashCount = 0 - // force the new root to calculate + // calculate the root _, err = view2.GetMerkleRoot(context.Background()) require.NoError(err) - // Make sure the branch node with the common prefix was created. + // Make sure the root is an intermediate node with the expected common prefix. // Note it's only created on call to GetMerkleRoot, not in NewView. - _, err = view2.getEditableNode(ToKey(keyPrefix), false) + prefixNode, err := view2.getEditableNode(ToKey(keyPrefix), false) require.NoError(err) + root := view2.getRoot().Value() + require.Equal(root, prefixNode) + require.Len(root.children, 2) - // only hashes the new branch node, the new child node, and root - // shouldn't hash the existing node + // Had to hash each of the new nodes ("12" and "1F") and the new root require.Equal(int64(3), dbTrie.metrics.(*mockMetrics).hashCount) } @@ -667,7 +680,16 @@ func Test_Trie_HashCountOnDelete(t *testing.T) { require.NoError(err) require.NoError(view.CommitToDB(context.Background())) - // the root is the only updated node so only one new hash + // trie is: + // [key0] (first 28 bits) + // / \ + // [key1] [key2] + root := view.getRoot().Value() + expectedRootKey := ToKey([]byte("key0")).Take(28) + require.Equal(expectedRootKey, root.key) + require.Len(root.children, 2) + + // Had to hash the new root but not [key1] or [key2] nodes require.Equal(oldCount+1, dbTrie.metrics.(*mockMetrics).hashCount) } @@ -762,9 +784,11 @@ func Test_Trie_ChainDeletion(t *testing.T) { require.NoError(err) require.NoError(newTrie.(*trieView).calculateNodeIDs(context.Background())) - root, err := newTrie.getEditableNode(Key{}, false) + maybeRoot := newTrie.getRoot() require.NoError(err) - require.Len(root.children, 1) + require.True(maybeRoot.HasValue()) + require.Equal([]byte("value0"), maybeRoot.Value().value.Value()) + require.Len(maybeRoot.Value().children, 1) newTrie, err = newTrie.NewView( context.Background(), @@ -779,10 +803,10 @@ func Test_Trie_ChainDeletion(t *testing.T) { ) require.NoError(err) require.NoError(newTrie.(*trieView).calculateNodeIDs(context.Background())) - root, err = newTrie.getEditableNode(Key{}, false) - require.NoError(err) - // since all values have been deleted, the nodes should have been cleaned up - require.Empty(root.children) + + // trie should be empty + root := newTrie.getRoot() + require.False(root.HasValue()) } func Test_Trie_Invalidate_Siblings_On_Commit(t *testing.T) { @@ -829,54 +853,63 @@ func Test_Trie_NodeCollapse(t *testing.T) { require.NoError(err) require.NotNil(dbTrie) + kvs := []database.BatchOp{ + {Key: []byte("k"), Value: []byte("value0")}, + {Key: []byte("ke"), Value: []byte("value1")}, + {Key: []byte("key"), Value: []byte("value2")}, + {Key: []byte("key1"), Value: []byte("value3")}, + {Key: []byte("key2"), Value: []byte("value4")}, + } + trie, err := dbTrie.NewView( context.Background(), ViewChanges{ - BatchOps: []database.BatchOp{ - {Key: []byte("k"), Value: []byte("value0")}, - {Key: []byte("ke"), Value: []byte("value1")}, - {Key: []byte("key"), Value: []byte("value2")}, - {Key: []byte("key1"), Value: []byte("value3")}, - {Key: []byte("key2"), Value: []byte("value4")}, - }, + BatchOps: kvs, }, ) require.NoError(err) require.NoError(trie.(*trieView).calculateNodeIDs(context.Background())) - root, err := trie.getEditableNode(Key{}, false) - require.NoError(err) - require.Len(root.children, 1) - root, err = trie.getEditableNode(Key{}, false) - require.NoError(err) - require.Len(root.children, 1) + for _, kv := range kvs { + node, err := trie.getEditableNode(ToKey(kv.Key), true) + require.NoError(err) - firstNode, err := trie.getEditableNode(getSingleChildKey(root, dbTrie.tokenSize), true) - require.NoError(err) - require.Len(firstNode.children, 1) + require.Equal(kv.Value, node.value.Value()) + } + + // delete some values + deletedKVs, remainingKVs := kvs[:3], kvs[3:] + deleteOps := make([]database.BatchOp, len(deletedKVs)) + for i, kv := range deletedKVs { + deleteOps[i] = database.BatchOp{ + Key: kv.Key, + Delete: true, + } + } - // delete the middle values trie, err = trie.NewView( context.Background(), ViewChanges{ - BatchOps: []database.BatchOp{ - {Key: []byte("k"), Delete: true}, - {Key: []byte("ke"), Delete: true}, - {Key: []byte("key"), Delete: true}, - }, + BatchOps: deleteOps, }, ) require.NoError(err) + require.NoError(trie.(*trieView).calculateNodeIDs(context.Background())) - root, err = trie.getEditableNode(Key{}, false) - require.NoError(err) - require.Len(root.children, 1) + for _, kv := range deletedKVs { + _, err := trie.getEditableNode(ToKey(kv.Key), true) + require.ErrorIs(err, database.ErrNotFound) + } - firstNode, err = trie.getEditableNode(getSingleChildKey(root, dbTrie.tokenSize), true) - require.NoError(err) - require.Len(firstNode.children, 2) + // make sure the other values are still there + for _, kv := range remainingKVs { + node, err := trie.getEditableNode(ToKey(kv.Key), true) + require.NoError(err) + + require.Equal(kv.Value, node.value.Value()) + } } func Test_Trie_MultipleStates(t *testing.T) { diff --git a/x/merkledb/trieview.go b/x/merkledb/trieview.go index 730d4b0187d0..35be62f8a5f9 100644 --- a/x/merkledb/trieview.go +++ b/x/merkledb/trieview.go @@ -37,7 +37,7 @@ var ( ) ErrVisitPathToKey = errors.New("failed to visit expected node during insertion") ErrStartAfterEnd = errors.New("start key > end key") - ErrNoValidRoot = errors.New("a valid root was not provided to the trieView constructor") + ErrNoChanges = errors.New("no changes provided") ErrParentNotDatabase = errors.New("parent trie is not database") ErrNodesAlreadyCalculated = errors.New("cannot modify the trie after the node changes have been calculated") ) @@ -96,9 +96,8 @@ type trieView struct { db *merkleDB - // The nil key node - // It is either the root of the trie or the root of the trie is its single child node - sentinelNode *node + // The root of the trie represented by this view. + root maybe.Maybe[*node] tokenSize int } @@ -142,26 +141,17 @@ func (t *trieView) NewView( } // Creates a new view with the given [parentTrie]. -// Assumes [parentTrie] isn't locked. func newTrieView( db *merkleDB, parentTrie TrieView, changes ViewChanges, ) (*trieView, error) { - sentinelNode, err := parentTrie.getEditableNode(Key{}, false /* hasValue */) - if err != nil { - if errors.Is(err, database.ErrNotFound) { - return nil, ErrNoValidRoot - } - return nil, err - } - newView := &trieView{ - sentinelNode: sentinelNode, - db: db, - parentTrie: parentTrie, - changes: newChangeSummary(len(changes.BatchOps) + len(changes.MapOps)), - tokenSize: db.tokenSize, + root: maybe.Bind(parentTrie.getRoot(), (*node).clone), + db: db, + parentTrie: parentTrie, + changes: newChangeSummary(len(changes.BatchOps) + len(changes.MapOps)), + tokenSize: db.tokenSize, } for _, op := range changes.BatchOps { @@ -192,26 +182,22 @@ func newTrieView( return newView, nil } -// Creates a view of the db at a historical root using the provided changes +// Creates a view of the db at a historical root using the provided [changes]. +// Returns ErrNoChanges if [changes] is empty. func newHistoricalTrieView( db *merkleDB, changes *changeSummary, ) (*trieView, error) { if changes == nil { - return nil, ErrNoValidRoot - } - - passedSentinelChange, ok := changes.nodes[Key{}] - if !ok { - return nil, ErrNoValidRoot + return nil, ErrNoChanges } newView := &trieView{ - sentinelNode: passedSentinelChange.after, - db: db, - parentTrie: db, - changes: changes, - tokenSize: db.tokenSize, + root: changes.rootChange.after, + db: db, + parentTrie: db, + changes: changes, + tokenSize: db.tokenSize, } // since this is a set of historical changes, all nodes have already been calculated // since no new changes have occurred, no new calculations need to be done @@ -220,6 +206,10 @@ func newHistoricalTrieView( return newView, nil } +func (t *trieView) getRoot() maybe.Maybe[*node] { + return t.root +} + // Recalculates the node IDs for all changed nodes in the trie. // Cancelling [ctx] doesn't cancel calculation. It's used only for tracing. func (t *trieView) calculateNodeIDs(ctx context.Context) error { @@ -231,6 +221,8 @@ func (t *trieView) calculateNodeIDs(ctx context.Context) error { } defer t.nodesAlreadyCalculated.Set(true) + oldRoot := maybe.Bind(t.root, (*node).clone) + // We wait to create the span until after checking that we need to actually // calculateNodeIDs to make traces more useful (otherwise there may be a span // per key modified even though IDs are not re-calculated). @@ -250,15 +242,17 @@ func (t *trieView) calculateNodeIDs(ctx context.Context) error { } } - _ = t.db.calculateNodeIDsSema.Acquire(context.Background(), 1) - t.changes.rootID = t.calculateNodeIDsHelper(t.sentinelNode) - t.db.calculateNodeIDsSema.Release(1) + if !t.root.IsNothing() { + _ = t.db.calculateNodeIDsSema.Acquire(context.Background(), 1) + t.changes.rootID = t.calculateNodeIDsHelper(t.root.Value()) + t.db.calculateNodeIDsSema.Release(1) + } else { + t.changes.rootID = ids.Empty + } - // If the sentinel node is not the root, the trie's root is the sentinel node's only child - if !isSentinelNodeTheRoot(t.sentinelNode) { - for _, childEntry := range t.sentinelNode.children { - t.changes.rootID = childEntry.id - } + t.changes.rootChange = change[maybe.Maybe[*node]]{ + before: oldRoot, + after: t.root, } // ensure no ancestor changes occurred during execution @@ -320,11 +314,15 @@ func (t *trieView) GetProof(ctx context.Context, key []byte) (*Proof, error) { return t.getProof(ctx, key) } -// Returns a proof that [bytesPath] is in or not in trie [t]. +// Returns a proof that [key] is in or not in [t]. func (t *trieView) getProof(ctx context.Context, key []byte) (*Proof, error) { _, span := t.db.infoTracer.Start(ctx, "MerkleDB.trieview.getProof") defer span.End() + if t.root.IsNothing() { + return nil, ErrEmptyProof + } + proof := &Proof{ Key: ToKey(key), } @@ -332,26 +330,18 @@ func (t *trieView) getProof(ctx context.Context, key []byte) (*Proof, error) { var closestNode *node if err := t.visitPathToKey(proof.Key, func(n *node) error { closestNode = n + // From root --> node from left --> right. proof.Path = append(proof.Path, n.asProofNode()) return nil }); err != nil { return nil, err } - root, err := t.getRoot() - if err != nil { - return nil, err - } - // The sentinel node is always the first node in the path. - // If the sentinel node is not the root, remove it from the proofPath. - if root != t.sentinelNode { - proof.Path = proof.Path[1:] - - // if there are no nodes in the proof path, add the root to serve as an exclusion proof - if len(proof.Path) == 0 { - proof.Path = []ProofNode{root.asProofNode()} - return proof, nil - } + if len(proof.Path) == 0 { + // No key in [t] is a prefix of [key]. + // The root alone proves that [key] isn't in [t]. + proof.Path = append(proof.Path, t.root.Value().asProofNode()) + return proof, nil } if closestNode.key == proof.Key { @@ -407,6 +397,10 @@ func (t *trieView) GetRangeProof( return nil, err } + if t.root.IsNothing() { + return nil, ErrEmptyProof + } + var result RangeProof result.KeyValues = make([]KeyValue, 0, initKeyValuesSize) @@ -462,19 +456,6 @@ func (t *trieView) GetRangeProof( result.StartProof = result.StartProof[i:] } - if len(result.StartProof) == 0 && len(result.EndProof) == 0 && len(result.KeyValues) == 0 { - // If the range is empty, return the root proof. - root, err := t.getRoot() - if err != nil { - return nil, err - } - rootProof, err := t.getProof(ctx, root.key.Bytes()) - if err != nil { - return nil, err - } - result.EndProof = rootProof.Path - } - if t.isInvalid() { return nil, ErrInvalid } @@ -630,14 +611,14 @@ func (t *trieView) remove(key Key) error { keyNode, err := t.getNode(key, true) if err != nil { if errors.Is(err, database.ErrNotFound) { - // key didn't exist + // [key] isn't in the trie. return nil } return err } - // node doesn't contain a value if !keyNode.hasValue() { + // [key] doesn't have a value. return nil } @@ -655,43 +636,50 @@ func (t *trieView) remove(key Key) error { } nodeToDelete.setValue(maybe.Nothing[[]byte]()) - if len(nodeToDelete.children) != 0 { - // merge this node and its child into a single node if possible - return t.compressNodePath(parent, nodeToDelete) - } // if the removed node has no children, the node can be removed from the trie - if err := t.recordNodeDeleted(nodeToDelete); err != nil { - return err - } - if parent != nil { + if len(nodeToDelete.children) == 0 { + if err := t.recordNodeDeleted(nodeToDelete); err != nil { + return err + } + + if nodeToDelete.key == t.root.Value().key { + // We deleted the root. The trie is empty now. + t.root = maybe.Nothing[*node]() + return nil + } + + // Note [parent] != nil since [nodeToDelete.key] != [t.root.key]. + // i.e. There's the root and at least one more node. parent.removeChild(nodeToDelete, t.tokenSize) // merge the parent node and its child into a single node if possible return t.compressNodePath(grandParent, parent) } - return nil + + // merge this node and its descendants into a single node if possible + return t.compressNodePath(parent, nodeToDelete) } -// Merges together nodes in the inclusive descendants of [node] that +// Merges together nodes in the inclusive descendants of [n] that // have no value and a single child into one node with a compressed // path until a node that doesn't meet those criteria is reached. -// [parent] is [node]'s parent. +// [parent] is [n]'s parent. If [parent] is nil, [n] is the root +// node and [t.root] is updated to [n]. // Assumes at least one of the following is true: -// * [node] has a value. -// * [node] has children. +// * [n] has a value. +// * [n] has children. // Must not be called after [calculateNodeIDs] has returned. -func (t *trieView) compressNodePath(parent, node *node) error { +func (t *trieView) compressNodePath(parent, n *node) error { if t.nodesAlreadyCalculated.Get() { return ErrNodesAlreadyCalculated } - // don't collapse into this node if it's the root, doesn't have 1 child, or has a value - if parent == nil || len(node.children) != 1 || node.hasValue() { + if len(n.children) != 1 || n.hasValue() { return nil } - if err := t.recordNodeDeleted(node); err != nil { + if err := t.recordNodeDeleted(n); err != nil { return err } @@ -702,13 +690,20 @@ func (t *trieView) compressNodePath(parent, node *node) error { // There is only one child, but we don't know the index. // "Cycle" over the key/values to find the only child. // Note this iteration once because len(node.children) == 1. - for index, entry := range node.children { - childKey = node.key.Extend(ToToken(index, t.tokenSize), entry.compressedKey) + for index, entry := range n.children { + childKey = n.key.Extend(ToToken(index, t.tokenSize), entry.compressedKey) childEntry = entry } - // [node] is the first node with multiple children. - // combine it with the [node] passed in. + if parent == nil { + root, err := t.getNode(childKey, childEntry.hasValue) + if err != nil { + return err + } + t.root = maybe.Some(root) + return nil + } + parent.setChildEntry(childKey.Token(parent.key.length, t.tokenSize), &child{ compressedKey: childKey.Skip(parent.key.length + t.tokenSize), @@ -718,15 +713,21 @@ func (t *trieView) compressNodePath(parent, node *node) error { return t.recordNodeChange(parent) } -// Returns the nodes along the path to [key]. -// The first node is the root, and the last node is either the node with the -// given [key], if it's in the trie, or the node with the largest prefix of -// the [key] if it isn't in the trie. -// Always returns at least the root node. +// Calls [visitNode] on each node along the path to [key]. +// The first node (if any) is the root, and the last node is either the +// node with the given [key], if it's in [t], or the node with the +// largest prefix of [key] otherwise. func (t *trieView) visitPathToKey(key Key, visitNode func(*node) error) error { + if t.root.IsNothing() { + return nil + } + root := t.root.Value() + if !key.HasPrefix(root.key) { + return nil + } var ( - // all node paths start at the sentinelNode since its nil key is a prefix of all keys - currentNode = t.sentinelNode + // all node paths start at the root + currentNode = root err error ) if err := visitNode(currentNode); err != nil { @@ -785,14 +786,49 @@ func (t *trieView) insert( return nil, ErrNodesAlreadyCalculated } + if t.root.IsNothing() { + // the trie is empty, so create a new root node. + root := newNode(key) + root.setValue(value) + t.root = maybe.Some(root) + return root, t.recordNewNode(root) + } + + // Find the node that most closely matches [key]. var closestNode *node if err := t.visitPathToKey(key, func(n *node) error { closestNode = n + // Need to recalculate ID for all nodes on path to [key]. return t.recordNodeChange(n) }); err != nil { return nil, err } + if closestNode == nil { + // [t.root.key] isn't a prefix of [key]. + var ( + oldRoot = t.root.Value() + commonPrefixLength = getLengthOfCommonPrefix(oldRoot.key, key, 0 /*offset*/, t.tokenSize) + commonPrefix = oldRoot.key.Take(commonPrefixLength) + newRoot = newNode(commonPrefix) + oldRootID = oldRoot.calculateID(t.db.metrics) + ) + + // Call addChildWithID instead of addChild so the old root is added + // to the new root with the correct ID. + // TODO: + // [oldRootID] shouldn't need to be calculated here. + // Either oldRootID should already be calculated or will be calculated at the end with the other nodes + // Initialize the t.changes.rootID during newTrieView and then use that here instead of oldRootID + newRoot.addChildWithID(oldRoot, t.tokenSize, oldRootID) + if err := t.recordNewNode(newRoot); err != nil { + return nil, err + } + t.root = maybe.Some(newRoot) + + closestNode = newRoot + } + // a node with that exact key already exists so update its value if closestNode.key == key { closestNode.setValue(value) @@ -890,26 +926,9 @@ func (t *trieView) recordNodeChange(after *node) error { // Records that the node associated with the given key has been deleted. // Must not be called after [calculateNodeIDs] has returned. func (t *trieView) recordNodeDeleted(after *node) error { - // don't delete the root. - if after.key.length == 0 { - return t.recordKeyChange(after.key, after, after.hasValue(), false /* newNode */) - } return t.recordKeyChange(after.key, nil, after.hasValue(), false /* newNode */) } -func (t *trieView) getRoot() (*node, error) { - if !isSentinelNodeTheRoot(t.sentinelNode) { - // sentinelNode has one child, which is the root - for index, childEntry := range t.sentinelNode.children { - return t.getNode( - t.sentinelNode.key.Extend(ToToken(index, t.tokenSize), childEntry.compressedKey), - childEntry.hasValue) - } - } - - return t.sentinelNode, nil -} - // Records that the node associated with the given key has been changed. // If it is an existing node, record what its value was before it was changed. // Must not be called after [calculateNodeIDs] has returned. diff --git a/x/sync/client_test.go b/x/sync/client_test.go index f6c67debe5ee..0c71ccb52e75 100644 --- a/x/sync/client_test.go +++ b/x/sync/client_test.go @@ -339,11 +339,11 @@ func TestGetRangeProof(t *testing.T) { BytesLimit: defaultRequestByteSizeLimit, }, modifyResponse: func(response *merkledb.RangeProof) { - response.KeyValues = nil response.StartProof = nil response.EndProof = nil + response.KeyValues = nil }, - expectedErr: merkledb.ErrNoMerkleProof, + expectedErr: merkledb.ErrEmptyProof, }, } @@ -524,7 +524,7 @@ func TestGetChangeProof(t *testing.T) { newDefaultDBConfig(), ) require.NoError(t, err) - startRoot, err := serverDB.GetMerkleRoot(context.Background()) // TODO uncomment + startRoot, err := serverDB.GetMerkleRoot(context.Background()) require.NoError(t, err) // create changes @@ -566,6 +566,8 @@ func TestGetChangeProof(t *testing.T) { endRoot, err := serverDB.GetMerkleRoot(context.Background()) require.NoError(t, err) + fakeRootID := ids.GenerateTestID() + tests := map[string]struct { db DB request *pb.SyncGetChangeProofRequest @@ -662,24 +664,11 @@ func TestGetChangeProof(t *testing.T) { }, expectedErr: merkledb.ErrInvalidProof, }, - "range proof response happy path": { - request: &pb.SyncGetChangeProofRequest{ - // Server doesn't have the (non-existent) start root - // so should respond with range proof. - StartRootHash: ids.Empty[:], - EndRootHash: endRoot[:], - KeyLimit: defaultRequestKeyLimit, - BytesLimit: defaultRequestByteSizeLimit, - }, - modifyChangeProofResponse: nil, - expectedErr: nil, - expectRangeProof: true, - }, "range proof response; remove first key": { request: &pb.SyncGetChangeProofRequest{ // Server doesn't have the (non-existent) start root // so should respond with range proof. - StartRootHash: ids.Empty[:], + StartRootHash: fakeRootID[:], EndRootHash: endRoot[:], KeyLimit: defaultRequestKeyLimit, BytesLimit: defaultRequestByteSizeLimit, diff --git a/x/sync/g_db/db_client.go b/x/sync/g_db/db_client.go index af1ce1c9080b..376bff6aeab9 100644 --- a/x/sync/g_db/db_client.go +++ b/x/sync/g_db/db_client.go @@ -45,6 +45,10 @@ func (c *DBClient) GetChangeProof( endKey maybe.Maybe[[]byte], keyLimit int, ) (*merkledb.ChangeProof, error) { + if endRootID == ids.Empty { + return nil, merkledb.ErrEmptyProof + } + resp, err := c.client.GetChangeProof(ctx, &pb.GetChangeProofRequest{ StartRootHash: startRootID[:], EndRootHash: endRootID[:], @@ -136,6 +140,10 @@ func (c *DBClient) GetRangeProofAtRoot( endKey maybe.Maybe[[]byte], keyLimit int, ) (*merkledb.RangeProof, error) { + if rootID == ids.Empty { + return nil, merkledb.ErrEmptyProof + } + resp, err := c.client.GetRangeProof(ctx, &pb.GetRangeProofRequest{ RootHash: rootID[:], StartKey: &pb.MaybeBytes{ diff --git a/x/sync/manager.go b/x/sync/manager.go index a7a6858d5122..d094f33165cb 100644 --- a/x/sync/manager.go +++ b/x/sync/manager.go @@ -266,6 +266,18 @@ func (m *Manager) getAndApplyChangeProof(ctx context.Context, work *workItem) { return } + if targetRootID == ids.Empty { + // The trie is empty after this change. + // Delete all the key-value pairs in the range. + if err := m.config.DB.Clear(); err != nil { + m.setError(err) + return + } + work.start = maybe.Nothing[[]byte]() + m.completeWorkItem(ctx, work, maybe.Nothing[[]byte](), targetRootID, nil) + return + } + changeOrRangeProof, err := m.config.Client.GetChangeProof( ctx, &pb.SyncGetChangeProofRequest{ @@ -332,6 +344,17 @@ func (m *Manager) getAndApplyChangeProof(ctx context.Context, work *workItem) { // Assumes [m.workLock] is not held. func (m *Manager) getAndApplyRangeProof(ctx context.Context, work *workItem) { targetRootID := m.getTargetRoot() + + if targetRootID == ids.Empty { + if err := m.config.DB.Clear(); err != nil { + m.setError(err) + return + } + work.start = maybe.Nothing[[]byte]() + m.completeWorkItem(ctx, work, maybe.Nothing[[]byte](), targetRootID, nil) + return + } + proof, err := m.config.Client.GetRangeProof(ctx, &pb.SyncGetRangeProofRequest{ RootHash: targetRootID[:], diff --git a/x/sync/network_server.go b/x/sync/network_server.go index 8027a05042f1..e31e6d8f3ace 100644 --- a/x/sync/network_server.go +++ b/x/sync/network_server.go @@ -397,6 +397,8 @@ func validateChangeProofRequest(req *pb.SyncGetChangeProofRequest) error { return errInvalidStartRootHash case len(req.EndRootHash) != hashing.HashLen: return errInvalidEndRootHash + case bytes.Equal(req.EndRootHash, ids.Empty[:]): + return merkledb.ErrEmptyProof case req.StartKey != nil && req.StartKey.IsNothing && len(req.StartKey.Value) > 0: return errInvalidStartKey case req.EndKey != nil && req.EndKey.IsNothing && len(req.EndKey.Value) > 0: @@ -418,6 +420,8 @@ func validateRangeProofRequest(req *pb.SyncGetRangeProofRequest) error { return errInvalidKeyLimit case len(req.RootHash) != ids.IDLen: return errInvalidRootHash + case bytes.Equal(req.RootHash, ids.Empty[:]): + return merkledb.ErrEmptyProof case req.StartKey != nil && req.StartKey.IsNothing && len(req.StartKey.Value) > 0: return errInvalidStartKey case req.EndKey != nil && req.EndKey.IsNothing && len(req.EndKey.Value) > 0: diff --git a/x/sync/network_server_test.go b/x/sync/network_server_test.go index 7b83afc3e7d5..a73aa3736979 100644 --- a/x/sync/network_server_test.go +++ b/x/sync/network_server_test.go @@ -93,6 +93,14 @@ func Test_Server_GetRangeProof(t *testing.T) { }, expectedMaxResponseBytes: defaultRequestByteSizeLimit, }, + "empty proof": { + request: &pb.SyncGetRangeProofRequest{ + RootHash: ids.Empty[:], + KeyLimit: defaultRequestKeyLimit, + BytesLimit: defaultRequestByteSizeLimit, + }, + proofNil: true, + }, } for name, test := range tests { @@ -252,7 +260,7 @@ func Test_Server_GetChangeProof(t *testing.T) { request: &pb.SyncGetChangeProofRequest{ // This root doesn't exist so server has insufficient history // to serve a change proof - StartRootHash: ids.Empty[:], + StartRootHash: fakeRootID[:], EndRootHash: endRoot[:], KeyLimit: defaultRequestKeyLimit, BytesLimit: defaultRequestByteSizeLimit, @@ -272,6 +280,16 @@ func Test_Server_GetChangeProof(t *testing.T) { expectedMaxResponseBytes: defaultRequestByteSizeLimit, proofNil: true, }, + "empt proof": { + request: &pb.SyncGetChangeProofRequest{ + StartRootHash: fakeRootID[:], + EndRootHash: ids.Empty[:], + KeyLimit: defaultRequestKeyLimit, + BytesLimit: defaultRequestByteSizeLimit, + }, + expectedMaxResponseBytes: defaultRequestByteSizeLimit, + proofNil: true, + }, } for name, test := range tests { diff --git a/x/sync/sync_test.go b/x/sync/sync_test.go index 71871e95db56..f970229aa2f5 100644 --- a/x/sync/sync_test.go +++ b/x/sync/sync_test.go @@ -102,14 +102,17 @@ func Test_Completion(t *testing.T) { newDefaultDBConfig(), ) require.NoError(err) + emptyRoot, err := emptyDB.GetMerkleRoot(context.Background()) require.NoError(err) + db, err := merkledb.New( context.Background(), memdb.New(), newDefaultDBConfig(), ) require.NoError(err) + syncer, err := NewManager(ManagerConfig{ DB: db, Client: newCallthroughSyncClient(ctrl, emptyDB), @@ -120,8 +123,10 @@ func Test_Completion(t *testing.T) { }) require.NoError(err) require.NotNil(syncer) + require.NoError(syncer.Start(context.Background())) require.NoError(syncer.Wait(context.Background())) + syncer.workLock.Lock() require.Zero(syncer.unprocessedWork.Len()) require.Equal(1, syncer.processedWork.Len()) @@ -332,25 +337,26 @@ func Test_Sync_FindNextKey_BranchInLocal(t *testing.T) { require.NoError(db.Put([]byte{0x11}, []byte{1})) require.NoError(db.Put([]byte{0x11, 0x11}, []byte{2})) - syncRoot, err := db.GetMerkleRoot(context.Background()) + targetRoot, err := db.GetMerkleRoot(context.Background()) require.NoError(err) + proof, err := db.GetProof(context.Background(), []byte{0x11, 0x11}) require.NoError(err) syncer, err := NewManager(ManagerConfig{ DB: db, Client: NewMockClient(ctrl), - TargetRoot: syncRoot, + TargetRoot: targetRoot, SimultaneousWorkLimit: 5, Log: logging.NoLog{}, BranchFactor: merkledb.BranchFactor16, }) require.NoError(err) - require.NoError(db.Put([]byte{0x12}, []byte{4})) + require.NoError(db.Put([]byte{0x11, 0x15}, []byte{4})) nextKey, err := syncer.findNextKey(context.Background(), []byte{0x11, 0x11}, maybe.Some([]byte{0x20}), proof.Path) require.NoError(err) - require.Equal(maybe.Some([]byte{0x12}), nextKey) + require.Equal(maybe.Some([]byte{0x11, 0x15}), nextKey) } func Test_Sync_FindNextKey_BranchInReceived(t *testing.T) { @@ -365,27 +371,28 @@ func Test_Sync_FindNextKey_BranchInReceived(t *testing.T) { require.NoError(err) require.NoError(db.Put([]byte{0x11}, []byte{1})) require.NoError(db.Put([]byte{0x12}, []byte{2})) - require.NoError(db.Put([]byte{0x11, 0x11}, []byte{3})) + require.NoError(db.Put([]byte{0x12, 0xA0}, []byte{4})) - syncRoot, err := db.GetMerkleRoot(context.Background()) + targetRoot, err := db.GetMerkleRoot(context.Background()) require.NoError(err) - proof, err := db.GetProof(context.Background(), []byte{0x11, 0x11}) + + proof, err := db.GetProof(context.Background(), []byte{0x12}) require.NoError(err) syncer, err := NewManager(ManagerConfig{ DB: db, Client: NewMockClient(ctrl), - TargetRoot: syncRoot, + TargetRoot: targetRoot, SimultaneousWorkLimit: 5, Log: logging.NoLog{}, BranchFactor: merkledb.BranchFactor16, }) require.NoError(err) - require.NoError(db.Delete([]byte{0x12})) + require.NoError(db.Delete([]byte{0x12, 0xA0})) - nextKey, err := syncer.findNextKey(context.Background(), []byte{0x11, 0x11}, maybe.Some([]byte{0x20}), proof.Path) + nextKey, err := syncer.findNextKey(context.Background(), []byte{0x12}, maybe.Some([]byte{0x20}), proof.Path) require.NoError(err) - require.Equal(maybe.Some([]byte{0x12}), nextKey) + require.Equal(maybe.Some([]byte{0x12, 0xA0}), nextKey) } func Test_Sync_FindNextKey_ExtraValues(t *testing.T) {