From d213136be9bbb9998279d8338c92eafc905e94bb Mon Sep 17 00:00:00 2001 From: Luchuan Date: Tue, 21 Jul 2020 14:12:52 +0800 Subject: [PATCH] Fix db.Seek (#291) * add db.seek ut * fix db.seek * remove ByteArrayCompare Co-authored-by: Luchuan Co-authored-by: Shargon --- src/LevelDBStore/IO/Data/LevelDB/Helper.cs | 30 ++++++++++--- src/RocksDBStore/Plugins/Storage/Snapshot.cs | 14 +++--- src/RocksDBStore/Plugins/Storage/Store.cs | 17 +++----- tests/Neo.Plugins.Storage.Tests/StoreTest.cs | 45 +++++++++++++++++++- 4 files changed, 80 insertions(+), 26 deletions(-) diff --git a/src/LevelDBStore/IO/Data/LevelDB/Helper.cs b/src/LevelDBStore/IO/Data/LevelDB/Helper.cs index 0ed2f9543..bd28be11c 100644 --- a/src/LevelDBStore/IO/Data/LevelDB/Helper.cs +++ b/src/LevelDBStore/IO/Data/LevelDB/Helper.cs @@ -10,16 +10,32 @@ public static class Helper public static IEnumerable Seek(this DB db, ReadOptions options, byte table, byte[] prefix, SeekDirection direction, Func resultSelector) { using Iterator it = db.NewIterator(options); - for (it.Seek(CreateKey(table, prefix)); it.Valid();) + byte[] target = CreateKey(table, prefix); + if (direction == SeekDirection.Forward) { - var key = it.Key(); - if (key.Length < 1 || key[0] != table) break; - yield return resultSelector(it.Key(), it.Value()); + for (it.Seek(target); it.Valid(); it.Next()) + { + var key = it.Key(); + if (key.Length < 1 || key[0] != table) break; + yield return resultSelector(it.Key(), it.Value()); + } + } + else + { + // SeekForPrev - if (direction == SeekDirection.Forward) - it.Next(); - else + it.Seek(target); + if (!it.Valid()) + it.SeekToLast(); + else if (it.Key().AsSpan().SequenceCompareTo(target) > 0) it.Prev(); + + for (; it.Valid(); it.Prev()) + { + var key = it.Key(); + if (key.Length < 1 || key[0] != table) break; + yield return resultSelector(it.Key(), it.Value()); + } } } diff --git a/src/RocksDBStore/Plugins/Storage/Snapshot.cs b/src/RocksDBStore/Plugins/Storage/Snapshot.cs index fba384598..834dbc0a3 100644 --- a/src/RocksDBStore/Plugins/Storage/Snapshot.cs +++ b/src/RocksDBStore/Plugins/Storage/Snapshot.cs @@ -43,15 +43,13 @@ public void Put(byte table, byte[] key, byte[] value) public IEnumerable<(byte[] Key, byte[] Value)> Seek(byte table, byte[] keyOrPrefix, SeekDirection direction) { using var it = db.NewIterator(store.GetFamily(table), options); - for (it.Seek(keyOrPrefix); it.Valid();) - { - yield return (it.Key(), it.Value()); - if (direction == SeekDirection.Forward) - it.Next(); - else - it.Prev(); - } + if (direction == SeekDirection.Forward) + for (it.Seek(keyOrPrefix); it.Valid(); it.Next()) + yield return (it.Key(), it.Value()); + else + for (it.SeekForPrev(keyOrPrefix); it.Valid(); it.Prev()) + yield return (it.Key(), it.Value()); } public byte[] TryGet(byte table, byte[] key) diff --git a/src/RocksDBStore/Plugins/Storage/Store.cs b/src/RocksDBStore/Plugins/Storage/Store.cs index 3cad88249..5fb255c11 100644 --- a/src/RocksDBStore/Plugins/Storage/Store.cs +++ b/src/RocksDBStore/Plugins/Storage/Store.cs @@ -82,18 +82,15 @@ public ISnapshot GetSnapshot() return new Snapshot(this, db); } - public IEnumerable<(byte[] Key, byte[] Value)> Seek(byte table, byte[] prefix, SeekDirection direction = SeekDirection.Forward) + public IEnumerable<(byte[] Key, byte[] Value)> Seek(byte table, byte[] keyOrPrefix, SeekDirection direction = SeekDirection.Forward) { using var it = db.NewIterator(GetFamily(table), Options.ReadDefault); - for (it.Seek(prefix); it.Valid();) - { - yield return (it.Key(), it.Value()); - - if (direction == SeekDirection.Forward) - it.Next(); - else - it.Prev(); - } + if (direction == SeekDirection.Forward) + for (it.Seek(keyOrPrefix); it.Valid(); it.Next()) + yield return (it.Key(), it.Value()); + else + for (it.SeekForPrev(keyOrPrefix); it.Valid(); it.Prev()) + yield return (it.Key(), it.Value()); } public byte[] TryGet(byte table, byte[] key) diff --git a/tests/Neo.Plugins.Storage.Tests/StoreTest.cs b/tests/Neo.Plugins.Storage.Tests/StoreTest.cs index ee83c8a42..d19290dc6 100644 --- a/tests/Neo.Plugins.Storage.Tests/StoreTest.cs +++ b/tests/Neo.Plugins.Storage.Tests/StoreTest.cs @@ -1,5 +1,6 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using Neo.Persistence; +using System; namespace Neo.Plugins.Storage.Tests { @@ -43,7 +44,7 @@ public void TestRocksDb() } /// - /// Test Put/Delete/TryGet + /// Test Put/Delete/TryGet/Seek /// /// Store private void TestStorage(IStore store) @@ -64,6 +65,48 @@ private void TestStorage(IStore store) ret = store.TryGet(0, new byte[] { 0x01, 0x02 }); Assert.IsNull(ret); + + // Test seek + + store.Put(1, new byte[] { 0x00, 0x00, 0x00 }, new byte[] { 0x00 }); + store.Put(1, new byte[] { 0x00, 0x00, 0x01 }, new byte[] { 0x01 }); + store.Put(1, new byte[] { 0x00, 0x00, 0x02 }, new byte[] { 0x02 }); + store.Put(1, new byte[] { 0x00, 0x00, 0x03 }, new byte[] { 0x03 }); + store.Put(1, new byte[] { 0x00, 0x00, 0x04 }, new byte[] { 0x04 }); + + // Seek Forward + + var enumerator = store.Seek(1, new byte[] { 0x00, 0x00, 0x02 }, IO.Caching.SeekDirection.Forward).GetEnumerator(); + Assert.IsTrue(enumerator.MoveNext()); + CollectionAssert.AreEqual(new byte[] { 0x00, 0x00, 0x02 }, enumerator.Current.Key); + CollectionAssert.AreEqual(new byte[] { 0x02 }, enumerator.Current.Value); + Assert.IsTrue(enumerator.MoveNext()); + CollectionAssert.AreEqual(new byte[] { 0x00, 0x00, 0x03 }, enumerator.Current.Key); + CollectionAssert.AreEqual(new byte[] { 0x03 }, enumerator.Current.Value); + + // Seek Backward + + enumerator = store.Seek(1, new byte[] { 0x00, 0x00, 0x02 }, IO.Caching.SeekDirection.Backward).GetEnumerator(); + Assert.IsTrue(enumerator.MoveNext()); + CollectionAssert.AreEqual(new byte[] { 0x00, 0x00, 0x02 }, enumerator.Current.Key); + CollectionAssert.AreEqual(new byte[] { 0x02 }, enumerator.Current.Value); + Assert.IsTrue(enumerator.MoveNext()); + CollectionAssert.AreEqual(new byte[] { 0x00, 0x00, 0x01 }, enumerator.Current.Key); + CollectionAssert.AreEqual(new byte[] { 0x01 }, enumerator.Current.Value); + + // Seek Backward + + store.Put(2, new byte[] { 0x00, 0x00, 0x00 }, new byte[] { 0x00 }); + store.Put(2, new byte[] { 0x00, 0x00, 0x01 }, new byte[] { 0x01 }); + store.Put(2, new byte[] { 0x00, 0x01, 0x02 }, new byte[] { 0x02 }); + + enumerator = store.Seek(2, new byte[] { 0x00, 0x00, 0x03 }, IO.Caching.SeekDirection.Backward).GetEnumerator(); + Assert.IsTrue(enumerator.MoveNext()); + CollectionAssert.AreEqual(new byte[] { 0x00, 0x00, 0x01 }, enumerator.Current.Key); + CollectionAssert.AreEqual(new byte[] { 0x01 }, enumerator.Current.Value); + Assert.IsTrue(enumerator.MoveNext()); + CollectionAssert.AreEqual(new byte[] { 0x00, 0x00, 0x00 }, enumerator.Current.Key); + CollectionAssert.AreEqual(new byte[] { 0x00 }, enumerator.Current.Value); } }