Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use standard serialized vector of frs in function tree cbinds #314

Merged
merged 11 commits into from
Apr 24, 2023
137 changes: 56 additions & 81 deletions circuits/cpp/src/aztec3/circuits/abis/c_bind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include <aztec3/utils/array.hpp>
#include <barretenberg/stdlib/merkle_tree/membership.hpp>
#include <barretenberg/crypto/keccak/keccak.hpp>
#include <barretenberg/common/serialize.hpp>

namespace {

Expand All @@ -37,71 +36,29 @@ using aztec3::circuits::abis::TxRequest;
using NT = aztec3::utils::types::NativeTypes;

// Cbind helper functions

/**
* @brief Compute an imperfect merkle tree's root from leaves.
* @brief Fill in zero-leaves to get a full tree's bottom layer.
*
* @details given a `uint8_t const*` buffer representing a merkle tree's leaves,
* compute the corresponding tree's root and return the serialized results
* in the `output` buffer. "Partial left tree" here means that the tree's leaves
* are filled strictly from left to right, but there may be empty leaves on the right
* end of the tree.
* @details Given the a vector of nonzero leaves starting at the left,
* append zeroleaves to that list until it represents a FULL set of leaves
* for a tree of the given height.
* **MODIFIES THE INPUT `leaves` REFERENCE!**
*
* @tparam TREE_HEIGHT height of the tree used to determine max leaves and used when computing root
* @param leaves_buf a buffer of bytes representing the leaves of the tree, where each leaf is
* assumed to be a field and is interpreted using `NT::fr::serialize_from_buffer(leaf_ptr)`
* @param num_leaves the number of leaves in leaves_buf
* @tparam TREE_HEIGHT height of the tree used to determine max leaves
* @param leaves the nonzero leaves of the tree starting at the left
* @param zero_leaf the leaf value to be used for any empty/unset leaves
* @returns a field (`NT::fr`) containing the computed merkle tree root
*/
template <size_t TREE_HEIGHT>
NT::fr compute_root_of_partial_left_tree(uint8_t const* leaves_buf, uint8_t num_leaves, NT::fr zero_leaf)
template <size_t TREE_HEIGHT> void rightfill_with_zeroleaves(std::vector<NT::fr>& leaves, NT::fr& zero_leaf)
{
const size_t max_leaves = 2 << (TREE_HEIGHT - 1);
// cant exceed max leaves
ASSERT(num_leaves <= max_leaves);

// initialize the vector of leaves to a complete-tree-sized vector of zero-leaves
std::vector<NT::fr> leaves(max_leaves, zero_leaf);

// Iterate over the input buffer, extracting each leaf and serializing it from buffer to field
// Insert each leaf field into the vector
// If num_leaves < perfect tree, remaining leaves will be `zero_leaf`
for (size_t l = 0; l < num_leaves; l++) {
// each iteration skips to over some number of `fr`s to get to the // next leaf
uint8_t const* cur_leaf_ptr = leaves_buf + sizeof(NT::fr) * l;
NT::fr leaf = NT::fr::serialize_from_buffer(cur_leaf_ptr);
leaves[l] = leaf;
}

// compute the root of this complete tree, return
return plonk::stdlib::merkle_tree::compute_tree_root_native(leaves);
}

// TODO comment
// TODO code reuse possible with root func above
template <size_t TREE_HEIGHT>
std::vector<NT::fr> // array length is num nodes
compute_partial_left_tree(uint8_t const* leaves_buf, uint8_t num_leaves, NT::fr zero_leaf)
{
const size_t max_leaves = 2 << (TREE_HEIGHT - 1);
// cant exceed max leaves
ASSERT(num_leaves <= max_leaves);

// initialize the vector of leaves to a complete-tree-sized vector of zero-leaves
std::vector<NT::fr> leaves(max_leaves, zero_leaf);

// Iterate over the input buffer, extracting each leaf and serializing it from buffer to field
// Insert each leaf field into the vector
// If num_leaves < perfect tree, remaining leaves will be `zero_leaf`
for (size_t l = 0; l < num_leaves; l++) {
// each iteration skips to over some number of `fr`s to get to the // next leaf
uint8_t const* cur_leaf_ptr = leaves_buf + sizeof(NT::fr) * l;
NT::fr leaf = NT::fr::serialize_from_buffer(cur_leaf_ptr);
leaves[l] = leaf;
}

// compute the root of this complete tree, return
return plonk::stdlib::merkle_tree::compute_tree_native(leaves);
constexpr size_t max_leaves = 2 << (TREE_HEIGHT - 1);
// input cant exceed max leaves
// FIXME don't think asserts will show in wasm
ASSERT(leaves.size() <= max_leaves);

// fill in input vector with zero-leaves
// to get a full bottom layer of the tree
leaves.insert(leaves.end(), max_leaves - leaves.size(), zero_leaf);
}

} // namespace
Expand Down Expand Up @@ -238,39 +195,59 @@ WASM_EXPORT void abis__compute_function_leaf(uint8_t const* function_leaf_preima
}

/**
* @brief Compute a function tree root from its leaves.
* @brief Compute a function tree root from its nonzero leaves.
* This is a WASM-export that can be called from Typescript.
*
* @details given a `uint8_t const*` buffer representing a function tree's leaves,
* compute the corresponding tree's root and return the serialized results
* in the `output` buffer.
* @details given a serialized vector of nonzero function leaves,
* compute the corresponding tree's root and return the
* serialized results via `root_out` buffer.
*
* @param function_leaves_buf a buffer of bytes representing the leaves of the function tree,
* where each leaf is assumed to be a serialized field
* @param num_leaves the number of leaves in leaves_buf
* @param output buffer that will contain the output. The serialized function tree root.
* @param function_leaves_in input buffer representing a serialized vector of
* nonzero function leaves where each leaf is an `fr` starting at the left of the tree
* @param root_out buffer that will contain the serialized function tree root `fr`.
*/
WASM_EXPORT void abis__compute_function_tree_root(uint8_t const* function_leaves_buf,
uint8_t num_leaves,
uint8_t* output)
WASM_EXPORT void abis__compute_function_tree_root(uint8_t const* function_leaves_in, uint8_t* root_out)
{
std::vector<NT::fr> leaves;
// fill in nonzero leaves to start
read(function_leaves_in, leaves);
// fill in zero leaves to complete tree
NT::fr zero_leaf = FunctionLeafPreimage<NT>().hash(); // hash of empty/0 preimage
Copy link
Member

@Maddiaa0 Maddiaa0 Apr 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we eventually want to have this also be 0, as we did for the NullifierTree?

Copy link
Collaborator Author

@dbanks12 dbanks12 Apr 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, what are advantages of each way?

NT::fr root =
compute_root_of_partial_left_tree<aztec3::FUNCTION_TREE_HEIGHT>(function_leaves_buf, num_leaves, zero_leaf);
rightfill_with_zeroleaves<aztec3::FUNCTION_TREE_HEIGHT>(leaves, zero_leaf);

// compute the root of this complete tree, return
NT::fr root = plonk::stdlib::merkle_tree::compute_tree_root_native(leaves);

// serialize and return root
NT::fr::serialize_to_buffer(root, output);
NT::fr::serialize_to_buffer(root, root_out);
}

// TODO comment
WASM_EXPORT void abis__compute_function_tree(uint8_t const* function_leaves_buf, uint8_t num_leaves, uint8_t* output)
/**
* @brief Compute all of a function tree's nodes from its nonzero leaves.
* This is a WASM-export that can be called from Typescript.
*
* @details given a serialized vector of nonzero function leaves,
* compute ALL of the corresponding tree's nodes (including root) and return
* the serialized results via `tree_nodes_out` buffer.
*
* @param function_leaves_in input buffer representing a serialized vector of
* nonzero function leaves where each leaf is an `fr` starting at the left of the tree.
* @param tree_nodes_out buffer that will contain the serialized function tree.
* The 0th node is the bottom leftmost leaf. The last entry is the root.
*/
WASM_EXPORT void abis__compute_function_tree(uint8_t const* function_leaves_in, uint8_t* tree_nodes_out)
{
std::vector<NT::fr> leaves;
// fill in nonzero leaves to start
read(function_leaves_in, leaves);
// fill in zero leaves to complete tree
NT::fr zero_leaf = FunctionLeafPreimage<NT>().hash(); // hash of empty/0 preimage
std::vector<NT::fr> tree =
compute_partial_left_tree<aztec3::FUNCTION_TREE_HEIGHT>(function_leaves_buf, num_leaves, zero_leaf);
rightfill_with_zeroleaves<aztec3::FUNCTION_TREE_HEIGHT>(leaves, zero_leaf);

std::vector<NT::fr> tree = plonk::stdlib::merkle_tree::compute_tree_native(leaves);

// serialize and return tree
write(output, tree);
write(tree_nodes_out, tree);
}

/**
Expand All @@ -295,7 +272,6 @@ WASM_EXPORT void abis__hash_constructor(uint8_t const* function_data_buf,
std::array<NT::fr, aztec3::ARGS_LENGTH> args;
NT::fr constructor_vk_hash;

using serialize::read;
read(function_data_buf, function_data);
read(args_buf, args);
read(constructor_vk_hash_buf, constructor_vk_hash);
Expand Down Expand Up @@ -331,7 +307,6 @@ WASM_EXPORT void abis__compute_contract_address(uint8_t const* deployer_address_
NT::fr function_tree_root;
NT::fr constructor_hash;

using serialize::read;
read(deployer_address_buf, deployer_address);
read(contract_address_salt_buf, contract_address_salt);
read(function_tree_root_buf, function_tree_root);
Expand Down
2 changes: 0 additions & 2 deletions circuits/cpp/src/aztec3/circuits/abis/c_bind.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@ WASM_EXPORT void abis__compute_function_selector(char const* func_sig_cstr, uint
WASM_EXPORT void abis__compute_function_leaf(uint8_t const* function_leaf_preimage_buf, uint8_t* output);

WASM_EXPORT void abis__compute_function_tree_root(uint8_t const* function_leaves_buf,
uint8_t num_leaves,
uint8_t* output);

WASM_EXPORT void abis__compute_function_tree(uint8_t const* function_leaves_buf,
uint8_t num_leaves,
uint8_t* output);

WASM_EXPORT void abis__hash_vk(uint8_t const* vk_data_buf, uint8_t* output);
Expand Down
92 changes: 45 additions & 47 deletions circuits/cpp/src/aztec3/circuits/abis/c_bind.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ namespace {

using NT = aztec3::utils::types::NativeTypes;
using aztec3::circuits::abis::NewContractData;
// num_leaves = 2**h = 2<<(h-1)
// root layer does not count in height
constexpr size_t FUNCTION_TREE_NUM_LEAVES = 2 << (aztec3::FUNCTION_TREE_HEIGHT - 1);
// num_nodes = (2**(h+1))-1 = (2<<h)
// root layer does not count in height
// num nodes includes root
constexpr size_t FUNCTION_TREE_NUM_NODES = (2 << aztec3::FUNCTION_TREE_HEIGHT) - 1;

auto& engine = numeric::random::get_debug_engine();

Expand Down Expand Up @@ -159,78 +166,69 @@ TEST(abi_tests, compute_function_leaf)

TEST(abi_tests, compute_function_tree_root)
{
constexpr size_t FUNCTION_TREE_NUM_LEAVES = 2 << (aztec3::FUNCTION_TREE_HEIGHT - 1); // leaves = 2 ^ height

NT::fr zero_leaf = FunctionLeafPreimage<NT>().hash(); // hash of empty/0 preimage
// these frs will be used to compute the root directly (without cbind)
// all empty slots will have the zero-leaf to ensure full tree
std::vector<NT::fr> leaves_frs(FUNCTION_TREE_NUM_LEAVES, zero_leaf);

// randomize number of non-zero leaves such that `0 < num_nonzero_leaves <= FUNCTION_TREE_NUM_LEAVES`
uint8_t num_nonzero_leaves = engine.get_random_uint8() % (FUNCTION_TREE_NUM_LEAVES + 1);
// create a vector whose vec.data() can be cast to a single mega-buffer containing all non-zero leaves
// initialize the vector with its size so that a leaf's data can be copied in (via `seralize_to_buffer`)
// (uint256_t here means nothing; it is just used because it is the right size (32 uint8_ts))
std::vector<uint256_t> leaves(num_nonzero_leaves);

// generate some random leaves
// insert them into the vector of leaf fields (for direct tree root computation)
// insert their serialized form into the vector of 32-bytes chunks/uint256_ts
// (to be cast to a single mega uint8_t* buffer and passed to cbind)
std::vector<NT::fr> leaves_frs;
for (size_t l = 0; l < num_nonzero_leaves; l++) {
NT::fr leaf = NT::fr::random_element();
leaves_frs[l] = leaf;
NT::fr::serialize_to_buffer(leaf, reinterpret_cast<uint8_t*>(&leaves[l]));
leaves_frs.push_back(NT::fr::random_element());
}
// serilalize the leaves to a buffer to pass to cbind
std::vector<uint8_t> leaves_bytes_vec;
write(leaves_bytes_vec, leaves_frs);

// call cbind and get output (root)
std::array<uint8_t, sizeof(NT::fr)> output = { 0 };
abis__compute_function_tree_root(reinterpret_cast<uint8_t*>(leaves.data()), num_nonzero_leaves, output.data());
abis__compute_function_tree_root(leaves_bytes_vec.data(), output.data());
NT::fr got_root = NT::fr::serialize_from_buffer(output.data());

// compare cbind results with direct computation
NT::fr got_root = NT::fr::serialize_from_buffer(output.data());

// add the zero leaves to the vector of fields and pass to barretenberg helper
NT::fr zero_leaf = FunctionLeafPreimage<NT>().hash(); // hash of empty/0 preimage
for (size_t l = num_nonzero_leaves; l < FUNCTION_TREE_NUM_LEAVES; l++) {
leaves_frs.push_back(zero_leaf);
}
// compare results
EXPECT_EQ(got_root, plonk::stdlib::merkle_tree::compute_tree_root_native(leaves_frs));
}

TEST(abi_tests, compute_function_tree)
{
constexpr size_t FUNCTION_TREE_NUM_LEAVES = 2 << (aztec3::FUNCTION_TREE_HEIGHT - 1); // leaves = 2 ^ height

NT::fr zero_leaf = FunctionLeafPreimage<NT>().hash(); // hash of empty/0 preimage
// these frs will be used to compute the tree directly (without cbind)
// all empty slots will have the zero-leaf to ensure full tree
std::vector<NT::fr> leaves_frs(FUNCTION_TREE_NUM_LEAVES, zero_leaf);

// randomize number of non-zero leaves such that `0 < num_nonzero_leaves <= FUNCTION_TREE_NUM_LEAVES`
uint8_t num_nonzero_leaves = engine.get_random_uint8() % (FUNCTION_TREE_NUM_LEAVES + 1);
// create a vector whose vec.data() can be cast to a single mega-buffer containing all non-zero leaves
// initialize the vector with its size so that a leaf's data can be copied in (via `seralize_to_buffer`)
// (uint256_t here means nothing; it is just used because it is the right size (32 uint8_ts))
std::vector<uint256_t> leaves(num_nonzero_leaves);

// generate some random leaves
// insert them into the vector of leaf fields (for direct tree computation)
// insert their serialized form into the vector of 32-bytes chunks/uint256_ts
// (to be cast to a single mega uint8_t* buffer and passed to cbind)
std::vector<NT::fr> leaves_frs;
for (size_t l = 0; l < num_nonzero_leaves; l++) {
NT::fr leaf = NT::fr::random_element();
leaves_frs[l] = leaf;
NT::fr::serialize_to_buffer(leaf, reinterpret_cast<uint8_t*>(&leaves[l]));
leaves_frs.push_back(NT::fr::random_element());
}
// serilalize the leaves to a buffer to pass to cbind
std::vector<uint8_t> leaves_bytes_vec;
write(leaves_bytes_vec, leaves_frs);

// setup output buffer
// it must fit a uint32_t (for the vector length)
// plus all of the nodes `frs` in the tree
constexpr auto size_output_buf = sizeof(uint32_t) + (sizeof(NT::fr) * FUNCTION_TREE_NUM_NODES);
std::array<uint8_t, size_output_buf> output = { 0 };

// call cbind and get output (full tree root)
abis__compute_function_tree(leaves_bytes_vec.data(), output.data());
// deserialize output to vector of frs representing all nodes in tree
std::vector<NT::fr> got_tree;
uint8_t const* output_ptr = output.data();
read(output_ptr, got_tree);

// (2**h) - 1
constexpr size_t num_nodes = (2 << aztec3::FUNCTION_TREE_HEIGHT) - 1;

// call cbind and get output (root)
uint8_t* output = (uint8_t*)malloc(sizeof(NT::fr) * num_nodes);
abis__compute_function_tree(reinterpret_cast<uint8_t*>(leaves.data()), num_nonzero_leaves, output);

using serialize::read;
// compare cbind results with direct computation
std::vector<NT::fr> got_tree;
uint8_t const* output_copy = output;
read(output_copy, got_tree);

// add the zero leaves to the vector of fields and pass to barretenberg helper
NT::fr zero_leaf = FunctionLeafPreimage<NT>().hash(); // hash of empty/0 preimage
for (size_t l = num_nonzero_leaves; l < FUNCTION_TREE_NUM_LEAVES; l++) {
leaves_frs.push_back(zero_leaf);
}
// compare results
EXPECT_EQ(got_tree, plonk::stdlib::merkle_tree::compute_tree_native(leaves_frs));
}

Expand Down
11 changes: 4 additions & 7 deletions yarn-project/circuits.js/src/abis/abis.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import {
NewContractData,
FunctionLeafPreimage,
} from '../index.js';
import { serializeToBuffer } from '../utils/serialize.js';
import { serializeToBuffer, serializeBufferArrayToVector } from '../utils/serialize.js';
import { AsyncWasmWrapper, WasmWrapper } from '@aztec/foundation/wasm';

export function wasmSyncCall(
Expand Down Expand Up @@ -94,13 +94,10 @@ export async function computeFunctionLeaf(wasm: CircuitsWasm, fnLeaf: FunctionLe
}

export async function computeFunctionTreeRoot(wasm: CircuitsWasm, fnLeafs: Fr[]) {
const inputBuf = serializeToBuffer(fnLeafs);
const inputVector = serializeBufferArrayToVector(fnLeafs.map(fr => fr.toBuffer()));
wasm.call('pedersen__init');
const outputBuf = wasm.call('bbmalloc', 32);
const inputBufPtr = wasm.call('bbmalloc', inputBuf.length);
wasm.writeMemory(inputBufPtr, inputBuf);
await wasm.asyncCall('abis__compute_function_tree_root', inputBufPtr, fnLeafs.length, outputBuf);
return Fr.fromBuffer(Buffer.from(wasm.getMemorySlice(outputBuf, outputBuf + 32)));
const result = await wasmAsyncCall(wasm, 'abis__compute_function_tree_root', { toBuffer: () => inputVector }, 32);
dbanks12 marked this conversation as resolved.
Show resolved Hide resolved
return Fr.fromBuffer(result);
}

export async function hashConstructor(
Expand Down
10 changes: 5 additions & 5 deletions yarn-project/circuits.js/src/kernel/kernel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
KernelCircuitPublicInputs,
SignedTxRequest,
} from '../index.js';
import { boolToBuffer, serializeToBuffer, uint8ArrayToNum } from '../utils/serialize.js';
import { boolToBuffer, serializeBufferArrayToVector, uint8ArrayToNum } from '../utils/serialize.js';
import { CircuitsWasm } from '../wasm/index.js';

export async function getDummyPreviousKernelData(wasm: CircuitsWasm) {
Expand All @@ -30,13 +30,13 @@ export async function computeFunctionTree(wasm: CircuitsWasm, leaves: Fr[]): Pro
const outputBufSize = 2 ** (FUNCTION_TREE_HEIGHT + 1) * Fr.SIZE_IN_BYTES + 4;

// Allocate memory for the input and output buffers, and populate input buffer
const inputBuf = serializeToBuffer(leaves);
const inputBufPtr = wasm.call('bbmalloc', inputBuf.length);
const inputVector = serializeBufferArrayToVector(leaves.map(fr => fr.toBuffer()));
const inputBufPtr = wasm.call('bbmalloc', inputVector.length);
const outputBufPtr = wasm.call('bbmalloc', outputBufSize * 100);
wasm.writeMemory(inputBufPtr, inputBuf);
wasm.writeMemory(inputBufPtr, inputVector);

// Run and read outputs
await wasm.asyncCall('abis__compute_function_tree', inputBufPtr, leaves.length, outputBufPtr);
await wasm.asyncCall('abis__compute_function_tree', inputBufPtr, outputBufPtr);
const outputBuf = Buffer.from(wasm.getMemorySlice(outputBufPtr, outputBufPtr + outputBufSize));
const reader = new BufferReader(outputBuf);
const output = reader.readVector(Fr);
Expand Down