Skip to content

Commit

Permalink
restore support for proofs of exclusion in MerkleSet and add comments…
Browse files Browse the repository at this point in the history
… describing how it works (#485)
  • Loading branch information
arvidn authored Apr 23, 2024
1 parent adc835d commit cb8ad69
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 43 deletions.
2 changes: 1 addition & 1 deletion crates/chia-consensus/benches/merkle-set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ fn run(c: &mut Criterion) {
proofs.push(
tree.generate_proof(item)
.expect("failed to generate proof")
.expect("item not found"),
.1,
);
}

Expand Down
17 changes: 9 additions & 8 deletions crates/chia-consensus/fuzz/fuzz_targets/merkle-set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@ fuzz_target!(|data: &[u8]| {
let tree = MerkleSet::from_leafs(&mut leafs);

for item in &leafs {
let proof = tree
.generate_proof(item)
.expect("failed to generate proof")
.expect("item not found");
let (true, proof) = tree.generate_proof(item).expect("failed to generate proof") else {
panic!("item is expected to exist");
};
let rebuilt = MerkleSet::from_proof(&proof).expect("failed to parse proof");
assert!(rebuilt
.generate_proof(item)
.expect("failed to validate proof")
.is_some());
assert!(
rebuilt
.generate_proof(item)
.expect("failed to validate proof")
.0
);
assert_eq!(rebuilt.get_root(), tree.get_root());
}
});
98 changes: 71 additions & 27 deletions crates/chia-consensus/src/merkle_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ pub enum ArrayTypes {
#[cfg_attr(feature = "py-bindings", pyclass(frozen, name = "MerkleSet"))]
pub struct MerkleSet {
nodes_vec: Vec<(ArrayTypes, [u8; 32])>,
// This is true if the tree was built from a proof. This means the tree may
// include truncated sub-trees and we can't (necessarily) produce new proofs
// as they don't round-trip. The original python implementation had some
// additional complexity to support round-tripping proofs, but we don't use
// it or need it anywhere.
from_proof: bool,
}

const EMPTY: u8 = 0;
Expand All @@ -51,7 +57,10 @@ pub struct SetError;

impl MerkleSet {
pub fn from_proof(proof: &[u8]) -> Result<MerkleSet, SetError> {
let mut merkle_tree = MerkleSet::default();
let mut merkle_tree = MerkleSet {
from_proof: true,
..Default::default()
};
merkle_tree.deserialize_proof_impl(proof)?;
Ok(merkle_tree)
}
Expand Down Expand Up @@ -196,17 +205,20 @@ impl MerkleSet {
}
}

// searches for "leaf" in the tree, returns the proof if found, otherwise None
pub fn generate_proof(&self, leaf: &[u8; 32]) -> Result<Option<Vec<u8>>, SetError> {
// produces a proof that leaf exists or does not exist in the merkle set.
// returns a bool where true means it's a proof-of-inclusion and false means
// it's a proof-of-exclusion.
pub fn generate_proof(&self, leaf: &[u8; 32]) -> Result<(bool, Vec<u8>), SetError> {
let mut proof = Vec::new();
if self.is_included(self.nodes_vec.len() - 1, leaf, &mut proof, 0)? {
Ok(Some(proof))
let included = self.generate_proof_impl(self.nodes_vec.len() - 1, leaf, &mut proof, 0)?;
if self.from_proof {
Ok((included, vec![]))
} else {
Ok(None)
Ok((included, proof))
}
}

fn is_included(
fn generate_proof_impl(
&self,
current_node_index: usize,
leaf: &[u8; 32],
Expand Down Expand Up @@ -237,9 +249,8 @@ impl MerkleSet {
&self.nodes_vec[right as usize].1,
depth,
);

// TODO: It's a bit odd to check for set membership and
// generating a proof in the same function
// if the leaf match, it's a proof-of-inclusion, otherwise,
// it's a proof-of-exclusion
return Ok(&self.nodes_vec[left as usize].1 == leaf
|| &self.nodes_vec[right as usize].1 == leaf);
}
Expand All @@ -248,10 +259,10 @@ impl MerkleSet {
if get_bit(leaf, depth) {
// bit is 1 so truncate left branch and search right branch
self.other_included(left as usize, proof)?;
self.is_included(right as usize, leaf, proof, depth + 1)
self.generate_proof_impl(right as usize, leaf, proof, depth + 1)
} else {
// bit is 0 is search left and then truncate right branch
let r = self.is_included(left as usize, leaf, proof, depth + 1)?;
let r = self.generate_proof_impl(left as usize, leaf, proof, depth + 1)?;
self.other_included(right as usize, proof)?;
Ok(r)
}
Expand All @@ -261,6 +272,11 @@ impl MerkleSet {
}

// this function builds the proof of the subtree we are not traversing
// even though this sub-tree does not hold any proof-value, we need it to
// compute and validate the root hash. When computing hashes, we collapse
// tree levels that terminate in a double-leaf node. So, when validating the
// proof, we'll need to full sub tree in that case, to enable correctly
// computing the root hash.
fn other_included(
&self,
current_node_index: usize,
Expand Down Expand Up @@ -290,6 +306,14 @@ impl MerkleSet {
// item's hash. However, when we compute node hashes (and the root hash) we *do*
// collapse sequences of empty nodes. This function re-introduces them into the
// proof.
// When producing proofs-of-exclusion, it's not technically necessary to
// expand these nodes all the way down to the leafs. We just need to hit an
// empty node where the excluded item would have been. However, when computing
// the root hash from a proof, we absolutely need to know whether a truncated
// tree is a "double-mid" or a normal mid node. That affects how the hashes are
// computed. So the current proof format does not support early truncation of
// these kinds of trees. We would need a new code, say "4", to mean truncated
// double node.
fn pad_middles_for_proof_gen(proof: &mut Vec<u8>, left: &[u8; 32], right: &[u8; 32], depth: u8) {
let left_bit = get_bit(left, depth);
let right_bit = get_bit(right, depth);
Expand Down Expand Up @@ -337,8 +361,7 @@ impl MerkleSet {
included_leaf: [u8; 32],
) -> PyResult<(bool, PyObject)> {
match self.generate_proof(&included_leaf) {
Ok(Some(proof)) => Ok((true, PyBytes::new(py, &proof).into())),
Ok(None) => Ok((false, PyBytes::new(py, &[]).into())),
Ok((included, proof)) => Ok((included, PyBytes::new(py, &proof).into())),
Err(_) => Err(PyValueError::new_err("invalid proof")),
}
}
Expand Down Expand Up @@ -366,7 +389,10 @@ impl MerkleSet {
// this is an expanded version of the radix sort function which builds the merkle tree and its hash cache as it goes
pub fn from_leafs(leafs: &mut [[u8; 32]]) -> MerkleSet {
// Leafs are already hashed
let mut merkle_tree = MerkleSet::default();
let mut merkle_tree = MerkleSet {
from_proof: false,
..Default::default()
};

// There's a special case for empty sets
if leafs.is_empty() {
Expand Down Expand Up @@ -570,27 +596,34 @@ mod tests {
let root = tree.get_root();
assert_eq!(root, tree.get_merkle_root_old());
assert_eq!(compute_merkle_set_root(leafs), root);

// === proofs-of-inclusion ===
for item in leafs {
let Ok(Some(proof)) = tree.generate_proof(item) else {
let Ok((included, proof)) = tree.generate_proof(item) else {
panic!("failed to generate proof");
};
assert!(included);
let rebuilt = MerkleSet::from_proof(&proof).expect("failed to parse proof");
assert_eq!(rebuilt.get_root(), root);
assert_eq!(rebuilt.generate_proof(item).unwrap(), Some(proof));
let (included, new_proof) = rebuilt.generate_proof(&item).unwrap();
assert!(included);
assert_eq!(new_proof, vec![]);
assert_eq!(rebuilt.get_root(), root);
}

// === proofs-of-exclusion ===
let mut rng = SmallRng::seed_from_u64(42);
// make sure that random hashes are never considered part of the tree
for _ in 0..1000 {
let mut item = [0_u8; 32];
rng.fill(&mut item);
match tree.generate_proof(&item) {
Err(_) => {}
Ok(None) => {}
Ok(Some(_)) => {
panic!("proof for random value");
}
}
let (included, proof) = tree.generate_proof(&item).unwrap();
assert!(!included);
let rebuilt = MerkleSet::from_proof(&proof).expect("failed to parse proof");
let (included, new_proof) = rebuilt.generate_proof(&item).unwrap();
assert!(!included);
assert_eq!(new_proof, vec![]);
assert_eq!(rebuilt.get_root(), root);
}
}

Expand Down Expand Up @@ -638,7 +671,7 @@ mod tests {
let root = tree.get_root();
assert_eq!(root, compute_merkle_set_root(&mut random_data));
let index = rng.gen_range(0..random_data.len());
let Ok(Some(proof)) = tree.generate_proof(&random_data[index]) else {
let Ok((true, proof)) = tree.generate_proof(&random_data[index]) else {
panic!("failed to generate proof");
};
let rebuilt = MerkleSet::from_proof(&proof[0..proof.len() - 2]);
Expand Down Expand Up @@ -687,6 +720,7 @@ mod tests {
// no collapsing
let a = hex!("c000000000000000000000000000000000000000000000000000000000000000");
let b = hex!("c800000000000000000000000000000000000000000000000000000000000000");
let c = hex!("7000000000000000000000000000000000000000000000000000000000000000");
// these leafs form a tree that look like this:
// o
// / \
Expand All @@ -705,11 +739,21 @@ mod tests {
// / \
// a b
let tree = MerkleSet::from_leafs(&mut [a, b]);
let proof = tree.generate_proof(&b).unwrap().unwrap();
let (true, proof) = tree.generate_proof(&b).unwrap() else {
panic!("failed to generate proof");
};
assert_eq!(hex::encode(proof), "0200020002020201c00000000000000000000000000000000000000000000000000000000000000001c8000000000000000000000000000000000000000000000000000000000000000000");

// in fact, the proof for a looks the same, since a and b are siblings
let proof = tree.generate_proof(&b).unwrap().unwrap();
let (true, proof) = tree.generate_proof(&b).unwrap() else {
panic!("failed to generate proof");
};
assert_eq!(hex::encode(proof), "0200020002020201c00000000000000000000000000000000000000000000000000000000000000001c8000000000000000000000000000000000000000000000000000000000000000000");

// proofs of exclusion must also be complete
let (false, proof) = tree.generate_proof(&c).unwrap() else {
panic!("failed to generate proof");
};
assert_eq!(hex::encode(proof), "0200020002020201c00000000000000000000000000000000000000000000000000000000000000001c8000000000000000000000000000000000000000000000000000000000000000000");
}

Expand Down
35 changes: 28 additions & 7 deletions tests/test_merkle_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,18 @@ def check_proof(
*,
root: bytes32,
item: bytes32,
expect_included: bool = True,
) -> None:
proof_tree = deserialize(proof)
assert proof_tree.get_root() == root
included, proof2 = proof_tree.is_included_already_hashed(item)
assert included
assert proof == proof2
included, junk = proof_tree.is_included_already_hashed(item)
assert included == expect_included

# the rust implementation does not round-trip proofs of exclusions.
# doing so requires additional complexity (and cost).
# rust deliberately generates an empty proof from a tree generated from a
# proof
assert junk == b"" or junk == proof


def check_tree(leafs: List[bytes32]) -> None:
Expand All @@ -43,11 +49,26 @@ def check_tree(leafs: List[bytes32]) -> None:
ru_included, ru_proof = ru_tree.is_included_already_hashed(item)
assert ru_included
assert py_proof == ru_proof
proof = ru_proof

check_proof(proof, py_deserialize_proof, root=root, item=item)
check_proof(proof, ru_deserialize_proof, root=root, item=item)

check_proof(py_proof, py_deserialize_proof, root=root, item=item)
check_proof(ru_proof, py_deserialize_proof, root=root, item=item)
check_proof(py_proof, ru_deserialize_proof, root=root, item=item)
check_proof(ru_proof, ru_deserialize_proof, root=root, item=item)
for i in range(256):
item = bytes32([i] + [2] * 31)
py_included, py_proof = py_tree.is_included_already_hashed(item)
assert not py_included
ru_included, ru_proof = ru_tree.is_included_already_hashed(item)
assert not ru_included
assert py_proof == ru_proof
proof = ru_proof

check_proof(
proof, py_deserialize_proof, root=root, item=item, expect_included=False
)
check_proof(
proof, ru_deserialize_proof, root=root, item=item, expect_included=False
)


def h(b: str) -> bytes32:
Expand Down

0 comments on commit cb8ad69

Please sign in to comment.