diff --git a/go/storage/mkvs/urkel/db/badger/badger.go b/go/storage/mkvs/urkel/db/badger/badger.go index f8182f4eed1..8f1004e76d3 100644 --- a/go/storage/mkvs/urkel/db/badger/badger.go +++ b/go/storage/mkvs/urkel/db/badger/badger.go @@ -34,7 +34,7 @@ var ( // Value is serialized node. nodeKeyFmt = keyformat.New('N', &common.Namespace{}, &hash.Hash{}) // writeLogKeyFmt is the key format for write logs (namespace, round, - // old root, new root). + // new root, old root). // // Value is CBOR-serialized write log. writeLogKeyFmt = keyformat.New('L', &common.Namespace{}, uint64(0), &hash.Hash{}, &hash.Hash{}) @@ -152,13 +152,13 @@ func (d *badgerNodeDB) GetWriteLog(ctx context.Context, startRoot node.Root, end tx := d.db.NewTransaction(false) defer tx.Discard() - key := writeLogKeyFmt.Encode(&endRoot.Namespace, endRoot.Round, &startRoot.Hash, &endRoot.Hash) + key := writeLogKeyFmt.Encode(&endRoot.Namespace, endRoot.Round, &endRoot.Hash, &startRoot.Hash) item, err := tx.Get(key) if err != nil { d.logger.Error("failed to Get write log from backing store", "err", err, - "start_root", startRoot, - "end_root", endRoot, + "old_root", startRoot, + "new_root", endRoot, ) return nil, errors.Wrap(err, "urkel/db/badger: failed to Get write log from backing store") } @@ -759,10 +759,48 @@ func (ba *badgerBatch) Commit(root node.Root) error { // Store write log. if ba.writeLog != nil && ba.annotations != nil { log := api.MakeHashedDBWriteLog(ba.writeLog, ba.annotations) - bytes := cbor.Marshal(log) - key := writeLogKeyFmt.Encode(&root.Namespace, root.Round, &ba.oldRoot.Hash, &root.Hash) - if err := ba.bat.Set(key, bytes); err != nil { - return errors.Wrap(err, "urkel/db/badger: set returned error") + + prefix := writeLogKeyFmt.Encode(&root.Namespace, root.Round, &ba.oldRoot.Hash) + it := tx.NewIterator(badger.IteratorOptions{Prefix: prefix}) + defer it.Close() + + foundOld := false + for it.Rewind(); it.Valid(); it.Next() { + var decNs common.Namespace + var decRound uint64 + var oldRootHash hash.Hash + var olderRootHash hash.Hash + + if !writeLogKeyFmt.Decode(it.Item().Key(), &decNs, &decRound, &oldRootHash, &olderRootHash) { + // This should not happen as the Badger iterator should take care of it. + panic("urkel/db/badger: bad iterator") + } + + // If an older write log exists, get it, merge it with this one and delete it from the db. + var oldWriteLog api.HashedDBWriteLog + err := it.Item().Value(func(data []byte) error { + return cbor.Unmarshal(data, &oldWriteLog) + }) + if err != nil { + return err + } + oldWriteLog = append(oldWriteLog, log...) + bytes := cbor.Marshal(oldWriteLog) + if err := ba.bat.Set(writeLogKeyFmt.Encode(&root.Namespace, root.Round, &root.Hash, &olderRootHash), bytes); err != nil { + return errors.Wrap(err, "urkel/db/badger: set merged write log returned error") + } + if err := ba.bat.Delete(it.Item().KeyCopy(nil)); err != nil { + return errors.Wrap(err, "urkel/db/badger: delete partial write log returned error") + } + foundOld = true + } + + if !foundOld { + bytes := cbor.Marshal(log) + key := writeLogKeyFmt.Encode(&root.Namespace, root.Round, &root.Hash, &ba.oldRoot.Hash) + if err := ba.bat.Set(key, bytes); err != nil { + return errors.Wrap(err, "urkel/db/badger: set new write log returned error") + } } } diff --git a/go/storage/mkvs/urkel/db/leveldb/leveldb.go b/go/storage/mkvs/urkel/db/leveldb/leveldb.go index b2abc4aa9fc..046c7a19e4a 100644 --- a/go/storage/mkvs/urkel/db/leveldb/leveldb.go +++ b/go/storage/mkvs/urkel/db/leveldb/leveldb.go @@ -28,7 +28,7 @@ var ( // Value is serialized node. nodeKeyFmt = keyformat.New('N', &common.Namespace{}, &hash.Hash{}) // writeLogKeyFmt is the key format for write logs (namespace, round, - // old root, new root). + // new root, old root). // // Value is CBOR-serialized write log. writeLogKeyFmt = keyformat.New('L', &common.Namespace{}, uint64(0), &hash.Hash{}, &hash.Hash{}) @@ -125,7 +125,7 @@ func (d *leveldbNodeDB) GetWriteLog(ctx context.Context, startRoot node.Root, en return nil, api.ErrRootMustFollowOld } - key := writeLogKeyFmt.Encode(&endRoot.Namespace, endRoot.Round, &startRoot.Hash, &endRoot.Hash) + key := writeLogKeyFmt.Encode(&endRoot.Namespace, endRoot.Round, &endRoot.Hash, &startRoot.Hash) bytes, err := d.db.Get(key, nil) if err != nil { return nil, err @@ -604,9 +604,39 @@ func (b *leveldbBatch) Commit(root node.Root) error { // Store write log. if b.writeLog != nil && b.annotations != nil { log := api.MakeHashedDBWriteLog(b.writeLog, b.annotations) - bytes := cbor.Marshal(log) - key := writeLogKeyFmt.Encode(&root.Namespace, root.Round, &b.oldRoot.Hash, &root.Hash) - b.bat.Put(key, bytes) + + prefix := writeLogKeyFmt.Encode(&root.Namespace, root.Round, &b.oldRoot.Hash) + it := snapshot.NewIterator(util.BytesPrefix(prefix), nil) + defer it.Release() + + foundOld := false + for it.Next() { + var decNs common.Namespace + var decRound uint64 + var oldRootHash hash.Hash + var olderRootHash hash.Hash + + if !writeLogKeyFmt.Decode(it.Key(), &decNs, &decRound, &oldRootHash, &olderRootHash) { + // This should not happen as the LevelDB iterator should take care of it. + panic("urkel/db/leveldb: bad iterator") + } + + // If an older write log exists, get it, merge it with this one and delete it from the db. + var oldWriteLog api.HashedDBWriteLog + if err := cbor.Unmarshal(it.Value(), &oldWriteLog); err != nil { + return err + } + oldWriteLog = append(oldWriteLog, log...) + b.bat.Put(writeLogKeyFmt.Encode(&root.Namespace, root.Round, &root.Hash, &olderRootHash), cbor.Marshal(oldWriteLog)) + b.bat.Delete(it.Key()) + foundOld = true + } + + if !foundOld { + bytes := cbor.Marshal(log) + key := writeLogKeyFmt.Encode(&root.Namespace, root.Round, &root.Hash, &b.oldRoot.Hash) + b.bat.Put(key, bytes) + } } if err := b.db.db.Write(b.bat, &opt.WriteOptions{Sync: true}); err != nil { diff --git a/go/storage/mkvs/urkel/urkel_test.go b/go/storage/mkvs/urkel/urkel_test.go index 2149fa7c40e..3e2bed9cbe8 100644 --- a/go/storage/mkvs/urkel/urkel_test.go +++ b/go/storage/mkvs/urkel/urkel_test.go @@ -51,6 +51,23 @@ func writeLogToMap(wl writelog.WriteLog) map[string]string { return writeLogSet } +func foldWriteLogIterator(t *testing.T, w writelog.Iterator) writelog.WriteLog { + writeLog := writelog.WriteLog{} + + for { + more, err := w.Next() + require.NoError(t, err, "error iterating over WriteLogIterator") + if !more { + break + } + + val, err := w.Value() + require.NoError(t, err, "error iterating over WriteLogIterator") + writeLog = append(writeLog, val) + } + return writeLog +} + func (s *dummySerialSyncer) GetSubtree(ctx context.Context, root node.Root, id node.ID, maxDepth node.Depth) (*syncer.Subtree, error) { obj, err := s.backing.GetSubtree(ctx, root, id, maxDepth) if err != nil { @@ -1137,6 +1154,71 @@ func testHasRoot(t *testing.T, ndb db.NodeDB) { require.True(t, ndb.HasRoot(root), "HasRoot should return true for existing root") } +func testMergeWriteLog(t *testing.T, ndb db.NodeDB) { + ctx := context.Background() + + keyZero := []byte("foo") + valueZero := []byte("bar") + keyOne := []byte("baz") + valueOne := []byte("quux") + + emptyRoot := node.Root{ + Namespace: testNs, + Round: 0, + } + emptyRoot.Hash.Empty() + + // Put some stuff in the tree. + tree := New(nil, ndb) + err := tree.Insert(ctx, keyZero, valueZero) + require.NoError(t, err, "Insert") + _, rootHash1, err := tree.Commit(ctx, testNs, 0) + require.NoError(t, err, "Commit") + + root1 := node.Root{ + Namespace: testNs, + Round: 0, + Hash: rootHash1, + } + + wli, err := ndb.GetWriteLog(ctx, emptyRoot, root1) + require.NoError(t, err, "GetWriteLog") + + wl := writeLogToMap(foldWriteLogIterator(t, wli)) + require.Equal(t, writeLogToMap(writelog.WriteLog{writelog.LogEntry{Key: keyZero, Value: valueZero}}), wl) + + // Continue adding to this same tree. + tree, err = NewWithRoot(ctx, nil, ndb, root1) + require.NoError(t, err, "NewWithRoot") + err = tree.Insert(ctx, keyOne, valueOne) + require.NoError(t, err, "Insert") + _, rootHash2, err := tree.Commit(ctx, testNs, 0) + require.NoError(t, err, "Commit") + + root2 := node.Root{ + Namespace: testNs, + Round: 0, + Hash: rootHash2, + } + + // Check that we can get a combined write log from the first root to the third one. + wli, err = ndb.GetWriteLog(ctx, emptyRoot, root2) + require.NoError(t, err, "GetWriteLog") + + wlDb := writeLogToMap(foldWriteLogIterator(t, wli)) + wlLiteral := writeLogToMap(writelog.WriteLog{ + writelog.LogEntry{Key: keyZero, Value: valueZero}, + writelog.LogEntry{Key: keyOne, Value: valueOne}, + }) + require.Equal(t, wlLiteral, wlDb) + + // Check that the write log to the intermediate root doesn't exist anymore. + _, err = ndb.GetWriteLog(ctx, emptyRoot, root1) + require.Error(t, err, "GetWriteLog") + _, err = ndb.GetWriteLog(ctx, root1, root2) + require.Error(t, err, "GetWriteLog") +} + func testPruneBasic(t *testing.T, ndb db.NodeDB) { ctx := context.Background() tree := New(nil, ndb) @@ -1730,6 +1812,7 @@ func testBackend( {"DebugDump", testDebugDump}, {"DebugStats", testDebugStats}, {"OnCommitHooks", testOnCommitHooks}, + {"MergeWriteLog", testMergeWriteLog}, {"HasRoot", testHasRoot}, {"PruneBasic", testPruneBasic}, {"PruneManyRounds", testPruneManyRounds}, @@ -1862,6 +1945,7 @@ func TestUrkelLRUBackend(t *testing.T) { "PruneCheckpoints", "Errors", "HasRoot", + "MergeWriteLog", }, ) }