diff --git a/src/neo/Cryptography/MPT/BranchNode.cs b/src/neo/Cryptography/MPT/BranchNode.cs new file mode 100644 index 0000000000..e69f2130ab --- /dev/null +++ b/src/neo/Cryptography/MPT/BranchNode.cs @@ -0,0 +1,35 @@ +using System.IO; + +namespace Neo.Cryptography.MPT +{ + public class BranchNode : MPTNode + { + public const int ChildCount = 17; + public readonly MPTNode[] Children = new MPTNode[ChildCount]; + + protected override NodeType Type => NodeType.BranchNode; + + public BranchNode() + { + for (int i = 0; i < ChildCount; i++) + { + Children[i] = HashNode.EmptyNode; + } + } + + internal override void EncodeSpecific(BinaryWriter writer) + { + for (int i = 0; i < ChildCount; i++) + WriteHash(writer, Children[i].Hash); + } + + internal override void DecodeSpecific(BinaryReader reader) + { + for (int i = 0; i < ChildCount; i++) + { + Children[i] = new HashNode(); + Children[i].DecodeSpecific(reader); + } + } + } +} diff --git a/src/neo/Cryptography/MPT/ExtensionNode.cs b/src/neo/Cryptography/MPT/ExtensionNode.cs new file mode 100644 index 0000000000..9c575915cf --- /dev/null +++ b/src/neo/Cryptography/MPT/ExtensionNode.cs @@ -0,0 +1,30 @@ +using Neo.IO; +using Neo.SmartContract; +using System.IO; + +namespace Neo.Cryptography.MPT +{ + public class ExtensionNode : MPTNode + { + //max lenght when store StorageKey + public const int MaxKeyLength = (ApplicationEngine.MaxStorageValueSize + sizeof(int)) * 2; + + public byte[] Key; + public MPTNode Next; + + protected override NodeType Type => NodeType.ExtensionNode; + + internal override void EncodeSpecific(BinaryWriter writer) + { + writer.WriteVarBytes(Key); + WriteHash(writer, Next.Hash); + } + + internal override void DecodeSpecific(BinaryReader reader) + { + Key = reader.ReadVarBytes(MaxKeyLength); + Next = new HashNode(); + Next.DecodeSpecific(reader); + } + } +} diff --git a/src/neo/Cryptography/MPT/HashNode.cs b/src/neo/Cryptography/MPT/HashNode.cs new file mode 100644 index 0000000000..b304833d66 --- /dev/null +++ b/src/neo/Cryptography/MPT/HashNode.cs @@ -0,0 +1,41 @@ +using Neo.IO; +using System; +using System.IO; + +namespace Neo.Cryptography.MPT +{ + public class HashNode : MPTNode + { + private UInt256 hash; + + public override UInt256 Hash => hash; + protected override NodeType Type => NodeType.HashNode; + public bool IsEmpty => Hash is null; + public static HashNode EmptyNode { get; } = new HashNode(); + + public HashNode() + { + } + + public HashNode(UInt256 hash) + { + this.hash = hash; + } + + internal override void EncodeSpecific(BinaryWriter writer) + { + WriteHash(writer, hash); + } + + internal override void DecodeSpecific(BinaryReader reader) + { + byte[] buffer = reader.ReadVarBytes(UInt256.Length); + hash = buffer.Length switch + { + 0 => null, + UInt256.Length => new UInt256(buffer), + _ => throw new FormatException() + }; + } + } +} diff --git a/src/neo/Cryptography/MPT/LeafNode.cs b/src/neo/Cryptography/MPT/LeafNode.cs new file mode 100644 index 0000000000..cf95a5bc63 --- /dev/null +++ b/src/neo/Cryptography/MPT/LeafNode.cs @@ -0,0 +1,36 @@ +using Neo.IO; +using Neo.SmartContract; +using System; +using System.IO; + +namespace Neo.Cryptography.MPT +{ + public class LeafNode : MPTNode + { + //the max size when store StorageItem + public const int MaxValueLength = 3 + ApplicationEngine.MaxStorageValueSize + sizeof(bool); + + public byte[] Value; + + protected override NodeType Type => NodeType.LeafNode; + + public LeafNode() + { + } + + public LeafNode(ReadOnlySpan value) + { + Value = value.ToArray(); + } + + internal override void EncodeSpecific(BinaryWriter writer) + { + writer.WriteVarBytes(Value); + } + + internal override void DecodeSpecific(BinaryReader reader) + { + Value = reader.ReadVarBytes(MaxValueLength); + } + } +} diff --git a/src/neo/Cryptography/MPT/MPTNode.cs b/src/neo/Cryptography/MPT/MPTNode.cs new file mode 100644 index 0000000000..85955c625f --- /dev/null +++ b/src/neo/Cryptography/MPT/MPTNode.cs @@ -0,0 +1,61 @@ +using Neo.IO; +using Neo.IO.Caching; +using System; +using System.IO; + +namespace Neo.Cryptography.MPT +{ + public abstract class MPTNode + { + private UInt256 hash; + + public virtual UInt256 Hash => hash ??= new UInt256(Crypto.Hash256(Encode())); + protected abstract NodeType Type { get; } + + public void SetDirty() + { + hash = null; + } + + public byte[] Encode() + { + using MemoryStream ms = new MemoryStream(); + using BinaryWriter writer = new BinaryWriter(ms); + + writer.Write((byte)Type); + EncodeSpecific(writer); + writer.Flush(); + + return ms.ToArray(); + } + + internal abstract void EncodeSpecific(BinaryWriter writer); + + public static unsafe MPTNode Decode(ReadOnlySpan data) + { + if (data.IsEmpty) return null; + + fixed (byte* pointer = data) + { + using UnmanagedMemoryStream stream = new UnmanagedMemoryStream(pointer, data.Length); + using BinaryReader reader = new BinaryReader(stream); + + MPTNode n = (MPTNode)ReflectionCache.CreateInstance((NodeType)reader.ReadByte()); + if (n is null) throw new InvalidOperationException(); + + n.DecodeSpecific(reader); + return n; + } + } + + internal abstract void DecodeSpecific(BinaryReader reader); + + protected void WriteHash(BinaryWriter writer, UInt256 hash) + { + if (hash is null) + writer.Write((byte)0); + else + writer.WriteVarBytes(hash.ToArray()); + } + } +} diff --git a/src/neo/Cryptography/MPT/MPTNodeType.cs b/src/neo/Cryptography/MPT/MPTNodeType.cs new file mode 100644 index 0000000000..0194fa9664 --- /dev/null +++ b/src/neo/Cryptography/MPT/MPTNodeType.cs @@ -0,0 +1,16 @@ +using Neo.IO.Caching; + +namespace Neo.Cryptography.MPT +{ + public enum NodeType : byte + { + [ReflectionCache(typeof(BranchNode))] + BranchNode = 0x00, + [ReflectionCache(typeof(ExtensionNode))] + ExtensionNode = 0x01, + [ReflectionCache(typeof(HashNode))] + HashNode = 0x02, + [ReflectionCache(typeof(LeafNode))] + LeafNode = 0x03, + } +} diff --git a/src/neo/Cryptography/MPT/MPTTrie.Delete.cs b/src/neo/Cryptography/MPT/MPTTrie.Delete.cs new file mode 100644 index 0000000000..4996738b92 --- /dev/null +++ b/src/neo/Cryptography/MPT/MPTTrie.Delete.cs @@ -0,0 +1,120 @@ +using Neo.IO; +using System; +using System.Collections.Generic; +using static Neo.Helper; + +namespace Neo.Cryptography.MPT +{ + partial class MPTTrie + { + public bool Delete(TKey key) + { + var path = ToNibbles(key.ToArray()); + if (path.Length == 0) return false; + return TryDelete(ref root, path); + } + + private bool TryDelete(ref MPTNode node, ReadOnlySpan path) + { + switch (node) + { + case LeafNode _: + { + if (path.IsEmpty) + { + node = HashNode.EmptyNode; + return true; + } + return false; + } + case ExtensionNode extensionNode: + { + if (path.StartsWith(extensionNode.Key)) + { + var result = TryDelete(ref extensionNode.Next, path[extensionNode.Key.Length..]); + if (!result) return false; + if (extensionNode.Next is HashNode hashNode && hashNode.IsEmpty) + { + node = extensionNode.Next; + return true; + } + if (extensionNode.Next is ExtensionNode sn) + { + extensionNode.Key = Concat(extensionNode.Key, sn.Key); + extensionNode.Next = sn.Next; + } + extensionNode.SetDirty(); + PutToStore(extensionNode); + return true; + } + return false; + } + case BranchNode branchNode: + { + bool result; + if (path.IsEmpty) + { + result = TryDelete(ref branchNode.Children[BranchNode.ChildCount - 1], path); + } + else + { + result = TryDelete(ref branchNode.Children[path[0]], path[1..]); + } + if (!result) return false; + List childrenIndexes = new List(BranchNode.ChildCount); + for (int i = 0; i < BranchNode.ChildCount; i++) + { + if (branchNode.Children[i] is HashNode hn && hn.IsEmpty) continue; + childrenIndexes.Add((byte)i); + } + if (childrenIndexes.Count > 1) + { + branchNode.SetDirty(); + PutToStore(branchNode); + return true; + } + var lastChildIndex = childrenIndexes[0]; + var lastChild = branchNode.Children[lastChildIndex]; + if (lastChildIndex == BranchNode.ChildCount - 1) + { + node = lastChild; + return true; + } + if (lastChild is HashNode hashNode) + { + lastChild = Resolve(hashNode); + if (lastChild is null) return false; + } + if (lastChild is ExtensionNode exNode) + { + exNode.Key = Concat(childrenIndexes.ToArray(), exNode.Key); + exNode.SetDirty(); + PutToStore(exNode); + node = exNode; + return true; + } + node = new ExtensionNode() + { + Key = childrenIndexes.ToArray(), + Next = lastChild, + }; + PutToStore(node); + return true; + } + case HashNode hashNode: + { + if (hashNode.IsEmpty) + { + return true; + } + var newNode = Resolve(hashNode); + if (newNode is null) return false; + node = newNode; + return TryDelete(ref node, path); + } + default: + return false; + } + } + } +} diff --git a/src/neo/Cryptography/MPT/MPTTrie.Find.cs b/src/neo/Cryptography/MPT/MPTTrie.Find.cs new file mode 100644 index 0000000000..5f4153758b --- /dev/null +++ b/src/neo/Cryptography/MPT/MPTTrie.Find.cs @@ -0,0 +1,110 @@ +using Neo.IO; +using System; +using System.Collections.Generic; +using System.Linq; +using static Neo.Helper; + +namespace Neo.Cryptography.MPT +{ + partial class MPTTrie + { + private ReadOnlySpan Seek(ref MPTNode node, ReadOnlySpan path, out MPTNode start) + { + switch (node) + { + case LeafNode leafNode: + { + if (path.IsEmpty) + { + start = leafNode; + return ReadOnlySpan.Empty; + } + break; + } + case HashNode hashNode: + { + if (hashNode.IsEmpty) break; + var newNode = Resolve(hashNode); + if (newNode is null) break; + node = newNode; + return Seek(ref node, path, out start); + } + case BranchNode branchNode: + { + if (path.IsEmpty) + { + start = branchNode; + return ReadOnlySpan.Empty; + } + return Concat(path[..1], Seek(ref branchNode.Children[path[0]], path[1..], out start)); + } + case ExtensionNode extensionNode: + { + if (path.IsEmpty) + { + start = extensionNode.Next; + return extensionNode.Key; + } + if (path.StartsWith(extensionNode.Key)) + { + return Concat(extensionNode.Key, Seek(ref extensionNode.Next, path[extensionNode.Key.Length..], out start)); + } + if (extensionNode.Key.AsSpan().StartsWith(path)) + { + start = extensionNode.Next; + return extensionNode.Key; + } + break; + } + } + start = null; + return ReadOnlySpan.Empty; + } + + public IEnumerable<(TKey Key, TValue Value)> Find(ReadOnlySpan prefix) + { + var path = ToNibbles(prefix); + path = Seek(ref root, path, out MPTNode start).ToArray(); + return Travers(start, path) + .Select(p => (FromNibbles(p.Key).AsSerializable(), p.Value.AsSerializable())); + } + + private IEnumerable<(byte[] Key, byte[] Value)> Travers(MPTNode node, byte[] path) + { + if (node is null) yield break; + switch (node) + { + case LeafNode leafNode: + { + yield return (path, (byte[])leafNode.Value.Clone()); + break; + } + case HashNode hashNode: + { + if (hashNode.IsEmpty) break; + var newNode = Resolve(hashNode); + if (newNode is null) break; + node = newNode; + foreach (var item in Travers(node, path)) + yield return item; + break; + } + case BranchNode branchNode: + { + for (int i = 0; i < BranchNode.ChildCount; i++) + { + foreach (var item in Travers(branchNode.Children[i], i == BranchNode.ChildCount - 1 ? path : Concat(path, new byte[] { (byte)i }))) + yield return item; + } + break; + } + case ExtensionNode extensionNode: + { + foreach (var item in Travers(extensionNode.Next, Concat(path, extensionNode.Key))) + yield return item; + break; + } + } + } + } +} diff --git a/src/neo/Cryptography/MPT/MPTTrie.Get.cs b/src/neo/Cryptography/MPT/MPTTrie.Get.cs new file mode 100644 index 0000000000..367f2a61ce --- /dev/null +++ b/src/neo/Cryptography/MPT/MPTTrie.Get.cs @@ -0,0 +1,61 @@ +using Neo.IO; +using System; + +namespace Neo.Cryptography.MPT +{ + partial class MPTTrie + { + public TValue this[TKey key] + { + get + { + var path = ToNibbles(key.ToArray()); + if (path.Length == 0) return null; + var result = TryGet(ref root, path, out var value); + return result ? value.AsSerializable() : null; + } + } + + private bool TryGet(ref MPTNode node, ReadOnlySpan path, out ReadOnlySpan value) + { + switch (node) + { + case LeafNode leafNode: + { + if (path.IsEmpty) + { + value = leafNode.Value; + return true; + } + break; + } + case HashNode hashNode: + { + if (hashNode.IsEmpty) break; + var newNode = Resolve(hashNode); + if (newNode is null) break; + node = newNode; + return TryGet(ref node, path, out value); + } + case BranchNode branchNode: + { + if (path.IsEmpty) + { + return TryGet(ref branchNode.Children[BranchNode.ChildCount - 1], path, out value); + } + return TryGet(ref branchNode.Children[path[0]], path[1..], out value); + } + case ExtensionNode extensionNode: + { + if (path.StartsWith(extensionNode.Key)) + { + return TryGet(ref extensionNode.Next, path[extensionNode.Key.Length..], out value); + } + break; + } + } + value = default; + return false; + } + } +} diff --git a/src/neo/Cryptography/MPT/MPTTrie.Proof.cs b/src/neo/Cryptography/MPT/MPTTrie.Proof.cs new file mode 100644 index 0000000000..7b4be120ab --- /dev/null +++ b/src/neo/Cryptography/MPT/MPTTrie.Proof.cs @@ -0,0 +1,72 @@ +using Neo.IO; +using Neo.Persistence; +using System; +using System.Collections.Generic; + +namespace Neo.Cryptography.MPT +{ + partial class MPTTrie + { + public HashSet GetProof(TKey key) + { + var path = ToNibbles(key.ToArray()); + if (path.Length == 0) return null; + HashSet set = new HashSet(ByteArrayEqualityComparer.Default); + if (!GetProof(ref root, path, set)) return null; + return set; + } + + private bool GetProof(ref MPTNode node, ReadOnlySpan path, HashSet set) + { + switch (node) + { + case LeafNode leafNode: + { + if (path.IsEmpty) + { + set.Add(leafNode.Encode()); + return true; + } + break; + } + case HashNode hashNode: + { + if (hashNode.IsEmpty) break; + var newNode = Resolve(hashNode); + if (newNode is null) break; + node = newNode; + return GetProof(ref node, path, set); + } + case BranchNode branchNode: + { + set.Add(branchNode.Encode()); + if (path.IsEmpty) + { + return GetProof(ref branchNode.Children[BranchNode.ChildCount - 1], path, set); + } + return GetProof(ref branchNode.Children[path[0]], path[1..], set); + } + case ExtensionNode extensionNode: + { + if (path.StartsWith(extensionNode.Key)) + { + set.Add(extensionNode.Encode()); + return GetProof(ref extensionNode.Next, path[extensionNode.Key.Length..], set); + } + break; + } + } + return false; + } + + public static TValue VerifyProof(UInt256 root, TKey key, HashSet proof) + { + using var memoryStore = new MemoryStore(); + foreach (byte[] data in proof) + memoryStore.Put(Prefix, Crypto.Hash256(data), data); + using ISnapshot snapshot = memoryStore.GetSnapshot(); + var trie = new MPTTrie(snapshot, root); + return trie[key]; + } + } +} diff --git a/src/neo/Cryptography/MPT/MPTTrie.Put.cs b/src/neo/Cryptography/MPT/MPTTrie.Put.cs new file mode 100644 index 0000000000..491213916f --- /dev/null +++ b/src/neo/Cryptography/MPT/MPTTrie.Put.cs @@ -0,0 +1,158 @@ +using Neo.IO; +using System; + +namespace Neo.Cryptography.MPT +{ + partial class MPTTrie + { + private static ReadOnlySpan CommonPrefix(ReadOnlySpan a, ReadOnlySpan b) + { + var minLen = a.Length <= b.Length ? a.Length : b.Length; + int i = 0; + if (a.Length != 0 && b.Length != 0) + { + for (i = 0; i < minLen; i++) + { + if (a[i] != b[i]) break; + } + } + return a[..i]; + } + + public bool Put(TKey key, TValue value) + { + var path = ToNibbles(key.ToArray()); + var val = value.ToArray(); + if (path.Length == 0 || path.Length > ExtensionNode.MaxKeyLength) + return false; + if (val.Length > LeafNode.MaxValueLength) + return false; + if (val.Length == 0) + return TryDelete(ref root, path); + var n = new LeafNode(val); + return Put(ref root, path, n); + } + + private bool Put(ref MPTNode node, ReadOnlySpan path, MPTNode val) + { + switch (node) + { + case LeafNode leafNode: + { + if (val is LeafNode v) + { + if (path.IsEmpty) + { + node = v; + PutToStore(node); + return true; + } + var branch = new BranchNode(); + branch.Children[BranchNode.ChildCount - 1] = leafNode; + Put(ref branch.Children[path[0]], path[1..], v); + PutToStore(branch); + node = branch; + return true; + } + return false; + } + case ExtensionNode extensionNode: + { + if (path.StartsWith(extensionNode.Key)) + { + var result = Put(ref extensionNode.Next, path[extensionNode.Key.Length..], val); + if (result) + { + extensionNode.SetDirty(); + PutToStore(extensionNode); + } + return result; + } + var prefix = CommonPrefix(extensionNode.Key, path); + var pathRemain = path[prefix.Length..]; + var keyRemain = extensionNode.Key.AsSpan(prefix.Length); + var son = new BranchNode(); + MPTNode grandSon1 = HashNode.EmptyNode; + MPTNode grandSon2 = HashNode.EmptyNode; + + Put(ref grandSon1, keyRemain[1..], extensionNode.Next); + son.Children[keyRemain[0]] = grandSon1; + + if (pathRemain.IsEmpty) + { + Put(ref grandSon2, pathRemain, val); + son.Children[BranchNode.ChildCount - 1] = grandSon2; + } + else + { + Put(ref grandSon2, pathRemain[1..], val); + son.Children[pathRemain[0]] = grandSon2; + } + PutToStore(son); + if (prefix.Length > 0) + { + var exNode = new ExtensionNode() + { + Key = prefix.ToArray(), + Next = son, + }; + PutToStore(exNode); + node = exNode; + } + else + { + node = son; + } + return true; + } + case BranchNode branchNode: + { + bool result; + if (path.IsEmpty) + { + result = Put(ref branchNode.Children[BranchNode.ChildCount - 1], path, val); + } + else + { + result = Put(ref branchNode.Children[path[0]], path[1..], val); + } + if (result) + { + branchNode.SetDirty(); + PutToStore(branchNode); + } + return result; + } + case HashNode hashNode: + { + MPTNode newNode; + if (hashNode.IsEmpty) + { + if (path.IsEmpty) + { + newNode = val; + } + else + { + newNode = new ExtensionNode() + { + Key = path.ToArray(), + Next = val, + }; + PutToStore(newNode); + } + node = newNode; + if (val is LeafNode) PutToStore(val); + return true; + } + newNode = Resolve(hashNode); + if (newNode is null) return false; + node = newNode; + return Put(ref node, path, val); + } + default: + return false; + } + } + } +} diff --git a/src/neo/Cryptography/MPT/MPTTrie.cs b/src/neo/Cryptography/MPT/MPTTrie.cs new file mode 100644 index 0000000000..df97b54f9c --- /dev/null +++ b/src/neo/Cryptography/MPT/MPTTrie.cs @@ -0,0 +1,58 @@ +using Neo.IO; +using Neo.Persistence; +using System; + +namespace Neo.Cryptography.MPT +{ + public partial class MPTTrie + where TKey : notnull, ISerializable, new() + where TValue : class, ISerializable, new() + { + private const byte Prefix = 0xf0; + + private readonly ISnapshot store; + private MPTNode root; + + public MPTNode Root => root; + + public MPTTrie(ISnapshot store, UInt256 root) + { + this.store = store ?? throw new ArgumentNullException(); + this.root = root is null ? HashNode.EmptyNode : new HashNode(root); + } + + private MPTNode Resolve(HashNode n) + { + var data = store.TryGet(Prefix, n.Hash.ToArray()); + return MPTNode.Decode(data); + } + + private static byte[] ToNibbles(ReadOnlySpan path) + { + var result = new byte[path.Length * 2]; + for (int i = 0; i < path.Length; i++) + { + result[i * 2] = (byte)(path[i] >> 4); + result[i * 2 + 1] = (byte)(path[i] & 0x0F); + } + return result; + } + + private static byte[] FromNibbles(ReadOnlySpan path) + { + if (path.Length % 2 != 0) throw new FormatException($"MPTTrie.FromNibbles invalid path."); + var key = new byte[path.Length / 2]; + for (int i = 0; i < key.Length; i++) + { + key[i] = (byte)(path[i * 2] << 4); + key[i] |= path[i * 2 + 1]; + } + return key; + } + + private void PutToStore(MPTNode node) + { + store.Put(Prefix, node.Hash.ToArray(), node.Encode()); + } + } +} diff --git a/src/neo/Helper.cs b/src/neo/Helper.cs index b8f185e2c9..dafeba7643 100644 --- a/src/neo/Helper.cs +++ b/src/neo/Helper.cs @@ -50,6 +50,14 @@ public static byte[] Concat(params byte[][] buffers) return dst; } + public static byte[] Concat(ReadOnlySpan a, ReadOnlySpan b) + { + byte[] buffer = new byte[a.Length + b.Length]; + a.CopyTo(buffer); + b.CopyTo(buffer.AsSpan(a.Length)); + return buffer; + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] internal static int GetBitLength(this BigInteger i) { diff --git a/tests/neo.UnitTests/Cryptography/MPT/UT_MPTNode.cs b/tests/neo.UnitTests/Cryptography/MPT/UT_MPTNode.cs new file mode 100644 index 0000000000..6ddc00f4be --- /dev/null +++ b/tests/neo.UnitTests/Cryptography/MPT/UT_MPTNode.cs @@ -0,0 +1,30 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Neo.Cryptography.MPT; +using System.Text; + +namespace Neo.UnitTests.Cryptography.MPT +{ + [TestClass] + public class UT_MPTNode + { + [TestMethod] + public void TestDecode() + { + var n = new LeafNode + { + Value = Encoding.ASCII.GetBytes("hello") + }; + var code = n.Encode(); + var m = MPTNode.Decode(code); + Assert.IsInstanceOfType(m, n.GetType()); + } + + [TestMethod] + public void TestHashNode() + { + var hn = new HashNode(null); + var data = hn.Encode(); + Assert.AreEqual("0200", data.ToHexString()); + } + } +} diff --git a/tests/neo.UnitTests/Cryptography/MPT/UT_MPTTrie.cs b/tests/neo.UnitTests/Cryptography/MPT/UT_MPTTrie.cs new file mode 100644 index 0000000000..014a10e6d2 --- /dev/null +++ b/tests/neo.UnitTests/Cryptography/MPT/UT_MPTTrie.cs @@ -0,0 +1,317 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Neo.Cryptography.MPT; +using Neo.IO; +using Neo.Persistence; +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; + +namespace Neo.UnitTests.Cryptography.MPT +{ + public class TestKey : ISerializable + { + private byte[] key; + + public int Size => key.Length; + + public TestKey() + { + this.key = Array.Empty(); + } + + public TestKey(byte[] key) + { + this.key = key; + } + public void Serialize(BinaryWriter writer) + { + writer.Write(key); + } + + public void Deserialize(BinaryReader reader) + { + key = reader.ReadBytes((int)(reader.BaseStream.Length - reader.BaseStream.Position)); + } + + public override string ToString() + { + return key.ToHexString(); + } + + public static implicit operator TestKey(byte[] key) + { + return new TestKey(key); + } + } + + public class TestValue : ISerializable + { + private byte[] value; + + public int Size => value.Length; + + public TestValue() + { + this.value = Array.Empty(); + } + + public TestValue(byte[] value) + { + this.value = value; + } + + public void Serialize(BinaryWriter writer) + { + writer.Write(value); + } + + public void Deserialize(BinaryReader reader) + { + value = reader.ReadBytes((int)(reader.BaseStream.Length - reader.BaseStream.Position)); + } + + public override string ToString() + { + return value.ToHexString(); + } + + public static implicit operator TestValue(byte[] value) + { + return new TestValue(value); + } + } + + [TestClass] + public class UT_MPTTrie + { + private MPTNode root; + private IStore mptdb; + + private void PutToStore(MPTNode node) + { + mptdb.Put(0xf0, node.Hash.ToArray(), node.Encode()); + } + + [TestInitialize] + public void TestInit() + { + var b = new BranchNode(); + var r = new ExtensionNode { Key = "0a0c".HexToBytes(), Next = b }; + var v1 = new LeafNode { Value = "abcd".HexToBytes() }; + var v2 = new LeafNode { Value = "2222".HexToBytes() }; + var v3 = new LeafNode { Value = Encoding.ASCII.GetBytes("hello") }; + var h1 = new HashNode(v3.Hash); + var l1 = new ExtensionNode { Key = new byte[] { 0x01 }, Next = v1 }; + var l2 = new ExtensionNode { Key = new byte[] { 0x09 }, Next = v2 }; + var l3 = new ExtensionNode { Key = "0e".HexToBytes(), Next = h1 }; + b.Children[0] = l1; + b.Children[9] = l2; + b.Children[10] = l3; + this.root = r; + this.mptdb = new MemoryStore(); + PutToStore(r); + PutToStore(b); + PutToStore(l1); + PutToStore(l2); + PutToStore(l3); + PutToStore(v1); + PutToStore(v2); + PutToStore(v3); + } + + [TestMethod] + public void TestTryGet() + { + var mpt = new MPTTrie(mptdb.GetSnapshot(), root.Hash); + Assert.AreEqual("abcd", mpt["ac01".HexToBytes()].ToString()); + Assert.AreEqual("2222", mpt["ac99".HexToBytes()].ToString()); + Assert.IsNull(mpt["ab99".HexToBytes()]); + Assert.IsNull(mpt["ac39".HexToBytes()]); + Assert.IsNull(mpt["ac02".HexToBytes()]); + Assert.IsNull(mpt["ac9910".HexToBytes()]); + } + + [TestMethod] + public void TestTryGetResolve() + { + var mpt = new MPTTrie(mptdb.GetSnapshot(), root.Hash); + Assert.AreEqual(Encoding.ASCII.GetBytes("hello").ToHexString(), mpt["acae".HexToBytes()].ToString()); + } + + [TestMethod] + public void TestTryPut() + { + var store = new MemoryStore(); + var mpt = new MPTTrie(store.GetSnapshot(), null); + var result = mpt.Put("ac01".HexToBytes(), "abcd".HexToBytes()); + Assert.IsTrue(result); + result = mpt.Put("ac99".HexToBytes(), "2222".HexToBytes()); + Assert.IsTrue(result); + result = mpt.Put("acae".HexToBytes(), Encoding.ASCII.GetBytes("hello")); + Assert.IsTrue(result); + Assert.AreEqual(root.Hash.ToString(), mpt.Root.Hash.ToString()); + } + + [TestMethod] + public void TestTryDelete() + { + var b = new BranchNode(); + var r = new ExtensionNode { Key = "0a0c".HexToBytes(), Next = b }; + var v1 = new LeafNode { Value = "abcd".HexToBytes() }; + var v2 = new LeafNode { Value = "2222".HexToBytes() }; + var r1 = new ExtensionNode { Key = "0a0c0001".HexToBytes(), Next = v1 }; + var l1 = new ExtensionNode { Key = new byte[] { 0x01 }, Next = v1 }; + var l2 = new ExtensionNode { Key = new byte[] { 0x09 }, Next = v2 }; + b.Children[0] = l1; + b.Children[9] = l2; + + Assert.AreEqual("0xdea3ab46e9461e885ed7091c1e533e0a8030b248d39cbc638962394eaca0fbb3", r1.Hash.ToString()); + Assert.AreEqual("0x93e8e1ffe2f83dd92fca67330e273bcc811bf64b8f8d9d1b25d5e7366b47d60d", r.Hash.ToString()); + + var mpt = new MPTTrie(mptdb.GetSnapshot(), root.Hash); + Assert.IsNotNull(mpt["ac99".HexToBytes()]); + bool result = mpt.Delete("ac99".HexToBytes()); + Assert.IsTrue(result); + result = mpt.Delete("acae".HexToBytes()); + Assert.IsTrue(result); + Assert.AreEqual("0xdea3ab46e9461e885ed7091c1e533e0a8030b248d39cbc638962394eaca0fbb3", mpt.Root.Hash.ToString()); + } + + [TestMethod] + public void TestDeleteSameValue() + { + var store = new MemoryStore(); + var snapshot = store.GetSnapshot(); + var mpt = new MPTTrie(snapshot, null); + Assert.IsTrue(mpt.Put("ac01".HexToBytes(), "abcd".HexToBytes())); + Assert.IsTrue(mpt.Put("ac02".HexToBytes(), "abcd".HexToBytes())); + Assert.IsNotNull(mpt["ac01".HexToBytes()]); + Assert.IsNotNull(mpt["ac02".HexToBytes()]); + mpt.Delete("ac01".HexToBytes()); + Assert.IsNotNull(mpt["ac02".HexToBytes()]); + snapshot.Commit(); + + var mpt0 = new MPTTrie(store.GetSnapshot(), mpt.Root.Hash); + Assert.IsNotNull(mpt0["ac02".HexToBytes()]); + } + + [TestMethod] + public void TestBranchNodeRemainValue() + { + var store = new MemoryStore(); + var mpt = new MPTTrie(store.GetSnapshot(), null); + Assert.IsTrue(mpt.Put("ac11".HexToBytes(), "ac11".HexToBytes())); + Assert.IsTrue(mpt.Put("ac22".HexToBytes(), "ac22".HexToBytes())); + Assert.IsTrue(mpt.Put("ac".HexToBytes(), "ac".HexToBytes())); + Assert.IsTrue(mpt.Delete("ac11".HexToBytes())); + mpt.Delete("ac22".HexToBytes()); + Assert.IsNotNull(mpt["ac".HexToBytes()]); + } + + [TestMethod] + public void TestGetProof() + { + var b = new BranchNode(); + var r = new ExtensionNode { Key = "0a0c".HexToBytes(), Next = b }; + var v1 = new LeafNode { Value = "abcd".HexToBytes() }; + var v2 = new LeafNode { Value = "2222".HexToBytes() }; + var v3 = new LeafNode { Value = Encoding.ASCII.GetBytes("hello") }; + var h1 = new HashNode(v3.Hash); + var l1 = new ExtensionNode { Key = new byte[] { 0x01 }, Next = v1 }; + var l2 = new ExtensionNode { Key = new byte[] { 0x09 }, Next = v2 }; + var l3 = new ExtensionNode { Key = "0e".HexToBytes(), Next = h1 }; + b.Children[0] = l1; + b.Children[9] = l2; + b.Children[10] = l3; + + var mpt = new MPTTrie(mptdb.GetSnapshot(), root.Hash); + Assert.AreEqual(r.Hash.ToString(), mpt.Root.Hash.ToString()); + HashSet proof = mpt.GetProof("ac01".HexToBytes()); + Assert.AreEqual(4, proof.Count); + Assert.IsTrue(proof.Contains(b.Encode())); + Assert.IsTrue(proof.Contains(r.Encode())); + Assert.IsTrue(proof.Contains(l1.Encode())); + Assert.IsTrue(proof.Contains(v1.Encode())); + } + + [TestMethod] + public void TestVerifyProof() + { + var mpt = new MPTTrie(mptdb.GetSnapshot(), root.Hash); + HashSet proof = mpt.GetProof("ac01".HexToBytes()); + TestValue value = MPTTrie.VerifyProof(root.Hash, "ac01".HexToBytes(), proof); + Assert.IsNotNull(value); + Assert.AreEqual(value.ToString(), "abcd"); + } + + [TestMethod] + public void TestAddLongerKey() + { + var store = new MemoryStore(); + var snapshot = store.GetSnapshot(); + var mpt = new MPTTrie(snapshot, null); + var result = mpt.Put(new byte[] { 0xab }, new byte[] { 0x01 }); + Assert.IsTrue(result); + result = mpt.Put(new byte[] { 0xab, 0xcd }, new byte[] { 0x02 }); + Assert.IsTrue(result); + } + + [TestMethod] + public void TestSplitKey() + { + var store = new MemoryStore(); + var snapshot = store.GetSnapshot(); + var mpt1 = new MPTTrie(snapshot, null); + Assert.IsTrue(mpt1.Put(new byte[] { 0xab, 0xcd }, new byte[] { 0x01 })); + Assert.IsTrue(mpt1.Put(new byte[] { 0xab }, new byte[] { 0x02 })); + HashSet set1 = mpt1.GetProof(new byte[] { 0xab, 0xcd }); + Assert.AreEqual(4, set1.Count); + var mpt2 = new MPTTrie(snapshot, null); + Assert.IsTrue(mpt2.Put(new byte[] { 0xab }, new byte[] { 0x02 })); + Assert.IsTrue(mpt2.Put(new byte[] { 0xab, 0xcd }, new byte[] { 0x01 })); + HashSet set2 = mpt2.GetProof(new byte[] { 0xab, 0xcd }); + Assert.AreEqual(4, set2.Count); + Assert.AreEqual(mpt1.Root.Hash, mpt2.Root.Hash); + } + + [TestMethod] + public void TestFind() + { + var store = new MemoryStore(); + var snapshot = store.GetSnapshot(); + var mpt1 = new MPTTrie(snapshot, null); + var results = mpt1.Find(ReadOnlySpan.Empty).ToArray(); + Assert.AreEqual(0, results.Count()); + var mpt2 = new MPTTrie(snapshot, null); + Assert.IsTrue(mpt2.Put(new byte[] { 0xab, 0xcd, 0xef }, new byte[] { 0x01 })); + Assert.IsTrue(mpt2.Put(new byte[] { 0xab, 0xcd, 0xe1 }, new byte[] { 0x02 })); + Assert.IsTrue(mpt2.Put(new byte[] { 0xab }, new byte[] { 0x03 })); + results = mpt2.Find(ReadOnlySpan.Empty).ToArray(); + Assert.AreEqual(3, results.Count()); + results = mpt2.Find(new byte[] { 0xab }).ToArray(); + Assert.AreEqual(3, results.Count()); + results = mpt2.Find(new byte[] { 0xab, 0xcd }).ToArray(); + Assert.AreEqual(2, results.Count()); + results = mpt2.Find(new byte[] { 0xac }).ToArray(); + Assert.AreEqual(0, results.Count()); + } + + [TestMethod] + public void TestFindLeadNode() + { + // r.Key = 0x0a0c + // b.Key = 0x00 + // l1.Key = 0x01 + var mpt = new MPTTrie(mptdb.GetSnapshot(), root.Hash); + var prefix = new byte[] { 0xac, 0x01 }; // = FromNibbles(path = { 0x0a, 0x0c, 0x00, 0x01 }); + var results = mpt.Find(prefix).ToArray(); + Assert.AreEqual(1, results.Count()); + + prefix = new byte[] { 0xac }; // = FromNibbles(path = { 0x0a, 0x0c }); + results = mpt.Find(prefix).ToArray(); + Assert.AreEqual(3, results.Count()); + } + } +}