Skip to content

Commit

Permalink
feat: return error codes in binary merkle tree
Browse files Browse the repository at this point in the history
  • Loading branch information
rach-id committed Jan 12, 2024
1 parent 20b4d30 commit 82d67c3
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 44 deletions.
2 changes: 1 addition & 1 deletion src/Blobstream.sol
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ contract Blobstream is IDAOracle, Initializable, UUPSUpgradeable, OwnableUpgrade
bytes32 root = state_dataRootTupleRoots[_tupleRootNonce];

// Verify the proof.
bool isProofValid = BinaryMerkleTree.verify(root, _proof, abi.encode(_tuple));
(bool isProofValid,) = BinaryMerkleTree.verify(root, _proof, abi.encode(_tuple));

return isProofValid;
}
Expand Down
80 changes: 59 additions & 21 deletions src/lib/tree/binary/BinaryMerkleTree.sol
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,50 @@ import "./BinaryMerkleProof.sol";

/// @title Binary Merkle Tree.
library BinaryMerkleTree {
/////////////////
// Error codes //
/////////////////

enum ErrorCodes {
NoError,
/// @notice The provided side nodes count is invalid for the proof.
InvalidNumberOfSideNodes,
/// @notice The provided proof key is not part of the tree.
KeyNotInTree,
/// @notice The number of leaves in the binary merkle proof is not divisible by 4.
InvalidNumberOfLeavesInProof,
/// @notice The proof contains unexpected side nodes.
UnexpectedInnerHashes,
/// @notice The proof verification expected at least one inner hash.
ExpectedAtLeastOneInnerHash
}

///////////////
// Functions //
///////////////

/// @notice Verify if element exists in Merkle tree, given data, proof, and root.
/// @param root The root of the tree in which verify the given leaf.
/// @param proof Binary Merkle proof for the leaf.
/// @param data The data of the leaf to verify.
/// @return `true` is proof is valid, `false` otherwise.
/// @dev proof.numLeaves is necessary to determine height of subtree containing the data to prove.
function verify(bytes32 root, BinaryMerkleProof memory proof, bytes memory data) internal pure returns (bool) {
function verify(bytes32 root, BinaryMerkleProof memory proof, bytes memory data)
internal
pure
returns (bool, ErrorCodes)
{
// Check proof is correct length for the key it is proving
if (proof.numLeaves <= 1) {
if (proof.sideNodes.length != 0) {
return false;
}
} else if (proof.sideNodes.length != pathLengthFromKey(proof.key, proof.numLeaves)) {
return false;
if (
proof.numLeaves <= 1 && proof.sideNodes.length != 0
|| proof.sideNodes.length != pathLengthFromKey(proof.key, proof.numLeaves)
) {
return (false, ErrorCodes.InvalidNumberOfSideNodes);
}

// Check key is in tree
if (proof.key >= proof.numLeaves) {
return false;
return (false, ErrorCodes.KeyNotInTree);
}

// A sibling at height 1 is created by getting the hash of the data to prove.
Expand All @@ -36,15 +61,19 @@ library BinaryMerkleTree {
// If so, just verify hash(data) is root
if (proof.sideNodes.length == 0) {
if (proof.numLeaves == 1) {
return (root == digest);
return (root == digest, ErrorCodes.NoError);
} else {
return false;
return (false, ErrorCodes.NoError);
}
}

bytes32 computedHash = computeRootHash(proof.key, proof.numLeaves, digest, proof.sideNodes);
(bytes32 computedHash, ErrorCodes error) = computeRootHash(proof.key, proof.numLeaves, digest, proof.sideNodes);

if (error != ErrorCodes.NoError) {
return (false, error);
}

return (computedHash == root);
return (computedHash == root, ErrorCodes.NoError);
}

/// @notice Use the leafHash and innerHashes to get the root merkle hash.
Expand All @@ -53,28 +82,37 @@ library BinaryMerkleTree {
function computeRootHash(uint256 key, uint256 numLeaves, bytes32 leafHash, bytes32[] memory sideNodes)
private
pure
returns (bytes32)
returns (bytes32, ErrorCodes)
{
if (numLeaves == 0) {
revert("cannot call computeRootHash with 0 number of leaves");
return (leafHash, ErrorCodes.InvalidNumberOfLeavesInProof);
}
if (numLeaves == 1) {
if (sideNodes.length != 0) {
revert("unexpected inner hashes");
return (leafHash, ErrorCodes.UnexpectedInnerHashes);
}
return leafHash;
return (leafHash, ErrorCodes.NoError);
}
if (sideNodes.length == 0) {
revert("expected at least one inner hash");
return (leafHash, ErrorCodes.ExpectedAtLeastOneInnerHash);
}
uint256 numLeft = _getSplitPoint(numLeaves);
bytes32[] memory sideNodesLeft = slice(sideNodes, 0, sideNodes.length - 1);
ErrorCodes error;
if (key < numLeft) {
bytes32 leftHash = computeRootHash(key, numLeft, leafHash, sideNodesLeft);
return nodeDigest(leftHash, sideNodes[sideNodes.length - 1]);
bytes32 leftHash;
(leftHash, error) = computeRootHash(key, numLeft, leafHash, sideNodesLeft);
if (error != ErrorCodes.NoError) {
return (leafHash, error);
}
return (nodeDigest(leftHash, sideNodes[sideNodes.length - 1]), ErrorCodes.NoError);
}
bytes32 rightHash;
(rightHash, error) = computeRootHash(key - numLeft, numLeaves - numLeft, leafHash, sideNodesLeft);
if (error != ErrorCodes.NoError) {
return (leafHash, error);
}
bytes32 rightHash = computeRootHash(key - numLeft, numLeaves - numLeft, leafHash, sideNodesLeft);
return nodeDigest(sideNodes[sideNodes.length - 1], rightHash);
return (nodeDigest(sideNodes[sideNodes.length - 1], rightHash), ErrorCodes.NoError);
}

/// @notice creates a slice of bytes32 from the data slice of bytes32 containing the elements
Expand Down
51 changes: 32 additions & 19 deletions src/lib/tree/binary/test/BinaryMerkleTree.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ contract BinaryMerkleProofTest is DSTest {
uint256 numLeaves = 0;
BinaryMerkleProof memory proof = BinaryMerkleProof(sideNodes, key, numLeaves);
bytes memory data;
bool isValid = BinaryMerkleTree.verify(root, proof, data);
(bool isValid,) = BinaryMerkleTree.verify(root, proof, data);
assertTrue(!isValid);
}

Expand All @@ -62,7 +62,8 @@ contract BinaryMerkleProofTest is DSTest {
uint256 numLeaves = 1;
BinaryMerkleProof memory proof = BinaryMerkleProof(sideNodes, key, numLeaves);
bytes memory data;
bool isValid = BinaryMerkleTree.verify(root, proof, data);
(bool isValid, BinaryMerkleTree.ErrorCodes error) = BinaryMerkleTree.verify(root, proof, data);
assertEq(uint256(BinaryMerkleTree.ErrorCodes.NoError), uint256(error));
assertTrue(isValid);
}

Expand All @@ -73,7 +74,8 @@ contract BinaryMerkleProofTest is DSTest {
uint256 numLeaves = 1;
BinaryMerkleProof memory proof = BinaryMerkleProof(sideNodes, key, numLeaves);
bytes memory data = hex"deadbeef";
bool isValid = BinaryMerkleTree.verify(root, proof, data);
(bool isValid, BinaryMerkleTree.ErrorCodes error) = BinaryMerkleTree.verify(root, proof, data);
assertEq(uint256(BinaryMerkleTree.ErrorCodes.NoError), uint256(error));
assertTrue(isValid);
}

Expand All @@ -84,7 +86,8 @@ contract BinaryMerkleProofTest is DSTest {
uint256 numLeaves = 1;
BinaryMerkleProof memory proof = BinaryMerkleProof(sideNodes, key, numLeaves);
bytes memory data = hex"01";
bool isValid = BinaryMerkleTree.verify(root, proof, data);
(bool isValid, BinaryMerkleTree.ErrorCodes error) = BinaryMerkleTree.verify(root, proof, data);
assertEq(uint256(BinaryMerkleTree.ErrorCodes.NoError), uint256(error));
assertTrue(isValid);
}

Expand All @@ -99,7 +102,8 @@ contract BinaryMerkleProofTest is DSTest {
uint256 numLeaves = 8;
BinaryMerkleProof memory proof = BinaryMerkleProof(sideNodes, key, numLeaves);
bytes memory data = hex"01";
bool isValid = BinaryMerkleTree.verify(root, proof, data);
(bool isValid, BinaryMerkleTree.ErrorCodes error) = BinaryMerkleTree.verify(root, proof, data);
assertEq(uint256(BinaryMerkleTree.ErrorCodes.NoError), uint256(error));
assertTrue(isValid);
}

Expand All @@ -114,7 +118,8 @@ contract BinaryMerkleProofTest is DSTest {
uint256 numLeaves = 8;
BinaryMerkleProof memory proof = BinaryMerkleProof(sideNodes, key, numLeaves);
bytes memory data = hex"02";
bool isValid = BinaryMerkleTree.verify(root, proof, data);
(bool isValid, BinaryMerkleTree.ErrorCodes error) = BinaryMerkleTree.verify(root, proof, data);
assertEq(uint256(BinaryMerkleTree.ErrorCodes.NoError), uint256(error));
assertTrue(isValid);
}

Expand All @@ -129,7 +134,8 @@ contract BinaryMerkleProofTest is DSTest {
uint256 numLeaves = 8;
BinaryMerkleProof memory proof = BinaryMerkleProof(sideNodes, key, numLeaves);
bytes memory data = hex"03";
bool isValid = BinaryMerkleTree.verify(root, proof, data);
(bool isValid, BinaryMerkleTree.ErrorCodes error) = BinaryMerkleTree.verify(root, proof, data);
assertEq(uint256(BinaryMerkleTree.ErrorCodes.NoError), uint256(error));
assertTrue(isValid);
}

Expand All @@ -144,7 +150,8 @@ contract BinaryMerkleProofTest is DSTest {
uint256 numLeaves = 8;
BinaryMerkleProof memory proof = BinaryMerkleProof(sideNodes, key, numLeaves);
bytes memory data = hex"07";
bool isValid = BinaryMerkleTree.verify(root, proof, data);
(bool isValid, BinaryMerkleTree.ErrorCodes error) = BinaryMerkleTree.verify(root, proof, data);
assertEq(uint256(BinaryMerkleTree.ErrorCodes.NoError), uint256(error));
assertTrue(isValid);
}

Expand All @@ -159,7 +166,8 @@ contract BinaryMerkleProofTest is DSTest {
uint256 numLeaves = 8;
BinaryMerkleProof memory proof = BinaryMerkleProof(sideNodes, key, numLeaves);
bytes memory data = hex"08";
bool isValid = BinaryMerkleTree.verify(root, proof, data);
(bool isValid, BinaryMerkleTree.ErrorCodes error) = BinaryMerkleTree.verify(root, proof, data);
assertEq(uint256(BinaryMerkleTree.ErrorCodes.NoError), uint256(error));
assertTrue(isValid);
}

Expand All @@ -180,7 +188,8 @@ contract BinaryMerkleProofTest is DSTest {
uint256 numLeaves = 5;
BinaryMerkleProof memory proof = BinaryMerkleProof(sideNodes, key, numLeaves);
bytes memory data = bytes(hex"01");
bool isValid = BinaryMerkleTree.verify(root, proof, data);
(bool isValid, BinaryMerkleTree.ErrorCodes error) = BinaryMerkleTree.verify(root, proof, data);
assertEq(uint256(BinaryMerkleTree.ErrorCodes.NoError), uint256(error));
assertTrue(isValid);
}

Expand All @@ -196,7 +205,8 @@ contract BinaryMerkleProofTest is DSTest {
uint256 numLeaves = 5;
BinaryMerkleProof memory proof = BinaryMerkleProof(sideNodes, key, numLeaves);
bytes memory data = bytes(hex"01");
bool isValid = BinaryMerkleTree.verify(root, proof, data);
(bool isValid, BinaryMerkleTree.ErrorCodes error) = BinaryMerkleTree.verify(root, proof, data);
assertEq(uint256(BinaryMerkleTree.ErrorCodes.NoError), uint256(error));
assertTrue(!isValid);
}

Expand All @@ -212,7 +222,8 @@ contract BinaryMerkleProofTest is DSTest {
uint256 numLeaves = 5;
BinaryMerkleProof memory proof = BinaryMerkleProof(sideNodes, key, numLeaves);
bytes memory data = bytes(hex"01");
bool isValid = BinaryMerkleTree.verify(root, proof, data);
(bool isValid, BinaryMerkleTree.ErrorCodes error) = BinaryMerkleTree.verify(root, proof, data);
assertEq(uint256(BinaryMerkleTree.ErrorCodes.NoError), uint256(error));
assertTrue(!isValid);
}

Expand All @@ -228,7 +239,7 @@ contract BinaryMerkleProofTest is DSTest {
uint256 numLeaves = 200;
BinaryMerkleProof memory proof = BinaryMerkleProof(sideNodes, key, numLeaves);
bytes memory data = bytes(hex"01");
bool isValid = BinaryMerkleTree.verify(root, proof, data);
(bool isValid,) = BinaryMerkleTree.verify(root, proof, data);
assertTrue(!isValid);
}

Expand All @@ -244,7 +255,8 @@ contract BinaryMerkleProofTest is DSTest {
uint256 numLeaves = 5;
BinaryMerkleProof memory proof = BinaryMerkleProof(sideNodes, key, numLeaves);
bytes memory data = bytes(hex"01");
bool isValid = BinaryMerkleTree.verify(root, proof, data);
(bool isValid, BinaryMerkleTree.ErrorCodes error) = BinaryMerkleTree.verify(root, proof, data);
assertEq(uint256(BinaryMerkleTree.ErrorCodes.NoError), uint256(error));
assertTrue(!isValid);
}

Expand All @@ -260,7 +272,8 @@ contract BinaryMerkleProofTest is DSTest {
BinaryMerkleProof memory proof = BinaryMerkleProof(sideNodes, key, numLeaves);
// correct data: 01
bytes memory data = bytes(hex"012345");
bool isValid = BinaryMerkleTree.verify(root, proof, data);
(bool isValid, BinaryMerkleTree.ErrorCodes error) = BinaryMerkleTree.verify(root, proof, data);
assertEq(uint256(BinaryMerkleTree.ErrorCodes.NoError), uint256(error));
assertTrue(!isValid);
}

Expand All @@ -284,8 +297,8 @@ contract BinaryMerkleProofTest is DSTest {
uint256 numLeaves = 3;
BinaryMerkleProof memory proof = BinaryMerkleProof(sideNodes, key, numLeaves);
bytes memory data = bytes(hex"01");
bool isValid = BinaryMerkleTree.verify(root, proof, data);
assert(!isValid);
(bool isValid,) = BinaryMerkleTree.verify(root, proof, data);
assertTrue(!isValid);
}

function testConsecutiveKeyAndNumberOfLeaves() external {
Expand All @@ -295,8 +308,8 @@ contract BinaryMerkleProofTest is DSTest {
uint256 numLeaves = 7;
BinaryMerkleProof memory proof = BinaryMerkleProof(sideNodes, key, numLeaves);
bytes memory data = bytes(hex"01");
bool isValid = BinaryMerkleTree.verify(root, proof, data);
assert(!isValid);
(bool isValid,) = BinaryMerkleTree.verify(root, proof, data);
assertTrue(!isValid);
}

function testInvalidSliceBeginEnd() public {
Expand Down
6 changes: 4 additions & 2 deletions src/lib/verifier/DAVerifier.sol
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ library DAVerifier {
}

bytes memory rowRoot = abi.encodePacked(_rowRoot.min.toBytes(), _rowRoot.max.toBytes(), _rowRoot.digest);
if (!BinaryMerkleTree.verify(_root, _rowProof, rowRoot)) {
(bool valid,) = BinaryMerkleTree.verify(_root, _rowProof, rowRoot);
if (!valid) {
return (false, ErrorCodes.InvalidRowToDataRootProof);
}

Expand Down Expand Up @@ -190,7 +191,8 @@ library DAVerifier {
for (uint256 i = 0; i < _rowProofs.length; i++) {
bytes memory rowRoot =
abi.encodePacked(_rowRoots[i].min.toBytes(), _rowRoots[i].max.toBytes(), _rowRoots[i].digest);
if (!BinaryMerkleTree.verify(_root, _rowProofs[i], rowRoot)) {
(bool valid,) = BinaryMerkleTree.verify(_root, _rowProofs[i], rowRoot);
if (!valid) {
return (false, ErrorCodes.InvalidRowsToDataRootProof);
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/lib/verifier/test/DAVerifier.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,9 @@ contract DAVerifierTest is DSTest {
}

function testComputeSquareSizeFromRowProof() public {
bool validMerkleProof =
(bool validMerkleProof, BinaryMerkleTree.ErrorCodes error) =
BinaryMerkleTree.verify(fixture.dataRoot(), fixture.getRowRootToDataRootProof(), fixture.firstRowRoot());
assertEq(uint256(error), uint256(BinaryMerkleTree.ErrorCodes.NoError));
assertTrue(validMerkleProof);

// check that the computed square size is correct
Expand Down

0 comments on commit 82d67c3

Please sign in to comment.