From 3ad6dd91c5bd30ca3b8b522855e9c04106e0ec9f Mon Sep 17 00:00:00 2001 From: Facundo Date: Thu, 15 Aug 2024 16:49:47 +0100 Subject: [PATCH] refactor(avm): separate binary and bytes finalization (#8010) --- .../vm/avm/trace/binary_trace.cpp | 48 +++++- .../vm/avm/trace/binary_trace.hpp | 11 +- .../barretenberg/vm/avm/trace/fixed_bytes.cpp | 92 ++++++++++++ .../barretenberg/vm/avm/trace/fixed_bytes.hpp | 25 ++++ .../src/barretenberg/vm/avm/trace/trace.cpp | 139 +++--------------- 5 files changed, 192 insertions(+), 123 deletions(-) create mode 100644 barretenberg/cpp/src/barretenberg/vm/avm/trace/fixed_bytes.cpp create mode 100644 barretenberg/cpp/src/barretenberg/vm/avm/trace/fixed_bytes.hpp diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/trace/binary_trace.cpp b/barretenberg/cpp/src/barretenberg/vm/avm/trace/binary_trace.cpp index dd0c0452203..8031cf1c9c9 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/trace/binary_trace.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/trace/binary_trace.cpp @@ -10,11 +10,6 @@ namespace bb::avm_trace { -std::vector AvmBinaryTraceBuilder::finalize() -{ - return std::move(binary_trace); -} - void AvmBinaryTraceBuilder::reset() { binary_trace.clear(); @@ -166,4 +161,47 @@ FF AvmBinaryTraceBuilder::op_xor(FF const& a, FF const& b, AvmMemoryTag instr_ta return uint256_t::from_uint128(c_uint128); } +void AvmBinaryTraceBuilder::finalize(std::vector>& main_trace) +{ + for (size_t i = 0; i < size(); i++) { + auto const& src = binary_trace.at(i); + auto& dest = main_trace.at(i); + dest.binary_clk = src.binary_clk; + dest.binary_sel_bin = static_cast(src.bin_sel); + dest.binary_acc_ia = src.acc_ia; + dest.binary_acc_ib = src.acc_ib; + dest.binary_acc_ic = src.acc_ic; + dest.binary_in_tag = src.in_tag; + dest.binary_op_id = src.op_id; + dest.binary_ia_bytes = src.bin_ia_bytes; + dest.binary_ib_bytes = src.bin_ib_bytes; + dest.binary_ic_bytes = src.bin_ic_bytes; + dest.binary_start = FF(static_cast(src.start)); + dest.binary_mem_tag_ctr = src.mem_tag_ctr; + dest.binary_mem_tag_ctr_inv = src.mem_tag_ctr_inv; + } + + reset(); +} + +void AvmBinaryTraceBuilder::finalize_lookups(std::vector>& main_trace) +{ + for (auto const& [clk, count] : byte_operation_counter) { + main_trace.at(clk).lookup_byte_operations_counts = count; + } + + for (uint8_t avm_in_tag = 0; avm_in_tag < 5; avm_in_tag++) { + // The +1 here is because the instruction tags we care about (i.e excl U0 and FF) has the range [1,5] + main_trace.at(avm_in_tag).lookup_byte_lengths_counts = byte_length_counter[avm_in_tag + 1]; + } +} + +void AvmBinaryTraceBuilder::finalize_lookups_for_testing(std::vector>& main_trace) +{ + for (uint8_t avm_in_tag = 0; avm_in_tag < 5; avm_in_tag++) { + // The +1 here is because the instruction tags we care about (i.e excl U0 and FF) has the range [1,5] + main_trace.at(avm_in_tag).lookup_byte_lengths_counts = byte_length_counter[avm_in_tag + 1]; + } +} + } // namespace bb::avm_trace diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/trace/binary_trace.hpp b/barretenberg/cpp/src/barretenberg/vm/avm/trace/binary_trace.hpp index bcbd2adcc49..67394be7ae5 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/trace/binary_trace.hpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/trace/binary_trace.hpp @@ -1,6 +1,7 @@ #pragma once #include "barretenberg/numeric/uint128/uint128.hpp" +#include "barretenberg/vm/avm/generated/full_row.hpp" #include "barretenberg/vm/avm/trace/common.hpp" #include @@ -32,9 +33,15 @@ class AvmBinaryTraceBuilder { std::unordered_map byte_length_counter; AvmBinaryTraceBuilder() = default; + + size_t size() const { return binary_trace.size(); } void reset(); - // Finalize the trace - std::vector finalize(); + + // These two have to be separate because the lookups need to be finalized + // after the extra first row is inserted in the main trace. + void finalize(std::vector>& main_trace); + void finalize_lookups(std::vector>& main_trace); + void finalize_lookups_for_testing(std::vector>& main_trace); FF op_and(FF const& a, FF const& b, AvmMemoryTag instr_tag, uint32_t clk); FF op_or(FF const& a, FF const& b, AvmMemoryTag instr_tag, uint32_t clk); diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/trace/fixed_bytes.cpp b/barretenberg/cpp/src/barretenberg/vm/avm/trace/fixed_bytes.cpp new file mode 100644 index 00000000000..b3fa426956b --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/vm/avm/trace/fixed_bytes.cpp @@ -0,0 +1,92 @@ +#include "barretenberg/vm/avm/trace/fixed_bytes.hpp" + +namespace bb::avm_trace { + +// Singleton. +const FixedBytesTable& FixedBytesTable::get() +{ + static FixedBytesTable table; + return table; +} + +void FixedBytesTable::finalize(std::vector>& main_trace) const +{ + if (main_trace.size() < 3 * (1 << 16)) { + main_trace.resize(3 * (1 << 16)); + } + // Generate Lookup Table of all combinations of 2, 8-bit numbers and op_id. + for (uint32_t op_id = 0; op_id < 3; op_id++) { + for (uint32_t input_a = 0; input_a <= UINT8_MAX; input_a++) { + for (uint32_t input_b = 0; input_b <= UINT8_MAX; input_b++) { + auto a = static_cast(input_a); + auto b = static_cast(input_b); + + // Derive a unique row index given op_id, a, and b. + auto main_trace_index = (op_id << 16) + (input_a << 8) + b; + + main_trace.at(main_trace_index).byte_lookup_sel_bin = FF(1); + main_trace.at(main_trace_index).byte_lookup_table_op_id = op_id; + main_trace.at(main_trace_index).byte_lookup_table_input_a = a; + main_trace.at(main_trace_index).byte_lookup_table_input_b = b; + } + } + } + + finalize_byte_length(main_trace); +} + +void FixedBytesTable::finalize_for_testing(std::vector>& main_trace, + const std::unordered_map& byte_operation_counter) const +{ + // Generate ByteLength Lookup table of instruction tags to the number of bytes + // {U8: 1, U16: 2, U32: 4, U64: 8, U128: 16} + for (auto const& [clk, count] : byte_operation_counter) { + // from the clk we can derive the a and b inputs + auto b = static_cast(clk); + auto a = static_cast(clk >> 8); + auto op_id = static_cast(clk >> 16); + uint8_t bit_op = 0; + if (op_id == 0) { + bit_op = a & b; + } else if (op_id == 1) { + bit_op = a | b; + } else { + bit_op = a ^ b; + } + if (clk > (main_trace.size() - 1)) { + main_trace.push_back(AvmFullRow{ + .byte_lookup_sel_bin = FF(1), + .byte_lookup_table_input_a = a, + .byte_lookup_table_input_b = b, + .byte_lookup_table_op_id = op_id, + .byte_lookup_table_output = bit_op, + .main_clk = FF(clk), + .lookup_byte_operations_counts = count, + }); + } else { + main_trace.at(clk).lookup_byte_operations_counts = count; + main_trace.at(clk).byte_lookup_sel_bin = FF(1); + main_trace.at(clk).byte_lookup_table_op_id = op_id; + main_trace.at(clk).byte_lookup_table_input_a = a; + main_trace.at(clk).byte_lookup_table_input_b = b; + main_trace.at(clk).byte_lookup_table_output = bit_op; + } + // Add the counter value stored throughout the execution + } + + finalize_byte_length(main_trace); +} + +void FixedBytesTable::finalize_byte_length(std::vector>& main_trace) +{ + // Generate ByteLength Lookup table of instruction tags to the number of bytes + // {U8: 1, U16: 2, U32: 4, U64: 8, U128: 16} + for (uint8_t avm_in_tag = 0; avm_in_tag < 5; avm_in_tag++) { + // The +1 here is because the instruction tags we care about (i.e excl U0 and FF) has the range 1,5] + main_trace.at(avm_in_tag).byte_lookup_sel_bin = FF(1); + main_trace.at(avm_in_tag).byte_lookup_table_in_tags = avm_in_tag + 1; + main_trace.at(avm_in_tag).byte_lookup_table_byte_lengths = static_cast(1 << avm_in_tag); + } +} + +} // namespace bb::avm_trace \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/trace/fixed_bytes.hpp b/barretenberg/cpp/src/barretenberg/vm/avm/trace/fixed_bytes.hpp new file mode 100644 index 00000000000..ee1b5ffd5fe --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/vm/avm/trace/fixed_bytes.hpp @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +#include "barretenberg/ecc/curves/bn254/fr.hpp" +#include "barretenberg/vm/avm/trace/common.hpp" +#include "barretenberg/vm/avm/trace/opcode.hpp" + +namespace bb::avm_trace { + +class FixedBytesTable { + public: + static const FixedBytesTable& get(); + + void finalize(std::vector>& main_trace) const; + void finalize_for_testing(std::vector>& main_trace, + const std::unordered_map& byte_operation_counter) const; + + private: + FixedBytesTable() = default; + static void finalize_byte_length(std::vector>& main_trace); +}; + +} // namespace bb::avm_trace \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.cpp b/barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.cpp index 4d287e0b453..fe6d4382dc0 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.cpp @@ -20,6 +20,7 @@ #include "barretenberg/numeric/uint256/uint256.hpp" #include "barretenberg/polynomials/univariate.hpp" #include "barretenberg/vm/avm/trace/common.hpp" +#include "barretenberg/vm/avm/trace/fixed_bytes.hpp" #include "barretenberg/vm/avm/trace/fixed_gas.hpp" #include "barretenberg/vm/avm/trace/fixed_powers.hpp" #include "barretenberg/vm/avm/trace/gadgets/slice_trace.hpp" @@ -34,47 +35,6 @@ namespace bb::avm_trace { * HELPERS IN ANONYMOUS NAMESPACE **************************************************************************************************/ namespace { -// WARNING: FOR TESTING ONLY -// Generates the minimal lookup table for the binary trace -uint32_t finalize_bin_trace_lookup_for_testing(std::vector& main_trace, AvmBinaryTraceBuilder& bin_trace_builder) -{ - // Generate ByteLength Lookup table of instruction tags to the number of bytes - // {U8: 1, U16: 2, U32: 4, U64: 8, U128: 16} - for (auto const& [clk, count] : bin_trace_builder.byte_operation_counter) { - // from the clk we can derive the a and b inputs - auto b = static_cast(clk); - auto a = static_cast(clk >> 8); - auto op_id = static_cast(clk >> 16); - uint8_t bit_op = 0; - if (op_id == 0) { - bit_op = a & b; - } else if (op_id == 1) { - bit_op = a | b; - } else { - bit_op = a ^ b; - } - if (clk > (main_trace.size() - 1)) { - main_trace.push_back(Row{ - .byte_lookup_sel_bin = FF(1), - .byte_lookup_table_input_a = a, - .byte_lookup_table_input_b = b, - .byte_lookup_table_op_id = op_id, - .byte_lookup_table_output = bit_op, - .main_clk = FF(clk), - .lookup_byte_operations_counts = count, - }); - } else { - main_trace.at(clk).lookup_byte_operations_counts = count; - main_trace.at(clk).byte_lookup_sel_bin = FF(1); - main_trace.at(clk).byte_lookup_table_op_id = op_id; - main_trace.at(clk).byte_lookup_table_input_a = a; - main_trace.at(clk).byte_lookup_table_input_b = b; - main_trace.at(clk).byte_lookup_table_output = bit_op; - } - // Add the counter value stored throughout the execution - } - return static_cast(main_trace.size()); -} constexpr size_t L2_HI_GAS_COUNTS_IDX = 0; constexpr size_t L2_LO_GAS_COUNTS_IDX = 1; @@ -3459,7 +3419,6 @@ std::vector AvmTraceBuilder::finalize(bool range_check_required) auto poseidon2_trace = poseidon2_trace_builder.finalize(); auto keccak_trace = keccak_trace_builder.finalize(); auto pedersen_trace = pedersen_trace_builder.finalize(); - auto bin_trace = bin_trace_builder.finalize(); auto gas_trace = gas_trace_builder.finalize(); auto slice_trace = slice_trace_builder.finalize(); const auto& fixed_gas_table = FixedGasTable::get(); @@ -3471,7 +3430,7 @@ std::vector AvmTraceBuilder::finalize(bool range_check_required) size_t poseidon2_trace_size = poseidon2_trace.size(); size_t keccak_trace_size = keccak_trace.size(); size_t pedersen_trace_size = pedersen_trace.size(); - size_t bin_trace_size = bin_trace.size(); + size_t bin_trace_size = bin_trace_builder.size(); size_t gas_trace_size = gas_trace.size(); size_t slice_trace_size = slice_trace.size(); @@ -3480,18 +3439,14 @@ std::vector AvmTraceBuilder::finalize(bool range_check_required) std::unordered_map mem_rng_check_mid_counts; std::unordered_map mem_rng_check_hi_counts; - // Main Trace needs to be at least as big as the biggest subtrace. - // If the bin_trace_size has entries, we need the main_trace to be as big as our byte lookup table (3 * - // 2**16 long) - size_t const lookup_table_size = (bin_trace_size > 0 && range_check_required) ? 3 * (1 << 16) : 0; // Range check size is 1 less than it needs to be since we insert a "first row" at the top of the trace at the // end, with clk 0 (this doubles as our range check) size_t const range_check_size = range_check_required ? UINT16_MAX : 0; - std::vector trace_sizes = { mem_trace_size, main_trace_size, alu_trace_size, - range_check_size, conv_trace_size, lookup_table_size, - sha256_trace_size, poseidon2_trace_size, pedersen_trace_size, - gas_trace_size + 1, KERNEL_INPUTS_LENGTH, KERNEL_OUTPUTS_LENGTH, - fixed_gas_table.size(), slice_trace_size, calldata.size() }; + std::vector trace_sizes = { mem_trace_size, main_trace_size, alu_trace_size, + range_check_size, conv_trace_size, sha256_trace_size, + poseidon2_trace_size, pedersen_trace_size, gas_trace_size + 1, + KERNEL_INPUTS_LENGTH, KERNEL_OUTPUTS_LENGTH, fixed_gas_table.size(), + slice_trace_size, calldata.size() }; vinfo("Trace sizes before padding:", "\n\tmain_trace_size: ", main_trace_size, @@ -3870,70 +3825,7 @@ std::vector AvmTraceBuilder::finalize(bool range_check_required) * BINARY TRACE INCLUSION **********************************************************************************************/ - // Add Binary Trace table - for (size_t i = 0; i < bin_trace_size; i++) { - auto const& src = bin_trace.at(i); - auto& dest = main_trace.at(i); - dest.binary_clk = src.binary_clk; - dest.binary_sel_bin = static_cast(src.bin_sel); - dest.binary_acc_ia = src.acc_ia; - dest.binary_acc_ib = src.acc_ib; - dest.binary_acc_ic = src.acc_ic; - dest.binary_in_tag = src.in_tag; - dest.binary_op_id = src.op_id; - dest.binary_ia_bytes = src.bin_ia_bytes; - dest.binary_ib_bytes = src.bin_ib_bytes; - dest.binary_ic_bytes = src.bin_ic_bytes; - dest.binary_start = FF(static_cast(src.start)); - dest.binary_mem_tag_ctr = src.mem_tag_ctr; - dest.binary_mem_tag_ctr_inv = src.mem_tag_ctr_inv; - } - - // Only generate precomputed byte tables if we are actually going to use them in this main trace. - if (bin_trace_size > 0) { - if (!range_check_required) { - finalize_bin_trace_lookup_for_testing(main_trace, bin_trace_builder); - } else { - // Generate Lookup Table of all combinations of 2, 8-bit numbers and op_id. - for (uint32_t op_id = 0; op_id < 3; op_id++) { - for (uint32_t input_a = 0; input_a <= UINT8_MAX; input_a++) { - for (uint32_t input_b = 0; input_b <= UINT8_MAX; input_b++) { - auto a = static_cast(input_a); - auto b = static_cast(input_b); - - // Derive a unique row index given op_id, a, and b. - auto main_trace_index = (op_id << 16) + (input_a << 8) + b; - - main_trace.at(main_trace_index).byte_lookup_sel_bin = FF(1); - main_trace.at(main_trace_index).byte_lookup_table_op_id = op_id; - main_trace.at(main_trace_index).byte_lookup_table_input_a = a; - main_trace.at(main_trace_index).byte_lookup_table_input_b = b; - // Add the counter value stored throughout the execution - main_trace.at(main_trace_index).lookup_byte_operations_counts = - bin_trace_builder.byte_operation_counter[main_trace_index]; - if (op_id == 0) { - main_trace.at(main_trace_index).byte_lookup_table_output = a & b; - } else if (op_id == 1) { - main_trace.at(main_trace_index).byte_lookup_table_output = a | b; - } else { - main_trace.at(main_trace_index).byte_lookup_table_output = a ^ b; - } - } - } - } - } - // Generate ByteLength Lookup table of instruction tags to the number of bytes - // {U8: 1, U16: 2, U32: 4, U64: 8, U128: 16} - for (uint8_t avm_in_tag = 0; avm_in_tag < 5; avm_in_tag++) { - // The +1 here is because the instruction tags we care about (i.e excl U0 and FF) has the range - // [1,5] - main_trace.at(avm_in_tag).byte_lookup_sel_bin = FF(1); - main_trace.at(avm_in_tag).byte_lookup_table_in_tags = avm_in_tag + 1; - main_trace.at(avm_in_tag).byte_lookup_table_byte_lengths = static_cast(pow(2, avm_in_tag)); - main_trace.at(avm_in_tag).lookup_byte_lengths_counts = - bin_trace_builder.byte_length_counter[avm_in_tag + 1]; - } - } + bin_trace_builder.finalize(main_trace); /********************************************************************************************** * GAS TRACE INCLUSION @@ -4015,6 +3907,21 @@ std::vector AvmTraceBuilder::finalize(bool range_check_required) Row first_row = Row{ .main_sel_first = FF(1), .mem_lastAccess = FF(1) }; main_trace.insert(main_trace.begin(), first_row); + /********************************************************************************************** + * BYTES TRACE INCLUSION + **********************************************************************************************/ + + // Only generate precomputed byte tables if we are actually going to use them in this main trace. + if (bin_trace_size > 0) { + if (!range_check_required) { + FixedBytesTable::get().finalize_for_testing(main_trace, bin_trace_builder.byte_operation_counter); + bin_trace_builder.finalize_lookups_for_testing(main_trace); + } else { + FixedBytesTable::get().finalize(main_trace); + bin_trace_builder.finalize_lookups(main_trace); + } + } + /********************************************************************************************** * RANGE CHECKS AND SELECTORS INCLUSION **********************************************************************************************/