Skip to content

Commit

Permalink
chore: wrap leaf index in struct (#5179)
Browse files Browse the repository at this point in the history
Description
---
Wrap leaf index in a struct

Motivation and Context
---
To prevent accidents that leaf index is used as a node index and vice versa.
  • Loading branch information
Cifko authored Feb 14, 2023
1 parent d2717a1 commit fe49d6e
Show file tree
Hide file tree
Showing 9 changed files with 66 additions and 59 deletions.
25 changes: 14 additions & 11 deletions base_layer/mmr/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,15 @@ use crate::{error::MerkleMountainRangeError, Hash};

const ALL_ONES: usize = std::usize::MAX;

#[derive(Copy, Clone)]
pub struct LeafIndex(pub usize);

/// Returns the MMR node index derived from the leaf index.
pub fn node_index(leaf_index: usize) -> usize {
if leaf_index == 0 {
pub fn node_index(leaf_index: LeafIndex) -> usize {
if leaf_index.0 == 0 {
return 0;
}
2 * leaf_index - leaf_index.count_ones() as usize
2 * leaf_index.0 - leaf_index.0.count_ones() as usize
}

/// Returns the leaf index derived from the MMR node index.
Expand Down Expand Up @@ -214,14 +217,14 @@ mod test {

#[test]
fn leaf_to_node_indices() {
assert_eq!(node_index(0), 0);
assert_eq!(node_index(1), 1);
assert_eq!(node_index(2), 3);
assert_eq!(node_index(3), 4);
assert_eq!(node_index(5), 8);
assert_eq!(node_index(6), 10);
assert_eq!(node_index(7), 11);
assert_eq!(node_index(8), 15);
assert_eq!(node_index(LeafIndex(0)), 0);
assert_eq!(node_index(LeafIndex(1)), 1);
assert_eq!(node_index(LeafIndex(2)), 3);
assert_eq!(node_index(LeafIndex(3)), 4);
assert_eq!(node_index(LeafIndex(5)), 8);
assert_eq!(node_index(LeafIndex(6)), 10);
assert_eq!(node_index(LeafIndex(7)), 11);
assert_eq!(node_index(LeafIndex(8)), 15);
}

#[test]
Expand Down
15 changes: 8 additions & 7 deletions base_layer/mmr/src/merkle_mountain_range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ use crate::{
leaf_index,
node_index,
peak_map_height,
LeafIndex,
},
error::MerkleMountainRangeError,
pruned_hashset::PrunedHashSet,
Expand Down Expand Up @@ -111,21 +112,21 @@ where
}

/// This function returns the hash of the leaf index provided, indexed from 0
pub fn get_leaf_hash(&self, leaf_index: usize) -> Result<Option<Hash>, MerkleMountainRangeError> {
pub fn get_leaf_hash(&self, leaf_index: LeafIndex) -> Result<Option<Hash>, MerkleMountainRangeError> {
self.get_node_hash(node_index(leaf_index))
}

/// Returns a set of leaf hashes from the MMR.
pub fn get_leaf_hashes(&self, leaf_index: usize, count: usize) -> Result<Vec<Hash>, MerkleMountainRangeError> {
pub fn get_leaf_hashes(&self, leaf_index: LeafIndex, count: usize) -> Result<Vec<Hash>, MerkleMountainRangeError> {
let leaf_count = self.get_leaf_count()?;
if leaf_index >= leaf_count {
if leaf_index.0 >= leaf_count {
return Ok(Vec::new());
}
let count = max(1, count);
let last_leaf_index = min(leaf_index + count - 1, leaf_count);
let mut leaf_hashes = Vec::with_capacity(last_leaf_index - leaf_index + 1);
for leaf_index in leaf_index..=last_leaf_index {
if let Some(hash) = self.get_leaf_hash(leaf_index)? {
let last_leaf_index = min(leaf_index.0 + count - 1, leaf_count);
let mut leaf_hashes = Vec::with_capacity(last_leaf_index - leaf_index.0 + 1);
for leaf_index in leaf_index.0..=last_leaf_index {
if let Some(hash) = self.get_leaf_hash(LeafIndex(leaf_index))? {
leaf_hashes.push(hash);
}
}
Expand Down
6 changes: 3 additions & 3 deletions base_layer/mmr/src/merkle_proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use thiserror::Error;

use crate::{
backend::ArrayLike,
common::{family, family_branch, find_peaks, hash_together, is_leaf, is_left_sibling, node_index},
common::{family, family_branch, find_peaks, hash_together, is_leaf, is_left_sibling, node_index, LeafIndex},
error::MerkleMountainRangeError,
serde_support,
Hash,
Expand Down Expand Up @@ -78,7 +78,7 @@ impl MerkleProof {
/// See [MerkleProof::for_node] for more details on how the proof is constructed.
pub fn for_leaf_node<D, B>(
mmr: &MerkleMountainRange<D, B>,
leaf_index: usize,
leaf_index: LeafIndex,
) -> Result<MerkleProof, MerkleProofError>
where
D: Digest + DomainDigest,
Expand Down Expand Up @@ -159,7 +159,7 @@ impl MerkleProof {
&self,
root: &HashSlice,
hash: &HashSlice,
leaf_index: usize,
leaf_index: LeafIndex,
) -> Result<(), MerkleProofError> {
let pos = node_index(leaf_index);
self.verify::<D>(root, hash, pos)
Expand Down
3 changes: 2 additions & 1 deletion base_layer/mmr/src/mmr_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use tari_common::DomainDigest;

use crate::{
backend::ArrayLike,
common::LeafIndex,
error::MerkleMountainRangeError,
functions::{prune_mutable_mmr, PrunedMutableMmr},
merkle_checkpoint::MerkleCheckPoint,
Expand Down Expand Up @@ -209,7 +210,7 @@ where

/// Returns the hash of the leaf index provided, as well as its deletion status. The node has been marked for
/// deletion if the boolean value is true.
pub fn fetch_mmr_node(&self, leaf_index: u32) -> Result<(Option<Hash>, bool), MerkleMountainRangeError> {
pub fn fetch_mmr_node(&self, leaf_index: LeafIndex) -> Result<(Option<Hash>, bool), MerkleMountainRangeError> {
let (base_hash, base_deleted) = self.base_mmr.get_leaf_status(leaf_index)?;
let (curr_hash, curr_deleted) = self.curr_mmr.get_leaf_status(leaf_index)?;
if let Some(base_hash) = base_hash {
Expand Down
24 changes: 12 additions & 12 deletions base_layer/mmr/src/mutable_mmr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use tari_common::DomainDigest;

use crate::{
backend::ArrayLike,
common::node_index,
common::{node_index, LeafIndex},
error::MerkleMountainRangeError,
mutable_mmr_leaf_nodes::MutableMmrLeafNodes,
Hash,
Expand Down Expand Up @@ -96,18 +96,18 @@ where

/// This function returns the hash of the leaf index provided, indexed from 0. If the hash does not exist, or if it
/// has been marked for deletion, `None` is returned.
pub fn get_leaf_hash(&self, leaf_index: u32) -> Result<Option<Hash>, MerkleMountainRangeError> {
if self.deleted.contains(leaf_index) {
pub fn get_leaf_hash(&self, leaf_index: LeafIndex) -> Result<Option<Hash>, MerkleMountainRangeError> {
if self.deleted.contains(leaf_index.0 as u32) {
return Ok(None);
}
self.mmr.get_node_hash(node_index(leaf_index as usize))
self.mmr.get_node_hash(node_index(leaf_index))
}

/// Returns the hash of the leaf index provided, as well as its deletion status. The node has been marked for
/// deletion if the boolean value is true.
pub fn get_leaf_status(&self, leaf_index: u32) -> Result<(Option<Hash>, bool), MerkleMountainRangeError> {
let hash = self.mmr.get_node_hash(node_index(leaf_index as usize))?;
let deleted = self.deleted.contains(leaf_index);
pub fn get_leaf_status(&self, leaf_index: LeafIndex) -> Result<(Option<Hash>, bool), MerkleMountainRangeError> {
let hash = self.mmr.get_node_hash(node_index(leaf_index))?;
let deleted = self.deleted.contains(leaf_index.0 as u32);
Ok((hash, deleted))
}

Expand Down Expand Up @@ -205,14 +205,14 @@ where
}

// Returns a bitmap with only the deleted nodes for the specified region in the MMR.
fn get_sub_bitmap(&self, leaf_index: usize, count: usize) -> Result<Bitmap, MerkleMountainRangeError> {
fn get_sub_bitmap(&self, leaf_index: LeafIndex, count: usize) -> Result<Bitmap, MerkleMountainRangeError> {
let mut deleted = self.deleted.clone();
if leaf_index > 0 {
deleted.remove_range_closed(0..u32::try_from(leaf_index - 1).unwrap())
if leaf_index.0 > 0 {
deleted.remove_range_closed(0..u32::try_from(leaf_index.0 - 1).unwrap())
}
let leaf_count = self.mmr.get_leaf_count()?;
if leaf_count > 1 {
let last_index = leaf_index + count - 1;
let last_index = leaf_index.0 + count - 1;
if last_index < leaf_count - 1 {
deleted.remove_range_closed(u32::try_from(last_index + 1).unwrap()..u32::try_from(leaf_count).unwrap());
}
Expand All @@ -223,7 +223,7 @@ where
/// Returns the state of the MMR that consists of the leaf hashes and the deleted nodes.
pub fn to_leaf_nodes(
&self,
leaf_index: usize,
leaf_index: LeafIndex,
count: usize,
) -> Result<MutableMmrLeafNodes, MerkleMountainRangeError> {
Ok(MutableMmrLeafNodes {
Expand Down
17 changes: 9 additions & 8 deletions base_layer/mmr/tests/merkle_mountain_range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
mod support;

use support::{combine_hashes, create_mmr, int_to_hash};
use tari_mmr::common::LeafIndex;

use crate::support::{MmrTestHasherBlake256, TestMmr};

Expand Down Expand Up @@ -119,7 +120,7 @@ fn validate() {
#[test]
fn restore_from_leaf_hashes() {
let mut mmr = TestMmr::new(Vec::default());
let leaf_hashes = mmr.get_leaf_hashes(0, 1).unwrap();
let leaf_hashes = mmr.get_leaf_hashes(LeafIndex(0), 1).unwrap();
assert_eq!(leaf_hashes.len(), 0);

let h0 = int_to_hash(0);
Expand All @@ -134,8 +135,8 @@ fn restore_from_leaf_hashes() {

// Construct MMR state from multiple leaf hash queries.
let leaf_count = mmr.get_leaf_count().unwrap();
let mut leaf_hashes = mmr.get_leaf_hashes(0, 2).unwrap();
leaf_hashes.append(&mut mmr.get_leaf_hashes(2, leaf_count - 2).unwrap());
let mut leaf_hashes = mmr.get_leaf_hashes(LeafIndex(0), 2).unwrap();
leaf_hashes.append(&mut mmr.get_leaf_hashes(LeafIndex(2), leaf_count - 2).unwrap());
assert_eq!(leaf_hashes.len(), 4);
assert_eq!(leaf_hashes[0], h0);
assert_eq!(leaf_hashes[1], h1);
Expand All @@ -148,11 +149,11 @@ fn restore_from_leaf_hashes() {

assert!(mmr.assign(leaf_hashes).is_ok());
assert_eq!(mmr.len(), Ok(7));
assert_eq!(mmr.get_leaf_hash(0), Ok(Some(h0)));
assert_eq!(mmr.get_leaf_hash(1), Ok(Some(h1)));
assert_eq!(mmr.get_leaf_hash(2), Ok(Some(h2)));
assert_eq!(mmr.get_leaf_hash(3), Ok(Some(h3)));
assert_eq!(mmr.get_leaf_hash(4), Ok(None));
assert_eq!(mmr.get_leaf_hash(LeafIndex(0)), Ok(Some(h0)));
assert_eq!(mmr.get_leaf_hash(LeafIndex(1)), Ok(Some(h1)));
assert_eq!(mmr.get_leaf_hash(LeafIndex(2)), Ok(Some(h2)));
assert_eq!(mmr.get_leaf_hash(LeafIndex(3)), Ok(Some(h3)));
assert_eq!(mmr.get_leaf_hash(LeafIndex(4)), Ok(None));
}

#[test]
Expand Down
18 changes: 9 additions & 9 deletions base_layer/mmr/tests/merkle_proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ mod support;

use support::int_to_hash;
use tari_mmr::{
common::{is_leaf, node_index},
common::{is_leaf, node_index, LeafIndex},
MerkleProof,
MerkleProofError,
};
Expand Down Expand Up @@ -69,7 +69,7 @@ fn med_mmr() {
let mmr = create_mmr(size);
let root = mmr.get_merkle_root().unwrap();
let i = 499;
let pos = node_index(i);
let pos = node_index(LeafIndex(i));
let hash = int_to_hash(i);
let proof = MerkleProof::for_node(&mmr, pos).unwrap();
assert!(proof.verify::<MmrTestHasherBlake256>(&root, &hash, pos).is_ok());
Expand All @@ -78,10 +78,10 @@ fn med_mmr() {
#[test]
fn a_big_proof() {
let mmr = create_mmr(100_000);
let leaf_pos = 28_543;
let leaf_pos = LeafIndex(28_543);
let mmr_index = node_index(leaf_pos);
let root = mmr.get_merkle_root().unwrap();
let hash = int_to_hash(leaf_pos);
let hash = int_to_hash(leaf_pos.0);
let proof = MerkleProof::for_node(&mmr, mmr_index).unwrap();
assert!(proof.verify::<MmrTestHasherBlake256>(&root, &hash, mmr_index).is_ok())
}
Expand All @@ -90,8 +90,8 @@ fn a_big_proof() {
fn for_leaf_node() {
let mmr = create_mmr(100);
let root = mmr.get_merkle_root().unwrap();
let leaf_pos = 28;
let hash = int_to_hash(leaf_pos);
let leaf_pos = LeafIndex(28);
let hash = int_to_hash(leaf_pos.0);
let proof = MerkleProof::for_leaf_node(&mmr, leaf_pos).unwrap();
assert!(proof
.verify_leaf::<MmrTestHasherBlake256>(&root, &hash, leaf_pos)
Expand All @@ -104,7 +104,7 @@ const BINCODE_PROOF: &str = "080000000000000002000000000000002000000000000000834
#[test]
fn serialisation() {
let mmr = create_mmr(5);
let proof = MerkleProof::for_leaf_node(&mmr, 3).unwrap();
let proof = MerkleProof::for_leaf_node(&mmr, LeafIndex(3)).unwrap();
let json_proof = serde_json::to_string(&proof).unwrap();
assert_eq!(&json_proof, JSON_PROOF);

Expand All @@ -123,14 +123,14 @@ fn deserialization() {
let proof: MerkleProof = serde_json::from_str(JSON_PROOF).unwrap();
println!("{}", proof);
assert!(proof
.verify_leaf::<MmrTestHasherBlake256>(&root, &int_to_hash(3), 3)
.verify_leaf::<MmrTestHasherBlake256>(&root, &int_to_hash(3), LeafIndex(3))
.is_ok());

// Verify bincode-derived proof
let bin_proof = hex::from_hex(BINCODE_PROOF).unwrap();
let proof: MerkleProof = bincode::deserialize(&bin_proof).unwrap();
println!("{}", proof);
assert!(proof
.verify_leaf::<MmrTestHasherBlake256>(&root, &int_to_hash(3), 3)
.verify_leaf::<MmrTestHasherBlake256>(&root, &int_to_hash(3), LeafIndex(3))
.is_ok());
}
12 changes: 6 additions & 6 deletions base_layer/mmr/tests/mutable_mmr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
mod support;

use croaring::Bitmap;
use tari_mmr::{Hash, HashSlice};
use tari_mmr::{common::LeafIndex, Hash, HashSlice};
use tari_utilities::hex::Hex;

use crate::support::{create_mmr, int_to_hash, MmrTestHasherBlake256, MutableTestMmr};
Expand Down Expand Up @@ -165,12 +165,12 @@ fn restore_from_leaf_nodes() {

// Request state of MMR with single call
let leaf_count = mmr.get_leaf_count();
let mmr_state1 = mmr.to_leaf_nodes(0, leaf_count).unwrap();
let mmr_state1 = mmr.to_leaf_nodes(LeafIndex(0), leaf_count).unwrap();

// Request state of MMR with multiple calls
let mut mmr_state2 = mmr.to_leaf_nodes(0, 3).unwrap();
mmr_state2.combine(mmr.to_leaf_nodes(3, 3).unwrap());
mmr_state2.combine(mmr.to_leaf_nodes(6, leaf_count - 6).unwrap());
let mut mmr_state2 = mmr.to_leaf_nodes(LeafIndex(0), 3).unwrap();
mmr_state2.combine(mmr.to_leaf_nodes(LeafIndex(3), 3).unwrap());
mmr_state2.combine(mmr.to_leaf_nodes(LeafIndex(6), leaf_count - 6).unwrap());
assert_eq!(mmr_state1, mmr_state2);

// Change the state more before the restore
Expand All @@ -182,6 +182,6 @@ fn restore_from_leaf_nodes() {
// Restore from compact state
assert!(mmr.assign(mmr_state1).is_ok());
assert_eq!(mmr.get_merkle_root(), mmr_root);
let restored_mmr_state = mmr.to_leaf_nodes(0, mmr.get_leaf_count()).unwrap();
let restored_mmr_state = mmr.to_leaf_nodes(LeafIndex(0), mmr.get_leaf_count()).unwrap();
assert_eq!(restored_mmr_state, mmr_state2);
}
5 changes: 3 additions & 2 deletions base_layer/mmr/tests/pruned_mmr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use rand::{
};
use support::{create_mmr, create_mutable_mmr, int_to_hash};
use tari_mmr::{
common::LeafIndex,
functions::{calculate_mmr_root, calculate_pruned_mmr_root, prune_mmr},
Hash,
};
Expand Down Expand Up @@ -60,8 +61,8 @@ fn pruned_mmrs() {
assert!(pruned.push(int_to_hash(*size + 1)).is_ok());
assert_eq!(pruned.get_merkle_root(), mmr2.get_merkle_root());
// But you can only get recent hashes
assert_eq!(pruned.get_leaf_hash(*size / 2), Ok(None));
assert_eq!(pruned.get_leaf_hash(*size), Ok(Some(new_hash)))
assert_eq!(pruned.get_leaf_hash(LeafIndex(*size / 2)), Ok(None));
assert_eq!(pruned.get_leaf_hash(LeafIndex(*size)), Ok(Some(new_hash)))
}
}

Expand Down

0 comments on commit fe49d6e

Please sign in to comment.