Skip to content

Commit

Permalink
feat(avm): calldata gadget preliminaries (#7227)
Browse files Browse the repository at this point in the history
First preliminary work for issue #7211.
Added a new public calldata column and passes calldata file to the avm
verifier.
  • Loading branch information
jeanmon authored Jun 28, 2024
1 parent b3409c4 commit 79e8588
Show file tree
Hide file tree
Showing 25 changed files with 269 additions and 126 deletions.
3 changes: 3 additions & 0 deletions barretenberg/cpp/pil/avm/main.pil
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ namespace main(256);
pol constant sel_first = [1] + [0]*; // Used mostly to toggle off the first row consisting
// only in first element of shifted polynomials.

//===== PUBLIC COLUMNS=========================================================
pol public calldata;

//===== KERNEL INPUTS =========================================================
// Kernel lookup selector opcodes
pol commit sel_q_kernel_lookup;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
[[maybe_unused]] auto kernel_kernel_value_out = View(new_term.kernel_kernel_value_out); \
[[maybe_unused]] auto kernel_kernel_side_effect_out = View(new_term.kernel_kernel_side_effect_out); \
[[maybe_unused]] auto kernel_kernel_metadata_out = View(new_term.kernel_kernel_metadata_out); \
[[maybe_unused]] auto main_calldata = View(new_term.main_calldata); \
[[maybe_unused]] auto alu_a_hi = View(new_term.alu_a_hi); \
[[maybe_unused]] auto alu_a_lo = View(new_term.alu_a_lo); \
[[maybe_unused]] auto alu_b_hi = View(new_term.alu_b_hi); \
Expand Down
28 changes: 19 additions & 9 deletions barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_execution.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "barretenberg/vm/avm_trace/avm_execution.hpp"
#include "barretenberg/bb/log.hpp"
#include "barretenberg/common/serialize.hpp"
#include "barretenberg/numeric/uint256/uint256.hpp"
#include "barretenberg/vm/avm_trace/avm_common.hpp"
#include "barretenberg/vm/avm_trace/avm_deserialization.hpp"
#include "barretenberg/vm/avm_trace/avm_helper.hpp"
Expand Down Expand Up @@ -78,10 +79,11 @@ std::tuple<AvmFlavor::VerificationKey, HonkProof> Execution::prove(std::vector<u
auto prover = composer.create_prover(circuit_builder);
auto verifier = composer.create_verifier(circuit_builder);

// The proof starts with the serialized public inputs
// Proof structure: public_inputs | calldata_size | calldata | raw proof
HonkProof proof(public_inputs_vec);
proof.emplace_back(calldata.size());
proof.insert(proof.end(), calldata.begin(), calldata.end());
auto raw_proof = prover.construct_proof();
// append the raw proof after the public inputs
proof.insert(proof.end(), raw_proof.begin(), raw_proof.end());
// TODO(#4887): Might need to return PCS vk when full verify is supported
return std::make_tuple(*verifier.key, proof);
Expand Down Expand Up @@ -261,14 +263,23 @@ bool Execution::verify(AvmFlavor::VerificationKey vk, HonkProof const& proof)
// crs_factory_);
// output_state.pcs_verification_key = std::move(pcs_verification_key);

// Proof structure: public_inputs | calldata_size | calldata | raw proof
std::vector<FF> public_inputs_vec;
std::vector<FF> calldata;
std::vector<FF> raw_proof;
std::copy(
proof.begin(), proof.begin() + PUBLIC_CIRCUIT_PUBLIC_INPUTS_LENGTH, std::back_inserter(public_inputs_vec));
std::copy(proof.begin() + PUBLIC_CIRCUIT_PUBLIC_INPUTS_LENGTH, proof.end(), std::back_inserter(raw_proof));

// This can be made nicer using BB's serialize::read, probably.
const auto public_inputs_offset = proof.begin();
const auto calldata_size_offset = public_inputs_offset + PUBLIC_CIRCUIT_PUBLIC_INPUTS_LENGTH;
const auto calldata_offset = calldata_size_offset + 1;
const auto raw_proof_offset = calldata_offset + static_cast<int64_t>(uint64_t(*calldata_size_offset));

std::copy(public_inputs_offset, calldata_size_offset, std::back_inserter(public_inputs_vec));
std::copy(calldata_offset, raw_proof_offset, std::back_inserter(calldata));
std::copy(raw_proof_offset, proof.end(), std::back_inserter(raw_proof));

VmPublicInputs public_inputs = convert_public_inputs(public_inputs_vec);
std::vector<std::vector<FF>> public_inputs_columns = copy_public_inputs_columns(public_inputs);
std::vector<std::vector<FF>> public_inputs_columns = copy_public_inputs_columns(public_inputs, calldata);
return verifier.verify_proof(raw_proof, public_inputs_columns);
}

Expand Down Expand Up @@ -309,7 +320,7 @@ std::vector<Row> Execution::gen_trace(std::vector<Instruction> const& instructio
uint32_t start_side_effect_counter =
!public_inputs_vec.empty() ? static_cast<uint32_t>(public_inputs_vec[PCPI_START_SIDE_EFFECT_COUNTER_OFFSET])
: 0;
AvmTraceBuilder trace_builder(public_inputs, execution_hints, start_side_effect_counter);
AvmTraceBuilder trace_builder(public_inputs, execution_hints, start_side_effect_counter, calldata);

// Copied version of pc maintained in trace builder. The value of pc is evolving based
// on opcode logic and therefore is not maintained here. However, the next opcode in the execution
Expand Down Expand Up @@ -436,8 +447,7 @@ std::vector<Row> Execution::gen_trace(std::vector<Instruction> const& instructio
trace_builder.op_calldata_copy(std::get<uint8_t>(inst.operands.at(0)),
std::get<uint32_t>(inst.operands.at(1)),
std::get<uint32_t>(inst.operands.at(2)),
std::get<uint32_t>(inst.operands.at(3)),
calldata);
std::get<uint32_t>(inst.operands.at(3)));
break;
// Machine State - Gas
case OpCode::L2GASLEFT:
Expand Down
6 changes: 4 additions & 2 deletions barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ bool is_operand_indirect(uint8_t ind_value, uint8_t operand_idx)
return static_cast<bool>((ind_value & (1 << operand_idx)) >> operand_idx);
}

std::vector<std::vector<FF>> copy_public_inputs_columns(VmPublicInputs const& public_inputs)
std::vector<std::vector<FF>> copy_public_inputs_columns(VmPublicInputs const& public_inputs,
std::vector<FF> const& calldata)
{
// We convert to a vector as the pil generated verifier is generic and unaware of the KERNEL_INPUTS_LENGTH
// For each of the public input vectors
Expand All @@ -158,7 +159,8 @@ std::vector<std::vector<FF>> copy_public_inputs_columns(VmPublicInputs const& pu
return { std::move(public_inputs_kernel_inputs),
std::move(public_inputs_kernel_value_outputs),
std::move(public_inputs_kernel_side_effect_outputs),
std::move(public_inputs_kernel_metadata_outputs) };
std::move(public_inputs_kernel_metadata_outputs),
calldata };
}

} // namespace bb::avm_trace
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ bool is_operand_indirect(uint8_t ind_value, uint8_t operand_idx);
// There are 4 public input columns, one for inputs, and 3 for the kernel outputs {value, side effect counter, metadata}
// The verifier is generic, and so accepts vectors of these values rather than the fixed length arrays that are used
// during circuit building. This method copies each array into a vector to be used by the verifier.
std::vector<std::vector<FF>> copy_public_inputs_columns(VmPublicInputs const& public_inputs);
std::vector<std::vector<FF>> copy_public_inputs_columns(VmPublicInputs const& public_inputs,
std::vector<FF> const& calldata);

} // namespace bb::avm_trace
51 changes: 39 additions & 12 deletions barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ namespace bb::avm_trace {
*/
AvmTraceBuilder::AvmTraceBuilder(VmPublicInputs public_inputs,
ExecutionHints execution_hints,
uint32_t side_effect_counter)
uint32_t side_effect_counter,
std::vector<FF> calldata)
// NOTE: we initialise the environment builder here as it requires public inputs
: kernel_trace_builder(std::move(public_inputs))
, calldata(std::move(calldata))
, side_effect_counter(side_effect_counter)
, initial_side_effect_counter(side_effect_counter)
, execution_hints(std::move(execution_hints))
Expand Down Expand Up @@ -1886,10 +1888,8 @@ void AvmTraceBuilder::op_div(
* @param cd_offset The starting index of the region in calldata to be copied.
* @param copy_size The number of finite field elements to be copied into memory.
* @param dst_offset The starting index of memory where calldata will be copied to.
* @param call_data_mem The vector containing calldata.
*/
void AvmTraceBuilder::op_calldata_copy(
uint8_t indirect, uint32_t cd_offset, uint32_t copy_size, uint32_t dst_offset, std::vector<FF> const& call_data_mem)
void AvmTraceBuilder::op_calldata_copy(uint8_t indirect, uint32_t cd_offset, uint32_t copy_size, uint32_t dst_offset)
{
// We parallelize storing memory operations in chunk of 3, i.e., 1 per intermediate register.
// The variable pos is an index pointing to the first storing operation (pertaining to intermediate
Expand All @@ -1912,7 +1912,7 @@ void AvmTraceBuilder::op_calldata_copy(
uint32_t rwc(0);
auto clk = static_cast<uint32_t>(main_trace.size()) + 1;

FF ia = call_data_mem.at(cd_offset + pos);
FF ia = calldata.at(cd_offset + pos);
uint32_t mem_op_a(1);
uint32_t rwa = 1;

Expand All @@ -1934,7 +1934,7 @@ void AvmTraceBuilder::op_calldata_copy(
call_ptr, clk, IntermRegister::IA, mem_addr_a, ia, AvmMemoryTag::U0, AvmMemoryTag::FF);

if (copy_size - pos > 1) {
ib = call_data_mem.at(cd_offset + pos + 1);
ib = calldata.at(cd_offset + pos + 1);
mem_op_b = 1;
mem_addr_b = direct_dst_offset + pos + 1;
rwb = 1;
Expand All @@ -1945,7 +1945,7 @@ void AvmTraceBuilder::op_calldata_copy(
}

if (copy_size - pos > 2) {
ic = call_data_mem.at(cd_offset + pos + 2);
ic = calldata.at(cd_offset + pos + 2);
mem_op_c = 1;
mem_addr_c = direct_dst_offset + pos + 2;
rwc = 1;
Expand Down Expand Up @@ -3762,7 +3762,9 @@ std::vector<Row> AvmTraceBuilder::finalize(uint32_t min_trace_size, bool range_c

main_trace.at(*trace_size - 1).main_sel_last = FF(1);

// Memory trace inclusion
/**********************************************************************************************
* MEMORY TRACE INCLUSION
**********************************************************************************************/

// We compute in the main loop the timestamp and global address for next row.
// Perform initialization for index 0 outside of the loop provided that mem trace exists.
Expand Down Expand Up @@ -3866,7 +3868,10 @@ std::vector<Row> AvmTraceBuilder::finalize(uint32_t min_trace_size, bool range_c
}
}

// Alu trace inclusion
/**********************************************************************************************
* ALU TRACE INCLUSION
**********************************************************************************************/

for (size_t i = 0; i < alu_trace_size; i++) {
auto const& src = alu_trace.at(i);
auto& dest = main_trace.at(i);
Expand Down Expand Up @@ -4013,6 +4018,10 @@ std::vector<Row> AvmTraceBuilder::finalize(uint32_t min_trace_size, bool range_c
}
}

/**********************************************************************************************
* GADGET TABLES INCLUSION
**********************************************************************************************/

// Add Conversion Gadget table
for (size_t i = 0; i < conv_trace_size; i++) {
auto const& src = conv_trace.at(i);
Expand Down Expand Up @@ -4067,6 +4076,10 @@ std::vector<Row> AvmTraceBuilder::finalize(uint32_t min_trace_size, bool range_c
dest.pedersen_sel_pedersen = FF(1);
}

/**********************************************************************************************
* BINARY TRACE INCLUSION
**********************************************************************************************/

// Add Binary Trace table
for (size_t i = 0; i < bin_trace_size; i++) {
auto const& src = bin_trace.at(i);
Expand Down Expand Up @@ -4132,7 +4145,9 @@ std::vector<Row> AvmTraceBuilder::finalize(uint32_t min_trace_size, bool range_c
}
}

/////////// GAS ACCOUNTING //////////////////////////
/**********************************************************************************************
* GAS TRACE INCLUSION
**********************************************************************************************/

// Add the gas cost table to the main trace
// TODO: do i need a way to produce an interupt that will stop the execution of the trace when the gas left
Expand Down Expand Up @@ -4222,11 +4237,14 @@ std::vector<Row> AvmTraceBuilder::finalize(uint32_t min_trace_size, bool range_c
dest.main_da_gas_remaining = current_da_gas_remaining;
}

/////////// END OF GAS ACCOUNTING //////////////////////////

// Adding extra row for the shifted values at the top of the execution trace.
Row first_row = Row{ .main_sel_first = FF(1), .mem_lastAccess = FF(1) };
main_trace.insert(main_trace.begin(), first_row);

/**********************************************************************************************
* RANGE CHECKS AND SELECTORS INCLUSION
**********************************************************************************************/

auto const old_trace_size = main_trace.size();

auto new_trace_size = range_check_required ? old_trace_size
Expand Down Expand Up @@ -4316,6 +4334,10 @@ std::vector<Row> AvmTraceBuilder::finalize(uint32_t min_trace_size, bool range_c
}
}

/**********************************************************************************************
* KERNEL TRACE INCLUSION
**********************************************************************************************/

// Write the kernel trace into the main trace
// 1. The write offsets are constrained to be non changing over the entire trace, so we fill in the values
// until we
Expand Down Expand Up @@ -4494,6 +4516,11 @@ std::vector<Row> AvmTraceBuilder::finalize(uint32_t min_trace_size, bool range_c
std::get<KERNEL_OUTPUTS_METADATA>(kernel_trace_builder.public_inputs).at(i);
}

// calldata column inclusion
for (size_t i = 0; i < calldata.size(); i++) {
main_trace.at(i).main_calldata = calldata.at(i);
}

// Get tag_err counts from the mem_trace_builder
if (range_check_required) {
finalise_mem_trace_lookup_counts();
Expand Down
11 changes: 5 additions & 6 deletions barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_trace.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class AvmTraceBuilder {
public:
AvmTraceBuilder(VmPublicInputs public_inputs = {},
ExecutionHints execution_hints = {},
uint32_t side_effect_counter = 0);
uint32_t side_effect_counter = 0,
std::vector<FF> calldata = {});

std::vector<Row> finalize(uint32_t min_trace_size = 0, bool range_check_required = ENABLE_PROVING);
void reset();
Expand Down Expand Up @@ -158,11 +159,7 @@ class AvmTraceBuilder {
// CALLDATACOPY opcode with direct/indirect memory access, i.e.,
// direct: M[dst_offset:dst_offset+copy_size] = calldata[cd_offset:cd_offset+copy_size]
// indirect: M[M[dst_offset]:M[dst_offset]+copy_size] = calldata[cd_offset:cd_offset+copy_size]
void op_calldata_copy(uint8_t indirect,
uint32_t cd_offset,
uint32_t copy_size,
uint32_t dst_offset,
std::vector<FF> const& call_data_mem);
void op_calldata_copy(uint8_t indirect, uint32_t cd_offset, uint32_t copy_size, uint32_t dst_offset);

// REVERT Opcode (that just call return under the hood for now)
std::vector<FF> op_revert(uint8_t indirect, uint32_t ret_offset, uint32_t ret_size);
Expand Down Expand Up @@ -241,6 +238,8 @@ class AvmTraceBuilder {
AvmPedersenTraceBuilder pedersen_trace_builder;
AvmEccTraceBuilder ecc_trace_builder;

std::vector<FF> calldata{};

/**
* @brief Create a kernel lookup opcode object
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ template <typename FF> std::vector<std::string> AvmFullRow<FF>::names()
"kernel_kernel_value_out",
"kernel_kernel_side_effect_out",
"kernel_kernel_metadata_out",
"main_calldata",
"alu_a_hi",
"alu_a_lo",
"alu_b_hi",
Expand Down Expand Up @@ -412,18 +413,19 @@ template <typename FF> std::ostream& operator<<(std::ostream& os, AvmFullRow<FF>
<< field_to_string(row.main_clk) << "," << field_to_string(row.main_sel_first) << ","
<< field_to_string(row.kernel_kernel_inputs) << "," << field_to_string(row.kernel_kernel_value_out) << ","
<< field_to_string(row.kernel_kernel_side_effect_out) << ","
<< field_to_string(row.kernel_kernel_metadata_out) << "," << field_to_string(row.alu_a_hi) << ","
<< field_to_string(row.alu_a_lo) << "," << field_to_string(row.alu_b_hi) << ","
<< field_to_string(row.alu_b_lo) << "," << field_to_string(row.alu_borrow) << ","
<< field_to_string(row.alu_cf) << "," << field_to_string(row.alu_clk) << ","
<< field_to_string(row.alu_cmp_rng_ctr) << "," << field_to_string(row.alu_div_u16_r0) << ","
<< field_to_string(row.alu_div_u16_r1) << "," << field_to_string(row.alu_div_u16_r2) << ","
<< field_to_string(row.alu_div_u16_r3) << "," << field_to_string(row.alu_div_u16_r4) << ","
<< field_to_string(row.alu_div_u16_r5) << "," << field_to_string(row.alu_div_u16_r6) << ","
<< field_to_string(row.alu_div_u16_r7) << "," << field_to_string(row.alu_divisor_hi) << ","
<< field_to_string(row.alu_divisor_lo) << "," << field_to_string(row.alu_ff_tag) << ","
<< field_to_string(row.alu_ia) << "," << field_to_string(row.alu_ib) << "," << field_to_string(row.alu_ic)
<< "," << field_to_string(row.alu_in_tag) << "," << field_to_string(row.alu_op_add) << ","
<< field_to_string(row.kernel_kernel_metadata_out) << "," << field_to_string(row.main_calldata) << ","
<< field_to_string(row.alu_a_hi) << "," << field_to_string(row.alu_a_lo) << ","
<< field_to_string(row.alu_b_hi) << "," << field_to_string(row.alu_b_lo) << ","
<< field_to_string(row.alu_borrow) << "," << field_to_string(row.alu_cf) << ","
<< field_to_string(row.alu_clk) << "," << field_to_string(row.alu_cmp_rng_ctr) << ","
<< field_to_string(row.alu_div_u16_r0) << "," << field_to_string(row.alu_div_u16_r1) << ","
<< field_to_string(row.alu_div_u16_r2) << "," << field_to_string(row.alu_div_u16_r3) << ","
<< field_to_string(row.alu_div_u16_r4) << "," << field_to_string(row.alu_div_u16_r5) << ","
<< field_to_string(row.alu_div_u16_r6) << "," << field_to_string(row.alu_div_u16_r7) << ","
<< field_to_string(row.alu_divisor_hi) << "," << field_to_string(row.alu_divisor_lo) << ","
<< field_to_string(row.alu_ff_tag) << "," << field_to_string(row.alu_ia) << ","
<< field_to_string(row.alu_ib) << "," << field_to_string(row.alu_ic) << ","
<< field_to_string(row.alu_in_tag) << "," << field_to_string(row.alu_op_add) << ","
<< field_to_string(row.alu_op_cast) << "," << field_to_string(row.alu_op_cast_prev) << ","
<< field_to_string(row.alu_op_div) << "," << field_to_string(row.alu_op_div_a_lt_b) << ","
<< field_to_string(row.alu_op_div_std) << "," << field_to_string(row.alu_op_eq) << ","
Expand Down
Loading

0 comments on commit 79e8588

Please sign in to comment.