Skip to content

Commit

Permalink
fix: use standard serialized vector of frs in function tree cbinds (#314
Browse files Browse the repository at this point in the history
)

* use a plain old serialized vector of frs instead of doing manual serialization logic for leaf frs

* fix function tree TS logic to match C++ fix

* wasm call helpers accept `Buffer \| { toBuffer: () => Buffer }` instead of only the object with toBuffer.
  • Loading branch information
dbanks12 authored Apr 24, 2023
1 parent f3ad475 commit 6d75664
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 151 deletions.
137 changes: 56 additions & 81 deletions circuits/cpp/src/aztec3/circuits/abis/c_bind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include <aztec3/utils/array.hpp>
#include <barretenberg/stdlib/merkle_tree/membership.hpp>
#include <barretenberg/crypto/keccak/keccak.hpp>
#include <barretenberg/common/serialize.hpp>

namespace {

Expand All @@ -37,71 +36,29 @@ using aztec3::circuits::abis::TxRequest;
using NT = aztec3::utils::types::NativeTypes;

// Cbind helper functions

/**
* @brief Compute an imperfect merkle tree's root from leaves.
* @brief Fill in zero-leaves to get a full tree's bottom layer.
*
* @details given a `uint8_t const*` buffer representing a merkle tree's leaves,
* compute the corresponding tree's root and return the serialized results
* in the `output` buffer. "Partial left tree" here means that the tree's leaves
* are filled strictly from left to right, but there may be empty leaves on the right
* end of the tree.
* @details Given the a vector of nonzero leaves starting at the left,
* append zeroleaves to that list until it represents a FULL set of leaves
* for a tree of the given height.
* **MODIFIES THE INPUT `leaves` REFERENCE!**
*
* @tparam TREE_HEIGHT height of the tree used to determine max leaves and used when computing root
* @param leaves_buf a buffer of bytes representing the leaves of the tree, where each leaf is
* assumed to be a field and is interpreted using `NT::fr::serialize_from_buffer(leaf_ptr)`
* @param num_leaves the number of leaves in leaves_buf
* @tparam TREE_HEIGHT height of the tree used to determine max leaves
* @param leaves the nonzero leaves of the tree starting at the left
* @param zero_leaf the leaf value to be used for any empty/unset leaves
* @returns a field (`NT::fr`) containing the computed merkle tree root
*/
template <size_t TREE_HEIGHT>
NT::fr compute_root_of_partial_left_tree(uint8_t const* leaves_buf, uint8_t num_leaves, NT::fr zero_leaf)
template <size_t TREE_HEIGHT> void rightfill_with_zeroleaves(std::vector<NT::fr>& leaves, NT::fr& zero_leaf)
{
const size_t max_leaves = 2 << (TREE_HEIGHT - 1);
// cant exceed max leaves
ASSERT(num_leaves <= max_leaves);

// initialize the vector of leaves to a complete-tree-sized vector of zero-leaves
std::vector<NT::fr> leaves(max_leaves, zero_leaf);

// Iterate over the input buffer, extracting each leaf and serializing it from buffer to field
// Insert each leaf field into the vector
// If num_leaves < perfect tree, remaining leaves will be `zero_leaf`
for (size_t l = 0; l < num_leaves; l++) {
// each iteration skips to over some number of `fr`s to get to the // next leaf
uint8_t const* cur_leaf_ptr = leaves_buf + sizeof(NT::fr) * l;
NT::fr leaf = NT::fr::serialize_from_buffer(cur_leaf_ptr);
leaves[l] = leaf;
}

// compute the root of this complete tree, return
return plonk::stdlib::merkle_tree::compute_tree_root_native(leaves);
}

// TODO comment
// TODO code reuse possible with root func above
template <size_t TREE_HEIGHT>
std::vector<NT::fr> // array length is num nodes
compute_partial_left_tree(uint8_t const* leaves_buf, uint8_t num_leaves, NT::fr zero_leaf)
{
const size_t max_leaves = 2 << (TREE_HEIGHT - 1);
// cant exceed max leaves
ASSERT(num_leaves <= max_leaves);

// initialize the vector of leaves to a complete-tree-sized vector of zero-leaves
std::vector<NT::fr> leaves(max_leaves, zero_leaf);

// Iterate over the input buffer, extracting each leaf and serializing it from buffer to field
// Insert each leaf field into the vector
// If num_leaves < perfect tree, remaining leaves will be `zero_leaf`
for (size_t l = 0; l < num_leaves; l++) {
// each iteration skips to over some number of `fr`s to get to the // next leaf
uint8_t const* cur_leaf_ptr = leaves_buf + sizeof(NT::fr) * l;
NT::fr leaf = NT::fr::serialize_from_buffer(cur_leaf_ptr);
leaves[l] = leaf;
}

// compute the root of this complete tree, return
return plonk::stdlib::merkle_tree::compute_tree_native(leaves);
constexpr size_t max_leaves = 2 << (TREE_HEIGHT - 1);
// input cant exceed max leaves
// FIXME don't think asserts will show in wasm
ASSERT(leaves.size() <= max_leaves);

// fill in input vector with zero-leaves
// to get a full bottom layer of the tree
leaves.insert(leaves.end(), max_leaves - leaves.size(), zero_leaf);
}

} // namespace
Expand Down Expand Up @@ -238,39 +195,59 @@ WASM_EXPORT void abis__compute_function_leaf(uint8_t const* function_leaf_preima
}

/**
* @brief Compute a function tree root from its leaves.
* @brief Compute a function tree root from its nonzero leaves.
* This is a WASM-export that can be called from Typescript.
*
* @details given a `uint8_t const*` buffer representing a function tree's leaves,
* compute the corresponding tree's root and return the serialized results
* in the `output` buffer.
* @details given a serialized vector of nonzero function leaves,
* compute the corresponding tree's root and return the
* serialized results via `root_out` buffer.
*
* @param function_leaves_buf a buffer of bytes representing the leaves of the function tree,
* where each leaf is assumed to be a serialized field
* @param num_leaves the number of leaves in leaves_buf
* @param output buffer that will contain the output. The serialized function tree root.
* @param function_leaves_in input buffer representing a serialized vector of
* nonzero function leaves where each leaf is an `fr` starting at the left of the tree
* @param root_out buffer that will contain the serialized function tree root `fr`.
*/
WASM_EXPORT void abis__compute_function_tree_root(uint8_t const* function_leaves_buf,
uint8_t num_leaves,
uint8_t* output)
WASM_EXPORT void abis__compute_function_tree_root(uint8_t const* function_leaves_in, uint8_t* root_out)
{
std::vector<NT::fr> leaves;
// fill in nonzero leaves to start
read(function_leaves_in, leaves);
// fill in zero leaves to complete tree
NT::fr zero_leaf = FunctionLeafPreimage<NT>().hash(); // hash of empty/0 preimage
NT::fr root =
compute_root_of_partial_left_tree<aztec3::FUNCTION_TREE_HEIGHT>(function_leaves_buf, num_leaves, zero_leaf);
rightfill_with_zeroleaves<aztec3::FUNCTION_TREE_HEIGHT>(leaves, zero_leaf);

// compute the root of this complete tree, return
NT::fr root = plonk::stdlib::merkle_tree::compute_tree_root_native(leaves);

// serialize and return root
NT::fr::serialize_to_buffer(root, output);
NT::fr::serialize_to_buffer(root, root_out);
}

// TODO comment
WASM_EXPORT void abis__compute_function_tree(uint8_t const* function_leaves_buf, uint8_t num_leaves, uint8_t* output)
/**
* @brief Compute all of a function tree's nodes from its nonzero leaves.
* This is a WASM-export that can be called from Typescript.
*
* @details given a serialized vector of nonzero function leaves,
* compute ALL of the corresponding tree's nodes (including root) and return
* the serialized results via `tree_nodes_out` buffer.
*
* @param function_leaves_in input buffer representing a serialized vector of
* nonzero function leaves where each leaf is an `fr` starting at the left of the tree.
* @param tree_nodes_out buffer that will contain the serialized function tree.
* The 0th node is the bottom leftmost leaf. The last entry is the root.
*/
WASM_EXPORT void abis__compute_function_tree(uint8_t const* function_leaves_in, uint8_t* tree_nodes_out)
{
std::vector<NT::fr> leaves;
// fill in nonzero leaves to start
read(function_leaves_in, leaves);
// fill in zero leaves to complete tree
NT::fr zero_leaf = FunctionLeafPreimage<NT>().hash(); // hash of empty/0 preimage
std::vector<NT::fr> tree =
compute_partial_left_tree<aztec3::FUNCTION_TREE_HEIGHT>(function_leaves_buf, num_leaves, zero_leaf);
rightfill_with_zeroleaves<aztec3::FUNCTION_TREE_HEIGHT>(leaves, zero_leaf);

std::vector<NT::fr> tree = plonk::stdlib::merkle_tree::compute_tree_native(leaves);

// serialize and return tree
write(output, tree);
write(tree_nodes_out, tree);
}

/**
Expand All @@ -295,7 +272,6 @@ WASM_EXPORT void abis__hash_constructor(uint8_t const* function_data_buf,
std::array<NT::fr, aztec3::ARGS_LENGTH> args;
NT::fr constructor_vk_hash;

using serialize::read;
read(function_data_buf, function_data);
read(args_buf, args);
read(constructor_vk_hash_buf, constructor_vk_hash);
Expand Down Expand Up @@ -331,7 +307,6 @@ WASM_EXPORT void abis__compute_contract_address(uint8_t const* deployer_address_
NT::fr function_tree_root;
NT::fr constructor_hash;

using serialize::read;
read(deployer_address_buf, deployer_address);
read(contract_address_salt_buf, contract_address_salt);
read(function_tree_root_buf, function_tree_root);
Expand Down
2 changes: 0 additions & 2 deletions circuits/cpp/src/aztec3/circuits/abis/c_bind.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@ WASM_EXPORT void abis__compute_function_selector(char const* func_sig_cstr, uint
WASM_EXPORT void abis__compute_function_leaf(uint8_t const* function_leaf_preimage_buf, uint8_t* output);

WASM_EXPORT void abis__compute_function_tree_root(uint8_t const* function_leaves_buf,
uint8_t num_leaves,
uint8_t* output);

WASM_EXPORT void abis__compute_function_tree(uint8_t const* function_leaves_buf,
uint8_t num_leaves,
uint8_t* output);

WASM_EXPORT void abis__hash_vk(uint8_t const* vk_data_buf, uint8_t* output);
Expand Down
92 changes: 45 additions & 47 deletions circuits/cpp/src/aztec3/circuits/abis/c_bind.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ namespace {

using NT = aztec3::utils::types::NativeTypes;
using aztec3::circuits::abis::NewContractData;
// num_leaves = 2**h = 2<<(h-1)
// root layer does not count in height
constexpr size_t FUNCTION_TREE_NUM_LEAVES = 2 << (aztec3::FUNCTION_TREE_HEIGHT - 1);
// num_nodes = (2**(h+1))-1 = (2<<h)
// root layer does not count in height
// num nodes includes root
constexpr size_t FUNCTION_TREE_NUM_NODES = (2 << aztec3::FUNCTION_TREE_HEIGHT) - 1;

auto& engine = numeric::random::get_debug_engine();

Expand Down Expand Up @@ -159,78 +166,69 @@ TEST(abi_tests, compute_function_leaf)

TEST(abi_tests, compute_function_tree_root)
{
constexpr size_t FUNCTION_TREE_NUM_LEAVES = 2 << (aztec3::FUNCTION_TREE_HEIGHT - 1); // leaves = 2 ^ height

NT::fr zero_leaf = FunctionLeafPreimage<NT>().hash(); // hash of empty/0 preimage
// these frs will be used to compute the root directly (without cbind)
// all empty slots will have the zero-leaf to ensure full tree
std::vector<NT::fr> leaves_frs(FUNCTION_TREE_NUM_LEAVES, zero_leaf);

// randomize number of non-zero leaves such that `0 < num_nonzero_leaves <= FUNCTION_TREE_NUM_LEAVES`
uint8_t num_nonzero_leaves = engine.get_random_uint8() % (FUNCTION_TREE_NUM_LEAVES + 1);
// create a vector whose vec.data() can be cast to a single mega-buffer containing all non-zero leaves
// initialize the vector with its size so that a leaf's data can be copied in (via `seralize_to_buffer`)
// (uint256_t here means nothing; it is just used because it is the right size (32 uint8_ts))
std::vector<uint256_t> leaves(num_nonzero_leaves);

// generate some random leaves
// insert them into the vector of leaf fields (for direct tree root computation)
// insert their serialized form into the vector of 32-bytes chunks/uint256_ts
// (to be cast to a single mega uint8_t* buffer and passed to cbind)
std::vector<NT::fr> leaves_frs;
for (size_t l = 0; l < num_nonzero_leaves; l++) {
NT::fr leaf = NT::fr::random_element();
leaves_frs[l] = leaf;
NT::fr::serialize_to_buffer(leaf, reinterpret_cast<uint8_t*>(&leaves[l]));
leaves_frs.push_back(NT::fr::random_element());
}
// serilalize the leaves to a buffer to pass to cbind
std::vector<uint8_t> leaves_bytes_vec;
write(leaves_bytes_vec, leaves_frs);

// call cbind and get output (root)
std::array<uint8_t, sizeof(NT::fr)> output = { 0 };
abis__compute_function_tree_root(reinterpret_cast<uint8_t*>(leaves.data()), num_nonzero_leaves, output.data());
abis__compute_function_tree_root(leaves_bytes_vec.data(), output.data());
NT::fr got_root = NT::fr::serialize_from_buffer(output.data());

// compare cbind results with direct computation
NT::fr got_root = NT::fr::serialize_from_buffer(output.data());

// add the zero leaves to the vector of fields and pass to barretenberg helper
NT::fr zero_leaf = FunctionLeafPreimage<NT>().hash(); // hash of empty/0 preimage
for (size_t l = num_nonzero_leaves; l < FUNCTION_TREE_NUM_LEAVES; l++) {
leaves_frs.push_back(zero_leaf);
}
// compare results
EXPECT_EQ(got_root, plonk::stdlib::merkle_tree::compute_tree_root_native(leaves_frs));
}

TEST(abi_tests, compute_function_tree)
{
constexpr size_t FUNCTION_TREE_NUM_LEAVES = 2 << (aztec3::FUNCTION_TREE_HEIGHT - 1); // leaves = 2 ^ height

NT::fr zero_leaf = FunctionLeafPreimage<NT>().hash(); // hash of empty/0 preimage
// these frs will be used to compute the tree directly (without cbind)
// all empty slots will have the zero-leaf to ensure full tree
std::vector<NT::fr> leaves_frs(FUNCTION_TREE_NUM_LEAVES, zero_leaf);

// randomize number of non-zero leaves such that `0 < num_nonzero_leaves <= FUNCTION_TREE_NUM_LEAVES`
uint8_t num_nonzero_leaves = engine.get_random_uint8() % (FUNCTION_TREE_NUM_LEAVES + 1);
// create a vector whose vec.data() can be cast to a single mega-buffer containing all non-zero leaves
// initialize the vector with its size so that a leaf's data can be copied in (via `seralize_to_buffer`)
// (uint256_t here means nothing; it is just used because it is the right size (32 uint8_ts))
std::vector<uint256_t> leaves(num_nonzero_leaves);

// generate some random leaves
// insert them into the vector of leaf fields (for direct tree computation)
// insert their serialized form into the vector of 32-bytes chunks/uint256_ts
// (to be cast to a single mega uint8_t* buffer and passed to cbind)
std::vector<NT::fr> leaves_frs;
for (size_t l = 0; l < num_nonzero_leaves; l++) {
NT::fr leaf = NT::fr::random_element();
leaves_frs[l] = leaf;
NT::fr::serialize_to_buffer(leaf, reinterpret_cast<uint8_t*>(&leaves[l]));
leaves_frs.push_back(NT::fr::random_element());
}
// serilalize the leaves to a buffer to pass to cbind
std::vector<uint8_t> leaves_bytes_vec;
write(leaves_bytes_vec, leaves_frs);

// setup output buffer
// it must fit a uint32_t (for the vector length)
// plus all of the nodes `frs` in the tree
constexpr auto size_output_buf = sizeof(uint32_t) + (sizeof(NT::fr) * FUNCTION_TREE_NUM_NODES);
std::array<uint8_t, size_output_buf> output = { 0 };

// call cbind and get output (full tree root)
abis__compute_function_tree(leaves_bytes_vec.data(), output.data());
// deserialize output to vector of frs representing all nodes in tree
std::vector<NT::fr> got_tree;
uint8_t const* output_ptr = output.data();
read(output_ptr, got_tree);

// (2**h) - 1
constexpr size_t num_nodes = (2 << aztec3::FUNCTION_TREE_HEIGHT) - 1;

// call cbind and get output (root)
uint8_t* output = (uint8_t*)malloc(sizeof(NT::fr) * num_nodes);
abis__compute_function_tree(reinterpret_cast<uint8_t*>(leaves.data()), num_nonzero_leaves, output);

using serialize::read;
// compare cbind results with direct computation
std::vector<NT::fr> got_tree;
uint8_t const* output_copy = output;
read(output_copy, got_tree);

// add the zero leaves to the vector of fields and pass to barretenberg helper
NT::fr zero_leaf = FunctionLeafPreimage<NT>().hash(); // hash of empty/0 preimage
for (size_t l = num_nonzero_leaves; l < FUNCTION_TREE_NUM_LEAVES; l++) {
leaves_frs.push_back(zero_leaf);
}
// compare results
EXPECT_EQ(got_tree, plonk::stdlib::merkle_tree::compute_tree_native(leaves_frs));
}

Expand Down
Loading

0 comments on commit 6d75664

Please sign in to comment.