Skip to content

Commit

Permalink
Add Indexed Merkle Tree (#281)
Browse files Browse the repository at this point in the history
* add basic indexed merkle tree.

* membership check for nullifier tree and cleanup.

* test works for a simple case.

* fix test.

* fix.

* Add comment.
  • Loading branch information
suyash67 committed Mar 29, 2023
1 parent 81f484c commit 8590e15
Show file tree
Hide file tree
Showing 18 changed files with 827 additions and 26 deletions.
15 changes: 12 additions & 3 deletions cpp/src/barretenberg/stdlib/merkle_tree/hash.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,21 @@ namespace plonk {
namespace stdlib {
namespace merkle_tree {

inline barretenberg::fr compress_native(barretenberg::fr const& lhs, barretenberg::fr const& rhs)
inline barretenberg::fr hash_pair_native(barretenberg::fr const& lhs, barretenberg::fr const& rhs)
{
if (plonk::SYSTEM_COMPOSER == plonk::PLOOKUP) {
return crypto::pedersen_hash::lookup::hash_multiple({ lhs, rhs });
return crypto::pedersen_hash::lookup::hash_multiple({ lhs, rhs }); // uses lookup tables
} else {
return crypto::pedersen_hash::hash_multiple({ lhs, rhs });
return crypto::pedersen_hash::hash_multiple({ lhs, rhs }); // uses fixed-base multiplication gate
}
}

inline barretenberg::fr hash_multiple_native(std::vector<barretenberg::fr> const& inputs)
{
if (plonk::SYSTEM_COMPOSER == plonk::PLOOKUP) {
return crypto::pedersen_hash::lookup::hash_multiple(inputs); // uses lookup tables
} else {
return crypto::pedersen_hash::hash_multiple(inputs); // uses fixed-base multiplication gate
}
}

Expand Down
5 changes: 1 addition & 4 deletions cpp/src/barretenberg/stdlib/merkle_tree/hash.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@ TEST(stdlib_merkle_tree_hash, compress_native_vs_circuit)
Composer composer = Composer();
witness_ct y = witness_ct(&composer, x);
field_ct z = plonk::stdlib::pedersen_hash<Composer>::hash_multiple({ y, y });
auto zz = crypto::pedersen_hash::hash_multiple({ x, x }); // uses fixed-base multiplication gate
if constexpr (Composer::type == ComposerType::PLOOKUP) {
zz = crypto::pedersen_hash::lookup::hash_multiple({ x, x }); // uses lookup tables
}
auto zz = stdlib::merkle_tree::hash_pair_native(x, x);

EXPECT_EQ(z.get_value(), zz);
}
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/barretenberg/stdlib/merkle_tree/hash_path.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ inline fr_hash_path get_new_hash_path(fr_hash_path const& old_path, uint128_t in
} else {
path[i].first = current;
}
current = compress_native(path[i].first, path[i].second);
current = hash_pair_native(path[i].first, path[i].second);
index /= 2;
}
return path;
Expand All @@ -50,14 +50,14 @@ template <typename Ctx> inline hash_path<Ctx> create_witness_hash_path(Ctx& ctx,

inline fr get_hash_path_root(fr_hash_path const& input)
{
return compress_native(input[input.size() - 1].first, input[input.size() - 1].second);
return hash_pair_native(input[input.size() - 1].first, input[input.size() - 1].second);
}

inline fr zero_hash_at_height(size_t height)
{
auto current = fr(0);
for (size_t i = 0; i < height; ++i) {
current = compress_native(current, current);
current = hash_pair_native(current, current);
}
return current;
}
Expand Down
29 changes: 29 additions & 0 deletions cpp/src/barretenberg/stdlib/merkle_tree/membership.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,35 @@ void update_membership(field_t<Composer> const& new_root,
assert_check_membership(new_root, old_hashes, new_value, index, true, msg + "_new_value");
}

/**
* Asserts if the state transitions on updating multiple existing leaves with new values.
*
* @param old_root: The root of the merkle tree before it was updated,
* @param new_roots: New roots after each existing leaf was updated,
* @param new_values: The new values that are inserted in the existing leaves,
* @param old_values: The values of the existing leaves that were updated,
* @param old_paths: The hash path from the given index right before a given existing leaf is updated,
* @param old_indicies: Indices of the existing leaves that need to be updated,
* @tparam Composer: type of composer.
*/
template <typename Composer>
field_t<Composer> update_memberships(field_t<Composer> old_root,
std::vector<field_t<Composer>> const& new_roots,
std::vector<field_t<Composer>> const& new_values,
std::vector<field_t<Composer>> const& old_values,
std::vector<hash_path<Composer>> const& old_paths,
std::vector<bit_vector<Composer>> const& old_indicies)
{
for (size_t i = 0; i < old_indicies.size(); i++) {
update_membership(
new_roots[i], new_values[i], old_root, old_paths[i], old_values[i], old_indicies[i], "update_memberships");

old_root = new_roots[i];
}

return old_root;
}

/**
* Asserts if old and new state of the tree is correct after a subtree-update.
*
Expand Down
70 changes: 70 additions & 0 deletions cpp/src/barretenberg/stdlib/merkle_tree/membership.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ using namespace barretenberg;
using namespace plonk::stdlib::types;
using namespace plonk::stdlib::merkle_tree;

namespace {
auto& engine = numeric::random::get_debug_engine();
}

TEST(stdlib_merkle_tree, test_check_membership)
{
MemoryStore store;
Expand Down Expand Up @@ -203,3 +207,69 @@ TEST(stdlib_merkle_tree, test_tree)
bool result = verifier.verify_proof(proof);
EXPECT_EQ(result, true);
}

TEST(stdlib_merkle_tree, test_update_memberships)
{
constexpr size_t depth = 4;
MemoryStore store;
MerkleTree tree(store, depth);

Composer composer = Composer();

constexpr size_t filled = (1UL << depth) / 2;
std::vector<fr> filled_values;
for (size_t i = 0; i < filled; i++) {
uint256_t val = fr::random_element();
tree.update_element(i, val);
filled_values.push_back(val);
}

// old state
fr old_root = tree.root();
std::vector<size_t> old_indices = { 0, 2, 5, 7 };

std::vector<fr> old_values;
std::vector<fr_hash_path> old_hash_paths;
for (size_t i = 0; i < old_indices.size(); i++) {
old_values.push_back(filled_values[old_indices[i]]);
}

// new state
std::vector<fr> new_values;
std::vector<fr> new_roots;
for (size_t i = 0; i < old_indices.size(); i++) {
uint256_t val = fr::random_element();
new_values.push_back(val);
old_hash_paths.push_back(tree.get_hash_path(old_indices[i]));
new_roots.push_back(tree.update_element(old_indices[i], new_values[i]));
}

// old state circuit types
field_ct old_root_ct = witness_ct(&composer, old_root);
std::vector<bit_vector<Composer>> old_indices_ct;
std::vector<field_ct> old_values_ct;
std::vector<hash_path<Composer>> old_hash_paths_ct;

// new state circuit types
std::vector<field_ct> new_values_ct;
std::vector<field_ct> new_roots_ct;

for (size_t i = 0; i < old_indices.size(); i++) {
auto idx_vec = field_ct(witness_ct(&composer, uint256_t(old_indices[i]))).decompose_into_bits(depth);
old_indices_ct.push_back(idx_vec);
old_values_ct.push_back(witness_ct(&composer, old_values[i]));
old_hash_paths_ct.push_back(create_witness_hash_path(composer, old_hash_paths[i]));

new_values_ct.push_back(witness_ct(&composer, new_values[i]));
new_roots_ct.push_back(witness_ct(&composer, new_roots[i]));
}

update_memberships(old_root_ct, new_roots_ct, new_values_ct, old_values_ct, old_hash_paths_ct, old_indices_ct);

auto prover = composer.create_prover();
printf("composer gates = %zu\n", composer.get_num_gates());
auto verifier = composer.create_verifier();
plonk::proof proof = prover.construct_proof();
bool result = verifier.verify_proof(proof);
EXPECT_EQ(result, true);
}
4 changes: 2 additions & 2 deletions cpp/src/barretenberg/stdlib/merkle_tree/memory_tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ MemoryTree::MemoryTree(size_t depth)
for (size_t i = 0; i < layer_size; ++i) {
hashes_[offset + i] = current;
}
current = compress_native(current, current);
current = hash_pair_native(current, current);
}

root_ = current;
Expand Down Expand Up @@ -48,7 +48,7 @@ fr MemoryTree::update_element(size_t index, fr const& value)
for (size_t i = 0; i < depth_; ++i) {
hashes_[offset + index] = current;
index &= (~0ULL) - 1;
current = compress_native(hashes_[offset + index], hashes_[offset + index + 1]);
current = hash_pair_native(hashes_[offset + index], hashes_[offset + index + 1]);
offset += layer_size;
layer_size >>= 1;
index >>= 1;
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/barretenberg/stdlib/merkle_tree/memory_tree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class MemoryTree {

fr root() const { return root_; }

private:
public:
size_t depth_;
size_t total_size_;
barretenberg::fr root_;
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/barretenberg/stdlib/merkle_tree/memory_tree.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ TEST(stdlib_merkle_tree, test_memory_store)
fr e01 = VALUES[1];
fr e02 = VALUES[2];
fr e03 = VALUES[3];
fr e10 = compress_native(e00, e01);
fr e11 = compress_native(e02, e03);
fr root = compress_native(e10, e11);
fr e10 = hash_pair_native(e00, e01);
fr e11 = hash_pair_native(e02, e03);
fr root = hash_pair_native(e10, e11);

MemoryTree db(2);
for (size_t i = 0; i < 4; ++i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ static std::vector<fr> VALUES = []() {
void hash(State& state) noexcept
{
for (auto _ : state) {
compress_native({ 0, 0, 0, 0 }, { 1, 1, 1, 1 });
hash_pair_native({ 0, 0, 0, 0 }, { 1, 1, 1, 1 });
}
}
BENCHMARK(hash)->MinTime(5);
Expand Down
14 changes: 7 additions & 7 deletions cpp/src/barretenberg/stdlib/merkle_tree/merkle_tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ MerkleTree<Store>::MerkleTree(Store& store, size_t depth, uint8_t tree_id)
auto current = fr(0);
for (size_t i = 0; i < depth; ++i) {
zero_hashes_[i] = current;
current = compress_native(current, current);
current = hash_pair_native(current, current);
}
}

Expand All @@ -51,7 +51,7 @@ template <typename Store> fr MerkleTree<Store>::root() const
std::vector<uint8_t> root;
std::vector<uint8_t> key = { tree_id_ };
bool status = store_.get(key, root);
return status ? from_buffer<fr>(root) : compress_native(zero_hashes_.back(), zero_hashes_.back());
return status ? from_buffer<fr>(root) : hash_pair_native(zero_hashes_.back(), zero_hashes_.back());
}

template <typename Store> typename MerkleTree<Store>::index_t MerkleTree<Store>::size() const
Expand Down Expand Up @@ -103,7 +103,7 @@ template <typename Store> fr_hash_path MerkleTree<Store>::get_hash_path(index_t
} else {
path[j] = std::make_pair(current, zero_hashes_[j]);
}
current = compress_native(path[j].first, path[j].second);
current = hash_pair_native(path[j].first, path[j].second);
}
} else {
// Requesting path to a different, indepenent element.
Expand All @@ -123,7 +123,7 @@ template <typename Store> fr_hash_path MerkleTree<Store>::get_hash_path(index_t
} else {
path[j] = std::make_pair(current, zero_hashes_[j]);
}
current = compress_native(path[j].first, path[j].second);
current = hash_pair_native(path[j].first, path[j].second);
}
}
break;
Expand Down Expand Up @@ -158,7 +158,7 @@ template <typename Store> fr MerkleTree<Store>::binary_put(index_t a_index, fr c
bool a_is_right = bit_set(a_index, height - 1);
auto left = a_is_right ? b : a;
auto right = a_is_right ? a : b;
auto key = compress_native(left, right);
auto key = hash_pair_native(left, right);
put(key, left, right);
return key;
}
Expand Down Expand Up @@ -236,7 +236,7 @@ fr MerkleTree<Store>::update_element(fr const& root, fr const& value, index_t in
} else {
left = subtree_root;
}
auto new_root = compress_native(left, right);
auto new_root = hash_pair_native(left, right);
put(new_root, left, right);

// Remove the old node only while rolling back in recursion.
Expand All @@ -260,7 +260,7 @@ template <typename Store> fr MerkleTree<Store>::compute_zero_path_hash(size_t he
right = zero_hashes_[i];
left = current;
}
current = compress_native(is_right ? zero_hashes_[i] : current, is_right ? current : zero_hashes_[i]);
current = hash_pair_native(left, right);
}
return current;
}
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/barretenberg/stdlib/merkle_tree/merkle_tree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ template <typename Store> class MerkleTree {

index_t size() const;

private:
protected:
void load_metadata();

/**
Expand Down Expand Up @@ -88,7 +88,7 @@ template <typename Store> class MerkleTree {

void remove(fr const& key);

private:
protected:
Store& store_;
std::vector<fr> zero_hashes_;
size_t depth_;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#pragma once
#include "barretenberg/crypto/pedersen_commitment/pedersen.hpp"

namespace plonk {
namespace stdlib {
namespace merkle_tree {

using namespace barretenberg;
typedef uint256_t index_t;

struct nullifier_leaf {
fr value;
index_t nextIndex;
fr nextValue;

bool operator==(nullifier_leaf const&) const = default;

std::ostream& operator<<(std::ostream& os)
{
os << "value = " << value << "\nnextIdx = " << nextIndex << "\nnextVal = " << nextValue;
return os;
}

void read(uint8_t const*& it)
{
using serialize::read;
read(it, value);
read(it, nextIndex);
read(it, nextValue);
}

inline void write(std::vector<uint8_t>& buf)
{
using serialize::write;
write(buf, value);
write(buf, nextIndex);
write(buf, nextValue);
}

barretenberg::fr hash() const { return stdlib::merkle_tree::hash_multiple_native({ value, nextIndex, nextValue }); }
};

inline std::pair<size_t, bool> find_closest_leaf(std::vector<nullifier_leaf> const& leaves_, fr const& new_value)
{
std::vector<uint256_t> diff;
bool repeated = false;
for (size_t i = 0; i < leaves_.size(); i++) {
auto leaf_value_ = uint256_t(leaves_[i].value);
auto new_value_ = uint256_t(new_value);
if (leaf_value_ > new_value_) {
diff.push_back(leaf_value_);
} else if (leaf_value_ == new_value_) {
repeated = true;
return std::make_pair(i, repeated);
} else {
diff.push_back(new_value_ - leaf_value_);
}
}
auto it = std::min_element(diff.begin(), diff.end());
return std::make_pair(static_cast<size_t>(it - diff.begin()), repeated);
}

} // namespace merkle_tree
} // namespace stdlib
} // namespace plonk
Loading

0 comments on commit 8590e15

Please sign in to comment.