From 197fe73325e86296bf261684f1e8803e9f14762f Mon Sep 17 00:00:00 2001 From: David Banks <47112877+dbanks12@users.noreply.github.com> Date: Tue, 3 Dec 2024 10:05:22 -0500 Subject: [PATCH] chore: align AVM witgen's limits on number of side effects with AVM simulator. Witgen supports phases and rollbacks. (#10329) 1. Aligns side effect limits between TS and CPP 2. Adds Noir test functions to spam side effects & adds that to proving test 2. Adds support to witgen for tx phases 3. Never expects a read hint for nullifier writes. Always just uses the write hint. 4. Adds an argument to finalize and `gen_trace` to skip end-gas assertions (not sure this is the best way, but all the tests pass) 5. renames TS hint vectors to all just be `*_reads` or `*_writes` instead of `*read/update_requests` etc. Work needed in a follow-up PR: - separate the opcode switch-case & phase management/enqueued-calls into separate functions or even separate files in witgen --- .../src/barretenberg/vm/avm/trace/errors.hpp | 2 + .../barretenberg/vm/avm/trace/execution.cpp | 1119 +++++++++-------- .../barretenberg/vm/avm/trace/execution.hpp | 11 +- .../vm/avm/trace/gadgets/merkle_tree.cpp | 9 + .../vm/avm/trace/gadgets/merkle_tree.hpp | 4 + .../src/barretenberg/vm/avm/trace/helper.cpp | 11 + .../src/barretenberg/vm/avm/trace/helper.hpp | 1 + .../vm/avm/trace/public_inputs.hpp | 21 + .../src/barretenberg/vm/avm/trace/trace.cpp | 168 ++- .../src/barretenberg/vm/avm/trace/trace.hpp | 8 +- .../contracts/avm_test_contract/src/main.nr | 36 + .../bb-prover/src/avm_proving.test.ts | 86 +- yarn-project/bb-prover/src/test/index.ts | 1 - yarn-project/bb-prover/src/test/test_avm.ts | 85 -- .../circuits.js/src/structs/avm/avm.ts | 85 +- .../circuits.js/src/tests/factories.ts | 14 +- .../simulator/src/avm/journal/journal.ts | 5 + .../enqueued_call_side_effect_trace.test.ts | 28 +- .../public/enqueued_call_side_effect_trace.ts | 30 +- .../simulator/src/public/fixtures/index.ts | 12 +- .../simulator/src/public/public_tx_context.ts | 2 +- .../simulator/src/public/side_effect_trace.ts | 14 +- .../src/public/transitional_adapters.ts | 191 +-- 23 files changed, 996 insertions(+), 947 deletions(-) delete mode 100644 yarn-project/bb-prover/src/test/test_avm.ts diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/trace/errors.hpp b/barretenberg/cpp/src/barretenberg/vm/avm/trace/errors.hpp index e31d486e502b..ca121ebefa29 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/trace/errors.hpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/trace/errors.hpp @@ -6,6 +6,7 @@ namespace bb::avm_trace { enum class AvmError : uint32_t { NO_ERROR, + REVERT_OPCODE, INVALID_PROGRAM_COUNTER, INVALID_OPCODE, INVALID_TAG_VALUE, @@ -18,6 +19,7 @@ enum class AvmError : uint32_t { CONTRACT_INST_MEM_UNKNOWN, RADIX_OUT_OF_BOUNDS, DUPLICATE_NULLIFIER, + SIDE_EFFECT_LIMIT_REACHED, }; } // namespace bb::avm_trace diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/trace/execution.cpp b/barretenberg/cpp/src/barretenberg/vm/avm/trace/execution.cpp index d5f733a25838..bd8e2e91186d 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/trace/execution.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/trace/execution.cpp @@ -38,6 +38,22 @@ using namespace bb; std::filesystem::path avm_dump_trace_path; namespace bb::avm_trace { + +std::string to_name(TxExecutionPhase phase) +{ + switch (phase) { + case TxExecutionPhase::SETUP: + return "SETUP"; + case TxExecutionPhase::APP_LOGIC: + return "APP_LOGIC"; + case TxExecutionPhase::TEARDOWN: + return "TEARDOWN"; + default: + throw std::runtime_error("Invalid tx phase"); + break; + } +} + namespace { // The SRS needs to be able to accommodate the circuit subgroup size. @@ -183,7 +199,8 @@ std::tuple Execution::prove(AvmPublicInpu for (const auto& enqueued_call_hints : execution_hints.enqueued_call_hints) { calldata.insert(calldata.end(), enqueued_call_hints.calldata.begin(), enqueued_call_hints.calldata.end()); } - std::vector trace = AVM_TRACK_TIME_V("prove/gen_trace", gen_trace(public_inputs, returndata, execution_hints)); + std::vector trace = AVM_TRACK_TIME_V( + "prove/gen_trace", gen_trace(public_inputs, returndata, execution_hints, /*apply_end_gas_assertions=*/true)); if (!avm_dump_trace_path.empty()) { info("Dumping trace as CSV to: " + avm_dump_trace_path.string()); dump_trace_as_csv(trace, avm_dump_trace_path); @@ -265,7 +282,8 @@ bool Execution::verify(AvmFlavor::VerificationKey vk, HonkProof const& proof) */ std::vector Execution::gen_trace(AvmPublicInputs const& public_inputs, std::vector& returndata, - ExecutionHints const& execution_hints) + ExecutionHints const& execution_hints, + bool apply_end_gas_assertions) { vinfo("------- GENERATING TRACE -------"); @@ -281,578 +299,595 @@ std::vector Execution::gen_trace(AvmPublicInputs const& public_inputs, AvmTraceBuilder trace_builder = Execution::trace_builder_constructor(public_inputs, execution_hints, start_side_effect_counter, calldata); - std::vector public_call_requests; - for (const auto& setup_requests : public_inputs.public_setup_call_requests) { - if (setup_requests.contract_address != 0) { - public_call_requests.push_back(setup_requests); - } - } - size_t setup_counter = public_call_requests.size(); - - for (const auto& app_requests : public_inputs.public_app_logic_call_requests) { - if (app_requests.contract_address != 0) { - public_call_requests.push_back(app_requests); - } - } - // We should not need to guard teardown, but while we are testing with handcrafted txs we do - if (public_inputs.public_teardown_call_request.contract_address != 0) { - public_call_requests.push_back(public_inputs.public_teardown_call_request); - } - // Temporary spot for private non-revertible insertion - std::vector siloed_nullifier; - siloed_nullifier.insert(siloed_nullifier.end(), - public_inputs.accumulated_data.nullifiers.begin(), - public_inputs.accumulated_data.nullifiers.begin() + - public_inputs.previous_non_revertible_accumulated_data_array_lengths.nullifiers); - trace_builder.insert_private_state(siloed_nullifier, {}); + std::vector siloed_nullifiers; + siloed_nullifiers.insert(siloed_nullifiers.end(), + public_inputs.accumulated_data.nullifiers.begin(), + public_inputs.accumulated_data.nullifiers.begin() + + public_inputs.previous_non_revertible_accumulated_data_array_lengths.nullifiers); + trace_builder.insert_private_state(siloed_nullifiers, {}); + trace_builder.checkpoint_non_revertible_state(); + + std::array public_teardown_call_requests{}; + public_teardown_call_requests[0] = public_inputs.public_teardown_call_request; // Loop over all the public call requests uint8_t call_ctx = 0; - for (size_t i = 0; i < public_call_requests.size(); i++) { - - // When we get this, it means we have done our non-revertible setup phase - if (i + 1 == setup_counter) { - // Temporary spot for private revertible insertion - std::vector siloed_nullifiers; - siloed_nullifiers.insert(siloed_nullifiers.end(), - public_inputs.previous_revertible_accumulated_data.nullifiers.begin(), - public_inputs.previous_revertible_accumulated_data.nullifiers.begin() + - public_inputs.previous_revertible_accumulated_data_array_lengths.nullifiers); - trace_builder.insert_private_state(siloed_nullifiers, {}); + auto const phases = { TxExecutionPhase::SETUP, TxExecutionPhase::APP_LOGIC, TxExecutionPhase::TEARDOWN }; + for (auto phase : phases) { + auto call_requests_array = phase == TxExecutionPhase::SETUP ? public_inputs.public_setup_call_requests + : phase == TxExecutionPhase::APP_LOGIC ? public_inputs.public_app_logic_call_requests + : public_teardown_call_requests; + std::vector public_call_requests; + for (const auto& call_request : call_requests_array) { + if (call_request.contract_address != 0) { + public_call_requests.push_back(call_request); + } } - - auto public_call_request = public_call_requests.at(i); - trace_builder.set_public_call_request(public_call_request); - trace_builder.set_call_ptr(call_ctx++); - - // Find the bytecode based on contract address of the public call request - const std::vector& bytecode = - std::ranges::find_if(execution_hints.all_contract_bytecode, [public_call_request](const auto& contract) { - return contract.contract_instance.address == public_call_request.contract_address; - })->bytecode; - info("Found bytecode for contract address: ", public_call_request.contract_address); - - // Set this also on nested call - - // Copied version of pc maintained in trace builder. The value of pc is evolving based - // on opcode logic and therefore is not maintained here. However, the next opcode in the execution - // is determined by this value which require read access to the code below. - uint32_t pc = 0; - uint32_t counter = 0; + info("Beginning execution of phase ", to_name(phase), " (", public_call_requests.size(), " enqueued calls)."); AvmError error = AvmError::NO_ERROR; - while (is_ok(error) && (pc = trace_builder.get_pc()) < bytecode.size()) { - auto [inst, parse_error] = Deserialization::parse(bytecode, pc); - error = parse_error; - - if (!is_ok(error)) { - break; + for (size_t i = 0; i < public_call_requests.size(); i++) { + + // When we get this, it means we have done our non-revertible setup phase + if (phase == TxExecutionPhase::SETUP) { + // Temporary spot for private revertible insertion + std::vector siloed_nullifiers; + siloed_nullifiers.insert( + siloed_nullifiers.end(), + public_inputs.previous_revertible_accumulated_data.nullifiers.begin(), + public_inputs.previous_revertible_accumulated_data.nullifiers.begin() + + public_inputs.previous_revertible_accumulated_data_array_lengths.nullifiers); + trace_builder.insert_private_state(siloed_nullifiers, {}); } - debug("[PC:" + std::to_string(pc) + "] [IC:" + std::to_string(counter++) + "] " + inst.to_string() + - " (gasLeft l2=" + std::to_string(trace_builder.get_l2_gas_left()) + ")"); - - switch (inst.op_code) { - // Compute - // Compute - Arithmetic - case OpCode::ADD_8: - error = trace_builder.op_add(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::ADD_8); - break; - case OpCode::ADD_16: - error = trace_builder.op_add(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::ADD_16); - break; - case OpCode::SUB_8: - error = trace_builder.op_sub(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::SUB_8); - break; - case OpCode::SUB_16: - error = trace_builder.op_sub(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::SUB_16); - break; - case OpCode::MUL_8: - error = trace_builder.op_mul(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::MUL_8); - break; - case OpCode::MUL_16: - error = trace_builder.op_mul(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::MUL_16); - break; - case OpCode::DIV_8: - error = trace_builder.op_div(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::DIV_8); - break; - case OpCode::DIV_16: - error = trace_builder.op_div(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::DIV_16); - break; - case OpCode::FDIV_8: - error = trace_builder.op_fdiv(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::FDIV_8); - break; - case OpCode::FDIV_16: - error = trace_builder.op_fdiv(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::FDIV_16); - break; - case OpCode::EQ_8: - error = trace_builder.op_eq(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::EQ_8); - break; - case OpCode::EQ_16: - error = trace_builder.op_eq(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::EQ_16); - break; - case OpCode::LT_8: - error = trace_builder.op_lt(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::LT_8); - break; - case OpCode::LT_16: - error = trace_builder.op_lt(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::LT_16); - break; - case OpCode::LTE_8: - error = trace_builder.op_lte(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::LTE_8); - break; - case OpCode::LTE_16: - error = trace_builder.op_lte(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::LTE_16); - break; - case OpCode::AND_8: - error = trace_builder.op_and(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::AND_8); - break; - case OpCode::AND_16: - error = trace_builder.op_and(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::AND_16); - break; - case OpCode::OR_8: - error = trace_builder.op_or(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::OR_8); - break; - case OpCode::OR_16: - error = trace_builder.op_or(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::OR_16); - break; - case OpCode::XOR_8: - error = trace_builder.op_xor(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::XOR_8); - break; - case OpCode::XOR_16: - error = trace_builder.op_xor(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::XOR_16); - break; - case OpCode::NOT_8: - error = trace_builder.op_not(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - OpCode::NOT_8); - break; - case OpCode::NOT_16: - error = trace_builder.op_not(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - OpCode::NOT_16); - break; - case OpCode::SHL_8: - error = trace_builder.op_shl(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::SHL_8); - break; - case OpCode::SHL_16: - error = trace_builder.op_shl(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::SHL_16); - break; - case OpCode::SHR_8: - error = trace_builder.op_shr(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::SHR_8); - break; - case OpCode::SHR_16: - error = trace_builder.op_shr(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::SHR_16); - break; - - // Compute - Type Conversions - case OpCode::CAST_8: - error = trace_builder.op_cast(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::CAST_8); - break; - case OpCode::CAST_16: - error = trace_builder.op_cast(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::CAST_16); - break; - - // Execution Environment - // TODO(https://github.com/AztecProtocol/aztec-packages/issues/6284): support indirect for below - case OpCode::GETENVVAR_16: - error = trace_builder.op_get_env_var(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2))); - break; - - // Execution Environment - Calldata - case OpCode::CALLDATACOPY: - error = trace_builder.op_calldata_copy(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3))); - break; - - case OpCode::RETURNDATASIZE: - error = trace_builder.op_returndata_size(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1))); - break; + auto public_call_request = public_call_requests.at(i); + trace_builder.set_public_call_request(public_call_request); + trace_builder.set_call_ptr(call_ctx++); + + // Find the bytecode based on contract address of the public call request + const std::vector& bytecode = std::ranges::find_if(execution_hints.all_contract_bytecode, + [public_call_request](const auto& contract) { + return contract.contract_instance.address == + public_call_request.contract_address; + }) + ->bytecode; + info("Found bytecode for contract address: ", public_call_request.contract_address); + + // Set this also on nested call + + // Copied version of pc maintained in trace builder. The value of pc is evolving based + // on opcode logic and therefore is not maintained here. However, the next opcode in the execution + // is determined by this value which require read access to the code below. + uint32_t pc = 0; + uint32_t counter = 0; + while (is_ok(error) && (pc = trace_builder.get_pc()) < bytecode.size()) { + auto [inst, parse_error] = Deserialization::parse(bytecode, pc); + error = parse_error; + + if (!is_ok(error)) { + break; + } - case OpCode::RETURNDATACOPY: - error = trace_builder.op_returndata_copy(std::get(inst.operands.at(0)), + debug("[PC:" + std::to_string(pc) + "] [IC:" + std::to_string(counter++) + "] " + inst.to_string() + + " (gasLeft l2=" + std::to_string(trace_builder.get_l2_gas_left()) + ")"); + + switch (inst.op_code) { + // Compute + // Compute - Arithmetic + case OpCode::ADD_8: + error = trace_builder.op_add(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::ADD_8); + break; + case OpCode::ADD_16: + error = trace_builder.op_add(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::ADD_16); + break; + case OpCode::SUB_8: + error = trace_builder.op_sub(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::SUB_8); + break; + case OpCode::SUB_16: + error = trace_builder.op_sub(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::SUB_16); + break; + case OpCode::MUL_8: + error = trace_builder.op_mul(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::MUL_8); + break; + case OpCode::MUL_16: + error = trace_builder.op_mul(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::MUL_16); + break; + case OpCode::DIV_8: + error = trace_builder.op_div(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::DIV_8); + break; + case OpCode::DIV_16: + error = trace_builder.op_div(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::DIV_16); + break; + case OpCode::FDIV_8: + error = trace_builder.op_fdiv(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::FDIV_8); + break; + case OpCode::FDIV_16: + error = trace_builder.op_fdiv(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::FDIV_16); + break; + case OpCode::EQ_8: + error = trace_builder.op_eq(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::EQ_8); + break; + case OpCode::EQ_16: + error = trace_builder.op_eq(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::EQ_16); + break; + case OpCode::LT_8: + error = trace_builder.op_lt(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::LT_8); + break; + case OpCode::LT_16: + error = trace_builder.op_lt(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::LT_16); + break; + case OpCode::LTE_8: + error = trace_builder.op_lte(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::LTE_8); + break; + case OpCode::LTE_16: + error = trace_builder.op_lte(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::LTE_16); + break; + case OpCode::AND_8: + error = trace_builder.op_and(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::AND_8); + break; + case OpCode::AND_16: + error = trace_builder.op_and(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::AND_16); + break; + case OpCode::OR_8: + error = trace_builder.op_or(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::OR_8); + break; + case OpCode::OR_16: + error = trace_builder.op_or(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::OR_16); + break; + case OpCode::XOR_8: + error = trace_builder.op_xor(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::XOR_8); + break; + case OpCode::XOR_16: + error = trace_builder.op_xor(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::XOR_16); + break; + case OpCode::NOT_8: + error = trace_builder.op_not(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + OpCode::NOT_8); + break; + case OpCode::NOT_16: + error = trace_builder.op_not(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + OpCode::NOT_16); + break; + case OpCode::SHL_8: + error = trace_builder.op_shl(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::SHL_8); + break; + case OpCode::SHL_16: + error = trace_builder.op_shl(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::SHL_16); + break; + case OpCode::SHR_8: + error = trace_builder.op_shr(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::SHR_8); + break; + case OpCode::SHR_16: + error = trace_builder.op_shr(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::SHR_16); + break; + + // Compute - Type Conversions + case OpCode::CAST_8: + error = trace_builder.op_cast(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::CAST_8); + break; + case OpCode::CAST_16: + error = trace_builder.op_cast(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::CAST_16); + break; + + // Execution Environment + // TODO(https://github.com/AztecProtocol/aztec-packages/issues/6284): support indirect for below + case OpCode::GETENVVAR_16: + error = trace_builder.op_get_env_var(std::get(inst.operands.at(0)), std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3))); - break; + std::get(inst.operands.at(2))); + break; - // Machine State - Internal Control Flow - case OpCode::JUMP_32: - error = trace_builder.op_jump(std::get(inst.operands.at(0))); - break; - case OpCode::JUMPI_32: - error = trace_builder.op_jumpi(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2))); - break; - case OpCode::INTERNALCALL: - error = trace_builder.op_internal_call(std::get(inst.operands.at(0))); - break; - case OpCode::INTERNALRETURN: - error = trace_builder.op_internal_return(); - break; - - // Machine State - Memory - case OpCode::SET_8: { - error = trace_builder.op_set(std::get(inst.operands.at(0)), - std::get(inst.operands.at(3)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - OpCode::SET_8); - break; - } - case OpCode::SET_16: { - error = trace_builder.op_set(std::get(inst.operands.at(0)), - std::get(inst.operands.at(3)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - OpCode::SET_16); - break; - } - case OpCode::SET_32: { - error = trace_builder.op_set(std::get(inst.operands.at(0)), - std::get(inst.operands.at(3)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - OpCode::SET_32); - break; - } - case OpCode::SET_64: { - error = trace_builder.op_set(std::get(inst.operands.at(0)), - std::get(inst.operands.at(3)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - OpCode::SET_64); - break; - } - case OpCode::SET_128: { - error = trace_builder.op_set(std::get(inst.operands.at(0)), - uint256_t::from_uint128(std::get(inst.operands.at(3))), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - OpCode::SET_128); - break; - } - case OpCode::SET_FF: { - error = trace_builder.op_set(std::get(inst.operands.at(0)), - std::get(inst.operands.at(3)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - OpCode::SET_FF); - break; - } - case OpCode::MOV_8: - error = trace_builder.op_mov(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - OpCode::MOV_8); - break; - case OpCode::MOV_16: - error = trace_builder.op_mov(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - OpCode::MOV_16); - break; + // Execution Environment - Calldata + case OpCode::CALLDATACOPY: + error = trace_builder.op_calldata_copy(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3))); + break; - // World State - case OpCode::SLOAD: - error = trace_builder.op_sload(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2))); - break; - case OpCode::SSTORE: - error = trace_builder.op_sstore(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2))); - break; - case OpCode::NOTEHASHEXISTS: - error = trace_builder.op_note_hash_exists(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3))); - break; - case OpCode::EMITNOTEHASH: - error = trace_builder.op_emit_note_hash(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1))); - break; - case OpCode::NULLIFIEREXISTS: - error = trace_builder.op_nullifier_exists(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3))); - break; - case OpCode::EMITNULLIFIER: - error = trace_builder.op_emit_nullifier(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1))); - break; + case OpCode::RETURNDATASIZE: + error = trace_builder.op_returndata_size(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1))); + break; - case OpCode::L1TOL2MSGEXISTS: - error = trace_builder.op_l1_to_l2_msg_exists(std::get(inst.operands.at(0)), + case OpCode::RETURNDATACOPY: + error = trace_builder.op_returndata_copy(std::get(inst.operands.at(0)), std::get(inst.operands.at(1)), std::get(inst.operands.at(2)), std::get(inst.operands.at(3))); - break; - case OpCode::GETCONTRACTINSTANCE: - error = trace_builder.op_get_contract_instance(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - std::get(inst.operands.at(4))); - break; - - // Accrued Substate - case OpCode::EMITUNENCRYPTEDLOG: - error = trace_builder.op_emit_unencrypted_log(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2))); - break; - case OpCode::SENDL2TOL1MSG: - error = trace_builder.op_emit_l2_to_l1_msg(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2))); - break; - - // Control Flow - Contract Calls - case OpCode::CALL: - error = trace_builder.op_call(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - std::get(inst.operands.at(4)), - std::get(inst.operands.at(5))); - break; - case OpCode::STATICCALL: - error = trace_builder.op_static_call(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - std::get(inst.operands.at(4)), - std::get(inst.operands.at(5))); - break; - case OpCode::RETURN: { - auto ret = trace_builder.op_return(std::get(inst.operands.at(0)), + break; + + // Machine State - Internal Control Flow + case OpCode::JUMP_32: + error = trace_builder.op_jump(std::get(inst.operands.at(0))); + break; + case OpCode::JUMPI_32: + error = trace_builder.op_jumpi(std::get(inst.operands.at(0)), std::get(inst.operands.at(1)), - std::get(inst.operands.at(2))); - error = ret.error; - returndata.insert(returndata.end(), ret.return_data.begin(), ret.return_data.end()); - - break; - } - case OpCode::REVERT_8: { - info("HIT REVERT_8 ", "[PC=" + std::to_string(pc) + "] " + inst.to_string()); - auto ret = trace_builder.op_revert(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2))); - error = ret.error; - returndata.insert(returndata.end(), ret.return_data.begin(), ret.return_data.end()); - - break; - } - case OpCode::REVERT_16: { - info("HIT REVERT_16 ", "[PC=" + std::to_string(pc) + "] " + inst.to_string()); - auto ret = trace_builder.op_revert(std::get(inst.operands.at(0)), + std::get(inst.operands.at(2))); + break; + case OpCode::INTERNALCALL: + error = trace_builder.op_internal_call(std::get(inst.operands.at(0))); + break; + case OpCode::INTERNALRETURN: + error = trace_builder.op_internal_return(); + break; + + // Machine State - Memory + case OpCode::SET_8: { + error = trace_builder.op_set(std::get(inst.operands.at(0)), + std::get(inst.operands.at(3)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + OpCode::SET_8); + break; + } + case OpCode::SET_16: { + error = trace_builder.op_set(std::get(inst.operands.at(0)), + std::get(inst.operands.at(3)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + OpCode::SET_16); + break; + } + case OpCode::SET_32: { + error = trace_builder.op_set(std::get(inst.operands.at(0)), + std::get(inst.operands.at(3)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + OpCode::SET_32); + break; + } + case OpCode::SET_64: { + error = trace_builder.op_set(std::get(inst.operands.at(0)), + std::get(inst.operands.at(3)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + OpCode::SET_64); + break; + } + case OpCode::SET_128: { + error = trace_builder.op_set(std::get(inst.operands.at(0)), + uint256_t::from_uint128(std::get(inst.operands.at(3))), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + OpCode::SET_128); + break; + } + case OpCode::SET_FF: { + error = trace_builder.op_set(std::get(inst.operands.at(0)), + std::get(inst.operands.at(3)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + OpCode::SET_FF); + break; + } + case OpCode::MOV_8: + error = trace_builder.op_mov(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + OpCode::MOV_8); + break; + case OpCode::MOV_16: + error = trace_builder.op_mov(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + OpCode::MOV_16); + break; + + // World State + case OpCode::SLOAD: + error = trace_builder.op_sload(std::get(inst.operands.at(0)), std::get(inst.operands.at(1)), std::get(inst.operands.at(2))); - error = ret.error; - returndata.insert(returndata.end(), ret.return_data.begin(), ret.return_data.end()); - - break; - } - - // Misc - case OpCode::DEBUGLOG: - error = trace_builder.op_debug_log(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - std::get(inst.operands.at(4))); - break; - - // Gadgets - case OpCode::POSEIDON2PERM: - error = trace_builder.op_poseidon2_permutation(std::get(inst.operands.at(0)), + break; + case OpCode::SSTORE: + error = trace_builder.op_sstore(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2))); + break; + case OpCode::NOTEHASHEXISTS: + error = trace_builder.op_note_hash_exists(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3))); + break; + case OpCode::EMITNOTEHASH: + error = trace_builder.op_emit_note_hash(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1))); + break; + case OpCode::NULLIFIEREXISTS: + error = trace_builder.op_nullifier_exists(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3))); + break; + case OpCode::EMITNULLIFIER: + error = trace_builder.op_emit_nullifier(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1))); + break; + + case OpCode::L1TOL2MSGEXISTS: + error = trace_builder.op_l1_to_l2_msg_exists(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3))); + break; + case OpCode::GETCONTRACTINSTANCE: + error = trace_builder.op_get_contract_instance(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + std::get(inst.operands.at(4))); + break; + + // Accrued Substate + case OpCode::EMITUNENCRYPTEDLOG: + error = trace_builder.op_emit_unencrypted_log(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2))); + break; + case OpCode::SENDL2TOL1MSG: + error = trace_builder.op_emit_l2_to_l1_msg(std::get(inst.operands.at(0)), std::get(inst.operands.at(1)), std::get(inst.operands.at(2))); + break; + + // Control Flow - Contract Calls + case OpCode::CALL: + error = trace_builder.op_call(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + std::get(inst.operands.at(4)), + std::get(inst.operands.at(5))); + break; + case OpCode::STATICCALL: + error = trace_builder.op_static_call(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + std::get(inst.operands.at(4)), + std::get(inst.operands.at(5))); + break; + case OpCode::RETURN: { + auto ret = trace_builder.op_return(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2))); + error = ret.error; + returndata.insert(returndata.end(), ret.return_data.begin(), ret.return_data.end()); - break; - - case OpCode::SHA256COMPRESSION: - error = trace_builder.op_sha256_compression(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3))); - break; - - case OpCode::KECCAKF1600: - error = trace_builder.op_keccakf1600(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2))); + break; + } + case OpCode::REVERT_8: { + info("HIT REVERT_8 ", "[PC=" + std::to_string(pc) + "] " + inst.to_string()); + auto ret = trace_builder.op_revert(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2))); + error = ret.error; + returndata.insert(returndata.end(), ret.return_data.begin(), ret.return_data.end()); + + break; + } + case OpCode::REVERT_16: { + info("HIT REVERT_16 ", "[PC=" + std::to_string(pc) + "] " + inst.to_string()); + auto ret = trace_builder.op_revert(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2))); + error = ret.error; + returndata.insert(returndata.end(), ret.return_data.begin(), ret.return_data.end()); - break; + break; + } - case OpCode::ECADD: - error = trace_builder.op_ec_add(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - std::get(inst.operands.at(4)), - std::get(inst.operands.at(5)), - std::get(inst.operands.at(6)), - std::get(inst.operands.at(7))); - break; - case OpCode::MSM: - error = trace_builder.op_variable_msm(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - std::get(inst.operands.at(4))); - break; + // Misc + case OpCode::DEBUGLOG: + error = trace_builder.op_debug_log(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + std::get(inst.operands.at(4))); + break; + + // Gadgets + case OpCode::POSEIDON2PERM: + error = trace_builder.op_poseidon2_permutation(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2))); + + break; + + case OpCode::SHA256COMPRESSION: + error = trace_builder.op_sha256_compression(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3))); + break; + + case OpCode::KECCAKF1600: + error = trace_builder.op_keccakf1600(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2))); + + break; + + case OpCode::ECADD: + error = trace_builder.op_ec_add(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + std::get(inst.operands.at(4)), + std::get(inst.operands.at(5)), + std::get(inst.operands.at(6)), + std::get(inst.operands.at(7))); + break; + case OpCode::MSM: + error = trace_builder.op_variable_msm(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + std::get(inst.operands.at(4))); + break; - // Conversions - case OpCode::TORADIXBE: - error = trace_builder.op_to_radix_be(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - std::get(inst.operands.at(4)), - std::get(inst.operands.at(5))); - break; + // Conversions + case OpCode::TORADIXBE: + error = trace_builder.op_to_radix_be(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + std::get(inst.operands.at(4)), + std::get(inst.operands.at(5))); + break; + + default: + throw_or_abort("Don't know how to execute opcode " + to_hex(inst.op_code) + " at pc " + + std::to_string(pc) + "."); + break; + } + } - default: - throw_or_abort("Don't know how to execute opcode " + to_hex(inst.op_code) + " at pc " + - std::to_string(pc) + "."); + if (!is_ok(error)) { + auto const error_ic = counter - 1; // Need adjustement as counter increment occurs in loop body + std::string reason_prefix = exceptionally_halted(error) ? "exceptional halt" : "REVERT opcode"; + info("AVM enqueued call halted due to ", + reason_prefix, + ". Error: ", + to_name(error), + " at PC: ", + pc, + " IC: ", + error_ic); break; } } - if (!is_ok(error)) { - info("AVM stopped due to exceptional halting condition. Error: ", - to_name(error), - " at PC: ", - pc, - " IC: ", - counter - 1); // Need adjustement as counter increment occurs in loop body + info("Phase ", to_name(phase), " reverted."); + if (phase == TxExecutionPhase::SETUP) { + info("A revert during SETUP phase halts the entire TX"); + break; + } else { + info("Rolling back tree roots to non-revertible checkpoint"); + trace_builder.rollback_to_non_revertible_checkpoint(); + } } } - auto trace = trace_builder.finalize(); + auto trace = trace_builder.finalize(apply_end_gas_assertions); show_trace_info(trace); return trace; diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/trace/execution.hpp b/barretenberg/cpp/src/barretenberg/vm/avm/trace/execution.hpp index 840b5a690df1..6573e63ab297 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/trace/execution.hpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/trace/execution.hpp @@ -13,6 +13,14 @@ namespace bb::avm_trace { +enum class TxExecutionPhase : uint32_t { + SETUP, + APP_LOGIC, + TEARDOWN, +}; + +std::string to_name(TxExecutionPhase phase); + class Execution { public: static constexpr size_t SRS_SIZE = 1 << 22; @@ -31,7 +39,8 @@ class Execution { // Eventually this will be the bytecode of the dispatch function of top-level contract static std::vector gen_trace(AvmPublicInputs const& public_inputs, std::vector& returndata, - ExecutionHints const& execution_hints); + ExecutionHints const& execution_hints, + bool apply_end_gas_assertions = false); // For testing purposes only. static void set_trace_builder_constructor(TraceBuilderConstructor constructor) diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/trace/gadgets/merkle_tree.cpp b/barretenberg/cpp/src/barretenberg/vm/avm/trace/gadgets/merkle_tree.cpp index c6dc8db7e029..4c549157fc7b 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/trace/gadgets/merkle_tree.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/trace/gadgets/merkle_tree.cpp @@ -10,6 +10,15 @@ using Poseidon2 = crypto::Poseidon2; * UNCONSTRAINED TREE OPERATIONS **************************************************************************************************/ +void AvmMerkleTreeTraceBuilder::checkpoint_non_revertible_state() +{ + non_revertible_tree_snapshots = tree_snapshots.copy(); +} +void AvmMerkleTreeTraceBuilder::rollback_to_non_revertible_checkpoint() +{ + tree_snapshots = non_revertible_tree_snapshots; +} + FF AvmMerkleTreeTraceBuilder::unconstrained_hash_nullifier_preimage(const NullifierLeafPreimage& preimage) { return Poseidon2::hash({ preimage.nullifier, preimage.next_nullifier, preimage.next_index }); diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/trace/gadgets/merkle_tree.hpp b/barretenberg/cpp/src/barretenberg/vm/avm/trace/gadgets/merkle_tree.hpp index f6948337b9bc..7b67db893138 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/trace/gadgets/merkle_tree.hpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/trace/gadgets/merkle_tree.hpp @@ -27,6 +27,9 @@ class AvmMerkleTreeTraceBuilder { void reset(); + void checkpoint_non_revertible_state(); + void rollback_to_non_revertible_checkpoint(); + bool check_membership( uint32_t clk, const FF& leaf_value, const uint64_t leaf_index, const std::vector& path, const FF& root); @@ -106,6 +109,7 @@ class AvmMerkleTreeTraceBuilder { private: std::vector merkle_check_trace; + TreeSnapshots non_revertible_tree_snapshots; TreeSnapshots tree_snapshots; MerkleEntry compute_root_from_path(uint32_t clk, const FF& leaf_value, diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/trace/helper.cpp b/barretenberg/cpp/src/barretenberg/vm/avm/trace/helper.cpp index e40a90129d54..28a540f69aaa 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/trace/helper.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/trace/helper.cpp @@ -105,6 +105,8 @@ std::string to_name(AvmError error) switch (error) { case AvmError::NO_ERROR: return "NO ERROR"; + case AvmError::REVERT_OPCODE: + return "REVERT OPCODE"; case AvmError::INVALID_PROGRAM_COUNTER: return "INVALID PROGRAM COUNTER"; case AvmError::INVALID_OPCODE: @@ -127,6 +129,10 @@ std::string to_name(AvmError error) return "CONTRACT INSTANCE MEMBER UNKNOWN"; case AvmError::RADIX_OUT_OF_BOUNDS: return "RADIX OUT OF BOUNDS"; + case AvmError::DUPLICATE_NULLIFIER: + return "DUPLICATE NULLIFIER"; + case AvmError::SIDE_EFFECT_LIMIT_REACHED: + return "SIDE EFFECT LIMIT REACHED"; default: throw std::runtime_error("Invalid error type"); break; @@ -138,6 +144,11 @@ bool is_ok(AvmError error) return error == AvmError::NO_ERROR; } +bool exceptionally_halted(AvmError error) +{ + return error != AvmError::NO_ERROR && error != AvmError::REVERT_OPCODE; +} + /** * * ONLY FOR TESTS - Required by dsl module and therefore cannot be moved to test/helpers.test.cpp diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/trace/helper.hpp b/barretenberg/cpp/src/barretenberg/vm/avm/trace/helper.hpp index 1f3b845c8e40..1ba7375276b4 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/trace/helper.hpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/trace/helper.hpp @@ -235,6 +235,7 @@ std::string to_name(bb::avm_trace::AvmMemoryTag tag); std::string to_name(AvmError error); bool is_ok(AvmError error); +bool exceptionally_halted(AvmError error); // Mutate the inputs void inject_end_gas_values(AvmPublicInputs& public_inputs, std::vector& trace); diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/trace/public_inputs.hpp b/barretenberg/cpp/src/barretenberg/vm/avm/trace/public_inputs.hpp index a26e4c4f8182..c54ad53793a9 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/trace/public_inputs.hpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/trace/public_inputs.hpp @@ -107,6 +107,27 @@ struct TreeSnapshots { return l1_to_l2_message_tree == rhs.l1_to_l2_message_tree && note_hash_tree == rhs.note_hash_tree && nullifier_tree == rhs.nullifier_tree && public_data_tree == rhs.public_data_tree; } + inline TreeSnapshots copy() + { + return { + .l1_to_l2_message_tree = { + .root = l1_to_l2_message_tree.root, + .size = l1_to_l2_message_tree.size, + }, + .note_hash_tree = { + .root = note_hash_tree.root, + .size = note_hash_tree.size, + }, + .nullifier_tree = { + .root = nullifier_tree.root, + .size = nullifier_tree.size, + }, + .public_data_tree = { + .root = public_data_tree.root, + .size = public_data_tree.size, + }, + }; + } }; inline void read(uint8_t const*& it, TreeSnapshots& tree_snapshots) diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.cpp b/barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.cpp index 97c0e1b1c022..8181ffc2b7a4 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.cpp @@ -37,6 +37,7 @@ #include "barretenberg/vm/avm/trace/opcode.hpp" #include "barretenberg/vm/avm/trace/public_inputs.hpp" #include "barretenberg/vm/avm/trace/trace.hpp" +#include "barretenberg/vm/aztec_constants.hpp" #include "barretenberg/vm/stats.hpp" namespace bb::avm_trace { @@ -135,6 +136,15 @@ bool check_tag_integral(AvmMemoryTag tag) * HELPERS **************************************************************************************************/ +void AvmTraceBuilder::checkpoint_non_revertible_state() +{ + merkle_tree_trace_builder.checkpoint_non_revertible_state(); +} +void AvmTraceBuilder::rollback_to_non_revertible_checkpoint() +{ + merkle_tree_trace_builder.rollback_to_non_revertible_checkpoint(); +} + void AvmTraceBuilder::insert_private_state(const std::vector& siloed_nullifiers, [[maybe_unused]] const std::vector& siloed_note_hashes) { @@ -2626,10 +2636,24 @@ AvmError AvmTraceBuilder::op_sload(uint8_t indirect, uint32_t slot_offset, uint3 AvmError AvmTraceBuilder::op_sstore(uint8_t indirect, uint32_t src_offset, uint32_t slot_offset) { // We keep the first encountered error + AvmError error = AvmError::NO_ERROR; auto clk = static_cast(main_trace.size()) + 1; - // We keep the first encountered error - AvmError error = AvmError::NO_ERROR; + if (storage_write_counter >= MAX_PUBLIC_DATA_UPDATE_REQUESTS_PER_TX) { + error = AvmError::SIDE_EFFECT_LIMIT_REACHED; + auto row = Row{ + .main_clk = clk, + .main_internal_return_ptr = internal_return_ptr, + .main_op_err = FF(static_cast(!is_ok(error))), + .main_pc = pc, + .main_sel_op_sstore = FF(1), + }; + gas_trace_builder.constrain_gas(clk, OpCode::SSTORE); + main_trace.push_back(row); + pc += Deserialization::get_pc_increment(OpCode::SSTORE); + return error; + } + auto [resolved_addrs, res_error] = Addressing<2>::fromWire(indirect, call_ptr).resolve({ src_offset, slot_offset }, mem_trace_builder); auto [resolved_src, resolved_slot] = resolved_addrs; @@ -2671,6 +2695,7 @@ AvmError AvmTraceBuilder::op_sstore(uint8_t indirect, uint32_t src_offset, uint3 .main_ind_addr_a = read_a.indirect_address, .main_internal_return_ptr = internal_return_ptr, .main_mem_addr_a = read_a.direct_address, // direct address incremented at end of the loop + .main_op_err = FF(static_cast(!is_ok(error))), .main_pc = pc, .main_r_in_tag = static_cast(AvmMemoryTag::FF), .main_sel_mem_op_a = 1, @@ -2789,8 +2814,24 @@ AvmError AvmTraceBuilder::op_emit_note_hash(uint8_t indirect, uint32_t note_hash { auto const clk = static_cast(main_trace.size()) + 1; + if (note_hash_write_counter >= MAX_NOTE_HASHES_PER_TX) { + AvmError error = AvmError::SIDE_EFFECT_LIMIT_REACHED; + auto row = Row{ + .main_clk = clk, + .main_internal_return_ptr = internal_return_ptr, + .main_op_err = FF(static_cast(!is_ok(error))), + .main_pc = pc, + .main_sel_op_emit_note_hash = FF(1), + }; + gas_trace_builder.constrain_gas(clk, OpCode::EMITNOTEHASH); + main_trace.push_back(row); + pc += Deserialization::get_pc_increment(OpCode::EMITNOTEHASH); + return error; + } + auto [row, error] = create_kernel_output_opcode(indirect, clk, note_hash_offset); row.main_sel_op_emit_note_hash = FF(1); + row.main_op_err = FF(static_cast(!is_ok(error))); AppendTreeHint note_hash_write_hint = execution_hints.note_hash_write_hints.at(note_hash_write_counter++); auto siloed_note_hash = AvmMerkleTreeTraceBuilder::unconstrained_silo_note_hash( @@ -2923,33 +2964,56 @@ AvmError AvmTraceBuilder::op_nullifier_exists(uint8_t indirect, AvmError AvmTraceBuilder::op_emit_nullifier(uint8_t indirect, uint32_t nullifier_offset) { + // We keep the first encountered error + AvmError error = AvmError::NO_ERROR; auto const clk = static_cast(main_trace.size()) + 1; - auto [row, error] = create_kernel_output_opcode(indirect, clk, nullifier_offset); + if (nullifier_write_counter >= MAX_NULLIFIERS_PER_TX) { + error = AvmError::SIDE_EFFECT_LIMIT_REACHED; + auto row = Row{ + .main_clk = clk, + .main_internal_return_ptr = internal_return_ptr, + .main_op_err = FF(static_cast(!is_ok(error))), + .main_pc = pc, + .main_sel_op_emit_nullifier = FF(1), + }; + gas_trace_builder.constrain_gas(clk, OpCode::EMITNULLIFIER); + main_trace.push_back(row); + pc += Deserialization::get_pc_increment(OpCode::EMITNULLIFIER); + return error; + } + + auto [row, output_error] = create_kernel_output_opcode(indirect, clk, nullifier_offset); row.main_sel_op_emit_nullifier = FF(1); + if (is_ok(error)) { + error = output_error; + } // Do merkle check FF nullifier_value = row.main_ia; FF siloed_nullifier = AvmMerkleTreeTraceBuilder::unconstrained_silo_nullifier( current_public_call_request.contract_address, nullifier_value); - // This is a little bit fragile - but we use the fact that if we traced a nullifier that already exists (which is - // invalid), we would have stored it under a read hint. - NullifierReadTreeHint nullifier_read_hint = execution_hints.nullifier_read_hints.at(nullifier_read_counter); - bool is_update = merkle_tree_trace_builder.perform_nullifier_read(clk, - nullifier_read_hint.low_leaf_preimage, - nullifier_read_hint.low_leaf_index, - nullifier_read_hint.low_leaf_sibling_path); + NullifierWriteTreeHint nullifier_write_hint = execution_hints.nullifier_write_hints.at(nullifier_write_counter++); + bool is_update = siloed_nullifier == nullifier_write_hint.low_leaf_membership.low_leaf_preimage.next_nullifier; if (is_update) { - // If we are in this branch, then the nullifier already exists in the tree - // WE NEED TO RAISE AN ERROR FLAG HERE - for now we do nothing, except increment the counter - + // hinted low-leaf points to the target nullifier, so it already exists + // prove membership of that low-leaf, which also proves membership of the target nullifier + bool exists = merkle_tree_trace_builder.perform_nullifier_read( + clk, + nullifier_write_hint.low_leaf_membership.low_leaf_preimage, + nullifier_write_hint.low_leaf_membership.low_leaf_index, + nullifier_write_hint.low_leaf_membership.low_leaf_sibling_path); + // if hinted low-leaf that skips the nullifier fails membership check, bad hint! + ASSERT(exists); nullifier_read_counter++; - error = AvmError::DUPLICATE_NULLIFIER; + // Cannot update an existing nullifier, and cannot emit a duplicate. Error! + if (is_ok(error)) { + error = AvmError::DUPLICATE_NULLIFIER; + } } else { - // This is a non-membership proof which means our insertion is valid - NullifierWriteTreeHint nullifier_write_hint = - execution_hints.nullifier_write_hints.at(nullifier_write_counter++); + // hinted low-leaf SKIPS the target nullifier, so it does NOT exist + // prove membership of the low leaf which also proves non-membership of the target nullifier merkle_tree_trace_builder.perform_nullifier_append( clk, nullifier_write_hint.low_leaf_membership.low_leaf_preimage, @@ -2959,6 +3023,8 @@ AvmError AvmTraceBuilder::op_emit_nullifier(uint8_t indirect, uint32_t nullifier nullifier_write_hint.insertion_path); } + row.main_op_err = FF(static_cast(!is_ok(error))); + // Constrain gas cost gas_trace_builder.constrain_gas(clk, OpCode::EMITNULLIFIER); @@ -3228,6 +3294,26 @@ AvmError AvmTraceBuilder::op_emit_unencrypted_log(uint8_t indirect, uint32_t log }; } + // Can't return earlier as we do elsewhere for side-effect-limit because we need + // to at least retrieve log_size first to charge proper gas. + // This means a tag error could occur before side-effect-limit first. + if (is_ok(error) && unencrypted_log_write_counter >= MAX_UNENCRYPTED_LOGS_PER_TX) { + error = AvmError::SIDE_EFFECT_LIMIT_REACHED; + auto row = Row{ + .main_clk = clk, + .main_internal_return_ptr = internal_return_ptr, + .main_op_err = FF(static_cast(!is_ok(error))), + .main_pc = pc, + .main_sel_op_emit_unencrypted_log = FF(1), + }; + // Constrain gas cost + gas_trace_builder.constrain_gas(clk, OpCode::EMITUNENCRYPTEDLOG, static_cast(log_size)); + main_trace.push_back(row); + pc += Deserialization::get_pc_increment(OpCode::EMITUNENCRYPTEDLOG); + return error; + } + unencrypted_log_write_counter++; + if (is_ok(error)) { // We need to read the rest of the log_size number of elements for (uint32_t i = 0; i < log_size; i++) { @@ -3282,14 +3368,38 @@ AvmError AvmTraceBuilder::op_emit_unencrypted_log(uint8_t indirect, uint32_t log AvmError AvmTraceBuilder::op_emit_l2_to_l1_msg(uint8_t indirect, uint32_t recipient_offset, uint32_t content_offset) { + // We keep the first encountered error + AvmError error = AvmError::NO_ERROR; auto const clk = static_cast(main_trace.size()) + 1; + if (l2_to_l1_msg_write_counter >= MAX_L2_TO_L1_MSGS_PER_TX) { + error = AvmError::SIDE_EFFECT_LIMIT_REACHED; + auto row = Row{ + .main_clk = clk, + .main_internal_return_ptr = internal_return_ptr, + .main_op_err = FF(static_cast(!is_ok(error))), + .main_pc = pc, + .main_sel_op_emit_l2_to_l1_msg = FF(1), + }; + gas_trace_builder.constrain_gas(clk, OpCode::SENDL2TOL1MSG); + main_trace.push_back(row); + pc += Deserialization::get_pc_increment(OpCode::SENDL2TOL1MSG); + return error; + } + l2_to_l1_msg_write_counter++; + // Note: unorthodox order - as seen in L2ToL1Message struct in TS - auto [row, error] = create_kernel_output_opcode_with_metadata( + auto [row, output_error] = create_kernel_output_opcode_with_metadata( indirect, clk, content_offset, AvmMemoryTag::FF, recipient_offset, AvmMemoryTag::FF); + + if (is_ok(error)) { + error = output_error; + } + // Wtite to output // kernel_trace_builder.op_emit_l2_to_l1_msg(clk, side_effect_counter, row.main_ia, row.main_ib); row.main_sel_op_emit_l2_to_l1_msg = FF(1); + row.main_op_err = FF(static_cast(!is_ok(error))); // Constrain gas cost gas_trace_builder.constrain_gas(clk, OpCode::SENDL2TOL1MSG); @@ -3610,6 +3720,10 @@ ReturnDataError AvmTraceBuilder::op_revert(uint8_t indirect, uint32_t ret_offset pc = UINT32_MAX; // This ensures that no subsequent opcode will be executed. + if (is_ok(error)) { + error = AvmError::REVERT_OPCODE; + } + // op_valid == true otherwise, ret_size == 0 and we would have returned above. return ReturnDataError{ .return_data = returndata, @@ -4327,7 +4441,7 @@ AvmError AvmTraceBuilder::op_to_radix_be(uint8_t indirect, * * @return The main trace */ -std::vector AvmTraceBuilder::finalize() +std::vector AvmTraceBuilder::finalize(bool apply_end_gas_assertions) { // Some sanity checks // Check that the final merkle tree lines up with the public inputs @@ -4596,13 +4710,15 @@ std::vector AvmTraceBuilder::finalize() gas_trace_builder.finalize(main_trace); - // Sanity check that the amount of gas consumed matches what we expect from the public inputs - auto last_l2_gas_remaining = main_trace.back().main_l2_gas_remaining; - auto expected_end_gas_l2 = public_inputs.gas_settings.gas_limits.l2_gas - public_inputs.end_gas_used.l2_gas; - ASSERT(last_l2_gas_remaining == expected_end_gas_l2); - auto last_da_gas_remaining = main_trace.back().main_da_gas_remaining; - auto expected_end_gas_da = public_inputs.gas_settings.gas_limits.da_gas - public_inputs.end_gas_used.da_gas; - ASSERT(last_da_gas_remaining == expected_end_gas_da); + if (apply_end_gas_assertions) { + // Sanity check that the amount of gas consumed matches what we expect from the public inputs + auto last_l2_gas_remaining = main_trace.back().main_l2_gas_remaining; + auto expected_end_gas_l2 = public_inputs.gas_settings.gas_limits.l2_gas - public_inputs.end_gas_used.l2_gas; + ASSERT(last_l2_gas_remaining == expected_end_gas_l2); + auto last_da_gas_remaining = main_trace.back().main_da_gas_remaining; + auto expected_end_gas_da = public_inputs.gas_settings.gas_limits.da_gas - public_inputs.end_gas_used.da_gas; + ASSERT(last_da_gas_remaining == expected_end_gas_da); + } /********************************************************************************************** * KERNEL TRACE INCLUSION diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.hpp b/barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.hpp index 83ba657b1fcc..2be1f528073d 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.hpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.hpp @@ -41,7 +41,7 @@ struct RowWithError { class AvmTraceBuilder { public: - AvmTraceBuilder(AvmPublicInputs public_inputs = {}, + AvmTraceBuilder(AvmPublicInputs public_inputs, ExecutionHints execution_hints = {}, uint32_t side_effect_counter = 0, std::vector calldata = {}); @@ -221,9 +221,11 @@ class AvmTraceBuilder { uint32_t num_limbs, uint8_t output_bits); - std::vector finalize(); + std::vector finalize(bool apply_end_gas_assertions = false); void reset(); + void checkpoint_non_revertible_state(); + void rollback_to_non_revertible_checkpoint(); void insert_private_state(const std::vector& siloed_nullifiers, const std::vector& siloed_note_hashes); // These are used for testing only. @@ -268,8 +270,10 @@ class AvmTraceBuilder { uint32_t nullifier_read_counter = 0; uint32_t nullifier_write_counter = 0; uint32_t l1_to_l2_msg_read_counter = 0; + uint32_t l2_to_l1_msg_write_counter = 0; uint32_t storage_read_counter = 0; uint32_t storage_write_counter = 0; + uint32_t unencrypted_log_write_counter = 0; // These exist due to testing only. bool range_check_required = true; diff --git a/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr b/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr index 93c07de02a27..78edfb1d6f4e 100644 --- a/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr +++ b/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr @@ -473,6 +473,42 @@ contract AvmTest { context.push_nullifier(nullifier); } + #[public] + fn n_storage_writes(num: u32) { + for i in 0..num { + context.push_nullifier(i as Field); + storage.map.at(AztecAddress::from_field(i as Field)).write(i); + } + } + + #[public] + fn n_new_note_hashes(num: u32) { + for i in 0..num { + context.push_note_hash(i as Field); + } + } + + #[public] + fn n_new_nullifiers(num: u32) { + for i in 0..num { + context.push_nullifier(i as Field); + } + } + + #[public] + fn n_new_l2_to_l1_msgs(num: u32) { + for i in 0..num { + context.message_portal(EthAddress::from_field(i as Field), i as Field) + } + } + + #[public] + fn n_new_unencrypted_logs(num: u32) { + for i in 0..num { + context.emit_unencrypted_log(/*message=*/ [i as Field]); + } + } + // Use the standard context interface to check for a nullifier #[public] fn nullifier_exists(nullifier: Field) -> bool { diff --git a/yarn-project/bb-prover/src/avm_proving.test.ts b/yarn-project/bb-prover/src/avm_proving.test.ts index 3e0ae84cf228..07024d57608f 100644 --- a/yarn-project/bb-prover/src/avm_proving.test.ts +++ b/yarn-project/bb-prover/src/avm_proving.test.ts @@ -1,4 +1,11 @@ -import { VerificationKeyData } from '@aztec/circuits.js'; +import { + MAX_L2_TO_L1_MSGS_PER_TX, + MAX_NOTE_HASHES_PER_TX, + MAX_NULLIFIERS_PER_TX, + MAX_PUBLIC_DATA_UPDATE_REQUESTS_PER_TX, + MAX_UNENCRYPTED_LOGS_PER_TX, + VerificationKeyData, +} from '@aztec/circuits.js'; import { Fr } from '@aztec/foundation/fields'; import { createDebugLogger } from '@aztec/foundation/log'; import { simulateAvmTestContractGenerateCircuitInputs } from '@aztec/simulator/public/fixtures'; @@ -10,17 +17,78 @@ import path from 'path'; import { type BBSuccess, BB_RESULT, generateAvmProof, verifyAvmProof } from './bb/execute.js'; import { extractAvmVkData } from './verification_key/verification_key_data.js'; +const TIMEOUT = 180_000; + describe('AVM WitGen, proof generation and verification', () => { - it('Should prove and verify bulk_testing', async () => { - await proveAndVerifyAvmTestContract( - 'bulk_testing', - [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10].map(x => new Fr(x)), - ); - }, 180_000); + it( + 'Should prove and verify bulk_testing', + async () => { + await proveAndVerifyAvmTestContract( + 'bulk_testing', + /*calldata=*/ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10].map(x => new Fr(x)), + ); + }, + TIMEOUT, + ); + it( + 'Should prove and verify test that performs too many storage writes and reverts', + async () => { + await proveAndVerifyAvmTestContract( + 'n_storage_writes', + /*calldata=*/ [new Fr(MAX_PUBLIC_DATA_UPDATE_REQUESTS_PER_TX + 1)], + /*expectRevert=*/ true, + ); + }, + TIMEOUT, + ); + it( + 'Should prove and verify test that creates too many note hashes and reverts', + async () => { + await proveAndVerifyAvmTestContract( + 'n_new_note_hashes', + /*calldata=*/ [new Fr(MAX_NOTE_HASHES_PER_TX + 1)], + /*expectRevert=*/ true, + ); + }, + TIMEOUT, + ); + it( + 'Should prove and verify test that creates too many nullifiers and reverts', + async () => { + await proveAndVerifyAvmTestContract( + 'n_new_nullifiers', + /*calldata=*/ [new Fr(MAX_NULLIFIERS_PER_TX + 1)], + /*expectRevert=*/ true, + ); + }, + TIMEOUT, + ); + it( + 'Should prove and verify test that creates too many l2tol1 messages and reverts', + async () => { + await proveAndVerifyAvmTestContract( + 'n_new_l2_to_l1_msgs', + /*calldata=*/ [new Fr(MAX_L2_TO_L1_MSGS_PER_TX + 1)], + /*expectRevert=*/ true, + ); + }, + TIMEOUT, + ); + it( + 'Should prove and verify test that creates too many unencrypted logs and reverts', + async () => { + await proveAndVerifyAvmTestContract( + 'n_new_unencrypted_logs', + /*calldata=*/ [new Fr(MAX_UNENCRYPTED_LOGS_PER_TX + 1)], + /*expectRevert=*/ true, + ); + }, + TIMEOUT, + ); }); -async function proveAndVerifyAvmTestContract(functionName: string, calldata: Fr[] = []) { - const avmCircuitInputs = await simulateAvmTestContractGenerateCircuitInputs(functionName, calldata); +async function proveAndVerifyAvmTestContract(functionName: string, calldata: Fr[] = [], expectRevert = false) { + const avmCircuitInputs = await simulateAvmTestContractGenerateCircuitInputs(functionName, calldata, expectRevert); const internalLogger = createDebugLogger('aztec:avm-proving-test'); const logger = (msg: string, _data?: any) => internalLogger.verbose(msg); diff --git a/yarn-project/bb-prover/src/test/index.ts b/yarn-project/bb-prover/src/test/index.ts index 555536e8cb7d..3f84ad27da1a 100644 --- a/yarn-project/bb-prover/src/test/index.ts +++ b/yarn-project/bb-prover/src/test/index.ts @@ -1,3 +1,2 @@ export * from './test_circuit_prover.js'; export * from './test_verifier.js'; -export * from './test_avm.js'; diff --git a/yarn-project/bb-prover/src/test/test_avm.ts b/yarn-project/bb-prover/src/test/test_avm.ts deleted file mode 100644 index 4cbac8bb1c4b..000000000000 --- a/yarn-project/bb-prover/src/test/test_avm.ts +++ /dev/null @@ -1,85 +0,0 @@ -import { - AztecAddress, - ContractStorageRead, - ContractStorageUpdateRequest, - Gas, - GlobalVariables, - Header, - L2ToL1Message, - LogHash, - MAX_ENQUEUED_CALLS_PER_CALL, - MAX_L1_TO_L2_MSG_READ_REQUESTS_PER_CALL, - MAX_L2_TO_L1_MSGS_PER_CALL, - MAX_NOTE_HASHES_PER_CALL, - MAX_NOTE_HASH_READ_REQUESTS_PER_CALL, - MAX_NULLIFIERS_PER_CALL, - MAX_NULLIFIER_NON_EXISTENT_READ_REQUESTS_PER_CALL, - MAX_NULLIFIER_READ_REQUESTS_PER_CALL, - MAX_PUBLIC_DATA_READS_PER_CALL, - MAX_PUBLIC_DATA_UPDATE_REQUESTS_PER_CALL, - MAX_UNENCRYPTED_LOGS_PER_CALL, - NoteHash, - Nullifier, - PublicCircuitPublicInputs, - PublicInnerCallRequest, - ReadRequest, - RevertCode, - TreeLeafReadRequest, -} from '@aztec/circuits.js'; -import { computeVarArgsHash } from '@aztec/circuits.js/hash'; -import { padArrayEnd } from '@aztec/foundation/collection'; -import { type PublicFunctionCallResult } from '@aztec/simulator'; - -// TODO: pub somewhere more usable - copied from abstract phase manager -export function getPublicInputs(result: PublicFunctionCallResult): PublicCircuitPublicInputs { - return PublicCircuitPublicInputs.from({ - callContext: result.executionRequest.callContext, - proverAddress: AztecAddress.ZERO, - argsHash: computeVarArgsHash(result.executionRequest.args), - noteHashes: padArrayEnd(result.noteHashes, NoteHash.empty(), MAX_NOTE_HASHES_PER_CALL), - nullifiers: padArrayEnd(result.nullifiers, Nullifier.empty(), MAX_NULLIFIERS_PER_CALL), - l2ToL1Msgs: padArrayEnd(result.l2ToL1Messages, L2ToL1Message.empty(), MAX_L2_TO_L1_MSGS_PER_CALL), - startSideEffectCounter: result.startSideEffectCounter, - endSideEffectCounter: result.endSideEffectCounter, - returnsHash: computeVarArgsHash(result.returnValues), - noteHashReadRequests: padArrayEnd( - result.noteHashReadRequests, - TreeLeafReadRequest.empty(), - MAX_NOTE_HASH_READ_REQUESTS_PER_CALL, - ), - nullifierReadRequests: padArrayEnd( - result.nullifierReadRequests, - ReadRequest.empty(), - MAX_NULLIFIER_READ_REQUESTS_PER_CALL, - ), - nullifierNonExistentReadRequests: padArrayEnd( - result.nullifierNonExistentReadRequests, - ReadRequest.empty(), - MAX_NULLIFIER_NON_EXISTENT_READ_REQUESTS_PER_CALL, - ), - l1ToL2MsgReadRequests: padArrayEnd( - result.l1ToL2MsgReadRequests, - TreeLeafReadRequest.empty(), - MAX_L1_TO_L2_MSG_READ_REQUESTS_PER_CALL, - ), - contractStorageReads: padArrayEnd( - result.contractStorageReads, - ContractStorageRead.empty(), - MAX_PUBLIC_DATA_READS_PER_CALL, - ), - contractStorageUpdateRequests: padArrayEnd( - result.contractStorageUpdateRequests, - ContractStorageUpdateRequest.empty(), - MAX_PUBLIC_DATA_UPDATE_REQUESTS_PER_CALL, - ), - publicCallRequests: padArrayEnd([], PublicInnerCallRequest.empty(), MAX_ENQUEUED_CALLS_PER_CALL), - unencryptedLogsHashes: padArrayEnd(result.unencryptedLogsHashes, LogHash.empty(), MAX_UNENCRYPTED_LOGS_PER_CALL), - historicalHeader: Header.empty(), - globalVariables: GlobalVariables.empty(), - startGasLeft: Gas.from(result.startGasLeft), - endGasLeft: Gas.from(result.endGasLeft), - transactionFee: result.transactionFee, - // TODO(@just-mitch): need better mapping from simulator to revert code. - revertCode: result.reverted ? RevertCode.APP_LOGIC_REVERTED : RevertCode.OK, - }); -} diff --git a/yarn-project/circuits.js/src/structs/avm/avm.ts b/yarn-project/circuits.js/src/structs/avm/avm.ts index ab16da8ce37d..e50308e2b73e 100644 --- a/yarn-project/circuits.js/src/structs/avm/avm.ts +++ b/yarn-project/circuits.js/src/structs/avm/avm.ts @@ -851,13 +851,13 @@ export class AvmExecutionHints { public readonly contractInstances: Vector; public readonly contractBytecodeHints: Vector; - public readonly storageReadRequest: Vector; - public readonly storageUpdateRequest: Vector; - public readonly nullifierReadRequest: Vector; - public readonly nullifierWriteHints: Vector; - public readonly noteHashReadRequest: Vector; - public readonly noteHashWriteRequest: Vector; - public readonly l1ToL2MessageReadRequest: Vector; + public readonly publicDataReads: Vector; + public readonly publicDataWrites: Vector; + public readonly nullifierReads: Vector; + public readonly nullifierWrites: Vector; + public readonly noteHashReads: Vector; + public readonly noteHashWrites: Vector; + public readonly l1ToL2MessageReads: Vector; constructor( enqueuedCalls: AvmEnqueuedCallHint[], @@ -868,13 +868,13 @@ export class AvmExecutionHints { externalCalls: AvmExternalCallHint[], contractInstances: AvmContractInstanceHint[], contractBytecodeHints: AvmContractBytecodeHints[], - storageReadRequest: AvmPublicDataReadTreeHint[], - storageUpdateRequest: AvmPublicDataWriteTreeHint[], - nullifierReadRequest: AvmNullifierReadTreeHint[], - nullifierWriteHints: AvmNullifierWriteTreeHint[], - noteHashReadRequest: AvmAppendTreeHint[], - noteHashWriteRequest: AvmAppendTreeHint[], - l1ToL2MessageReadRequest: AvmAppendTreeHint[], + publicDataReads: AvmPublicDataReadTreeHint[], + publicDataWrites: AvmPublicDataWriteTreeHint[], + nullifierReads: AvmNullifierReadTreeHint[], + nullifierWrites: AvmNullifierWriteTreeHint[], + noteHashReads: AvmAppendTreeHint[], + noteHashWrites: AvmAppendTreeHint[], + l1ToL2MessageReads: AvmAppendTreeHint[], ) { this.enqueuedCalls = new Vector(enqueuedCalls); this.storageValues = new Vector(storageValues); @@ -884,14 +884,13 @@ export class AvmExecutionHints { this.externalCalls = new Vector(externalCalls); this.contractInstances = new Vector(contractInstances); this.contractBytecodeHints = new Vector(contractBytecodeHints); - this.storageReadRequest = new Vector(storageReadRequest); - this.storageUpdateRequest = new Vector(storageUpdateRequest); - this.noteHashReadRequest = new Vector(noteHashReadRequest); - this.nullifierReadRequest = new Vector(nullifierReadRequest); - this.nullifierWriteHints = new Vector(nullifierWriteHints); - this.noteHashReadRequest = new Vector(noteHashReadRequest); - this.noteHashWriteRequest = new Vector(noteHashWriteRequest); - this.l1ToL2MessageReadRequest = new Vector(l1ToL2MessageReadRequest); + this.publicDataReads = new Vector(publicDataReads); + this.publicDataWrites = new Vector(publicDataWrites); + this.nullifierReads = new Vector(nullifierReads); + this.nullifierWrites = new Vector(nullifierWrites); + this.noteHashReads = new Vector(noteHashReads); + this.noteHashWrites = new Vector(noteHashWrites); + this.l1ToL2MessageReads = new Vector(l1ToL2MessageReads); } /** @@ -932,13 +931,13 @@ export class AvmExecutionHints { this.externalCalls.items.length == 0 && this.contractInstances.items.length == 0 && this.contractBytecodeHints.items.length == 0 && - this.storageReadRequest.items.length == 0 && - this.storageUpdateRequest.items.length == 0 && - this.nullifierReadRequest.items.length == 0 && - this.nullifierWriteHints.items.length == 0 && - this.noteHashReadRequest.items.length == 0 && - this.noteHashWriteRequest.items.length == 0 && - this.l1ToL2MessageReadRequest.items.length == 0 + this.publicDataReads.items.length == 0 && + this.publicDataWrites.items.length == 0 && + this.nullifierReads.items.length == 0 && + this.nullifierWrites.items.length == 0 && + this.noteHashReads.items.length == 0 && + this.noteHashWrites.items.length == 0 && + this.l1ToL2MessageReads.items.length == 0 ); } @@ -957,13 +956,13 @@ export class AvmExecutionHints { fields.externalCalls.items, fields.contractInstances.items, fields.contractBytecodeHints.items, - fields.storageReadRequest.items, - fields.storageUpdateRequest.items, - fields.nullifierReadRequest.items, - fields.nullifierWriteHints.items, - fields.noteHashReadRequest.items, - fields.noteHashWriteRequest.items, - fields.l1ToL2MessageReadRequest.items, + fields.publicDataReads.items, + fields.publicDataWrites.items, + fields.nullifierReads.items, + fields.nullifierWrites.items, + fields.noteHashReads.items, + fields.noteHashWrites.items, + fields.l1ToL2MessageReads.items, ); } @@ -982,13 +981,13 @@ export class AvmExecutionHints { fields.externalCalls, fields.contractInstances, fields.contractBytecodeHints, - fields.storageReadRequest, - fields.storageUpdateRequest, - fields.nullifierReadRequest, - fields.nullifierWriteHints, - fields.noteHashReadRequest, - fields.noteHashWriteRequest, - fields.l1ToL2MessageReadRequest, + fields.publicDataReads, + fields.publicDataWrites, + fields.nullifierReads, + fields.nullifierWrites, + fields.noteHashReads, + fields.noteHashWrites, + fields.l1ToL2MessageReads, ] as const; } diff --git a/yarn-project/circuits.js/src/tests/factories.ts b/yarn-project/circuits.js/src/tests/factories.ts index 1fce585ae3bc..7e5806736db8 100644 --- a/yarn-project/circuits.js/src/tests/factories.ts +++ b/yarn-project/circuits.js/src/tests/factories.ts @@ -1404,13 +1404,13 @@ export function makeAvmExecutionHints( externalCalls: makeVector(baseLength + 4, makeAvmExternalCallHint, seed + 0x4600), contractInstances: makeVector(baseLength + 5, makeAvmContractInstanceHint, seed + 0x4700), contractBytecodeHints: makeVector(baseLength + 6, makeAvmBytecodeHints, seed + 0x4800), - storageReadRequest: makeVector(baseLength + 7, makeAvmStorageReadTreeHints, seed + 0x4900), - storageUpdateRequest: makeVector(baseLength + 8, makeAvmStorageUpdateTreeHints, seed + 0x4a00), - nullifierReadRequest: makeVector(baseLength + 9, makeAvmNullifierReadTreeHints, seed + 0x4b00), - nullifierWriteHints: makeVector(baseLength + 10, makeAvmNullifierInsertionTreeHints, seed + 0x4c00), - noteHashReadRequest: makeVector(baseLength + 11, makeAvmTreeHints, seed + 0x4d00), - noteHashWriteRequest: makeVector(baseLength + 12, makeAvmTreeHints, seed + 0x4e00), - l1ToL2MessageReadRequest: makeVector(baseLength + 13, makeAvmTreeHints, seed + 0x4f00), + publicDataReads: makeVector(baseLength + 7, makeAvmStorageReadTreeHints, seed + 0x4900), + publicDataWrites: makeVector(baseLength + 8, makeAvmStorageUpdateTreeHints, seed + 0x4a00), + nullifierReads: makeVector(baseLength + 9, makeAvmNullifierReadTreeHints, seed + 0x4b00), + nullifierWrites: makeVector(baseLength + 10, makeAvmNullifierInsertionTreeHints, seed + 0x4c00), + noteHashReads: makeVector(baseLength + 11, makeAvmTreeHints, seed + 0x4d00), + noteHashWrites: makeVector(baseLength + 12, makeAvmTreeHints, seed + 0x4e00), + l1ToL2MessageReads: makeVector(baseLength + 13, makeAvmTreeHints, seed + 0x4f00), ...overrides, }); } diff --git a/yarn-project/simulator/src/avm/journal/journal.ts b/yarn-project/simulator/src/avm/journal/journal.ts index 9a3ffa5273a3..15ad24709d5b 100644 --- a/yarn-project/simulator/src/avm/journal/journal.ts +++ b/yarn-project/simulator/src/avm/journal/journal.ts @@ -379,6 +379,11 @@ export class AvmPersistableStateManager { // Cache pending nullifiers for later access await this.nullifiers.append(siloedNullifier); // We append the new nullifier + this.log.debug( + `Nullifier tree root before insertion ${this.merkleTrees.treeMap + .get(MerkleTreeId.NULLIFIER_TREE)! + .getRoot()}`, + ); const appendResult = await this.merkleTrees.appendNullifier(siloedNullifier); this.log.debug( `Nullifier tree root after insertion ${this.merkleTrees.treeMap.get(MerkleTreeId.NULLIFIER_TREE)!.getRoot()}`, diff --git a/yarn-project/simulator/src/public/enqueued_call_side_effect_trace.test.ts b/yarn-project/simulator/src/public/enqueued_call_side_effect_trace.test.ts index 6f84f4de2adb..d21f38dee710 100644 --- a/yarn-project/simulator/src/public/enqueued_call_side_effect_trace.test.ts +++ b/yarn-project/simulator/src/public/enqueued_call_side_effect_trace.test.ts @@ -59,7 +59,7 @@ describe('Enqueued-call Side Effect Trace', () => { expect(trace.getCounter()).toBe(startCounterPlus1); const expected = new AvmPublicDataReadTreeHint(leafPreimage, leafIndex, siblingPath); - expect(trace.getAvmCircuitHints().storageReadRequest.items).toEqual([expected]); + expect(trace.getAvmCircuitHints().publicDataReads.items).toEqual([expected]); }); it('Should trace storage writes', () => { @@ -84,14 +84,14 @@ describe('Enqueued-call Side Effect Trace', () => { const readHint = new AvmPublicDataReadTreeHint(lowLeafPreimage, lowLeafIndex, lowLeafSiblingPath); const expectedHint = new AvmPublicDataWriteTreeHint(readHint, newLeafPreimage, siblingPath); - expect(trace.getAvmCircuitHints().storageUpdateRequest.items).toEqual([expectedHint]); + expect(trace.getAvmCircuitHints().publicDataWrites.items).toEqual([expectedHint]); }); it('Should trace note hash checks', () => { const exists = true; trace.traceNoteHashCheck(address, utxo, leafIndex, exists, siblingPath); const expected = new AvmAppendTreeHint(leafIndex, utxo, siblingPath); - expect(trace.getAvmCircuitHints().noteHashReadRequest.items).toEqual([expected]); + expect(trace.getAvmCircuitHints().noteHashReads.items).toEqual([expected]); }); it('Should trace note hashes', () => { @@ -102,7 +102,7 @@ describe('Enqueued-call Side Effect Trace', () => { expect(trace.getSideEffects().noteHashes).toEqual(expected); const expectedHint = new AvmAppendTreeHint(leafIndex, utxo, siblingPath); - expect(trace.getAvmCircuitHints().noteHashWriteRequest.items).toEqual([expectedHint]); + expect(trace.getAvmCircuitHints().noteHashWrites.items).toEqual([expectedHint]); }); it('Should trace nullifier checks', () => { @@ -112,7 +112,7 @@ describe('Enqueued-call Side Effect Trace', () => { expect(trace.getCounter()).toBe(startCounterPlus1); const expected = new AvmNullifierReadTreeHint(lowLeafPreimage, leafIndex, siblingPath); - expect(trace.getAvmCircuitHints().nullifierReadRequest.items).toEqual([expected]); + expect(trace.getAvmCircuitHints().nullifierReads.items).toEqual([expected]); }); it('Should trace nullifiers', () => { @@ -125,14 +125,14 @@ describe('Enqueued-call Side Effect Trace', () => { const readHint = new AvmNullifierReadTreeHint(lowLeafPreimage, lowLeafIndex, lowLeafSiblingPath); const expectedHint = new AvmNullifierWriteTreeHint(readHint, siblingPath); - expect(trace.getAvmCircuitHints().nullifierWriteHints.items).toEqual([expectedHint]); + expect(trace.getAvmCircuitHints().nullifierWrites.items).toEqual([expectedHint]); }); it('Should trace L1ToL2 Message checks', () => { const exists = true; trace.traceL1ToL2MessageCheck(address, utxo, leafIndex, exists, siblingPath); const expected = new AvmAppendTreeHint(leafIndex, utxo, siblingPath); - expect(trace.getAvmCircuitHints().l1ToL2MessageReadRequest.items).toEqual([expected]); + expect(trace.getAvmCircuitHints().l1ToL2MessageReads.items).toEqual([expected]); }); it('Should trace new L2ToL1 messages', () => { @@ -321,13 +321,13 @@ describe('Enqueued-call Side Effect Trace', () => { expect(parentHints.externalCalls.items).toEqual(childHints.externalCalls.items); expect(parentHints.contractInstances.items).toEqual(childHints.contractInstances.items); expect(parentHints.contractBytecodeHints.items).toEqual(childHints.contractBytecodeHints.items); - expect(parentHints.storageReadRequest.items).toEqual(childHints.storageReadRequest.items); - expect(parentHints.storageUpdateRequest.items).toEqual(childHints.storageUpdateRequest.items); - expect(parentHints.nullifierReadRequest.items).toEqual(childHints.nullifierReadRequest.items); - expect(parentHints.nullifierWriteHints.items).toEqual(childHints.nullifierWriteHints.items); - expect(parentHints.noteHashReadRequest.items).toEqual(childHints.noteHashReadRequest.items); - expect(parentHints.noteHashWriteRequest.items).toEqual(childHints.noteHashWriteRequest.items); - expect(parentHints.l1ToL2MessageReadRequest.items).toEqual(childHints.l1ToL2MessageReadRequest.items); + expect(parentHints.publicDataReads.items).toEqual(childHints.publicDataReads.items); + expect(parentHints.publicDataWrites.items).toEqual(childHints.publicDataWrites.items); + expect(parentHints.nullifierReads.items).toEqual(childHints.nullifierReads.items); + expect(parentHints.nullifierWrites.items).toEqual(childHints.nullifierWrites.items); + expect(parentHints.noteHashReads.items).toEqual(childHints.noteHashReads.items); + expect(parentHints.noteHashWrites.items).toEqual(childHints.noteHashWrites.items); + expect(parentHints.l1ToL2MessageReads.items).toEqual(childHints.l1ToL2MessageReads.items); }); }); }); diff --git a/yarn-project/simulator/src/public/enqueued_call_side_effect_trace.ts b/yarn-project/simulator/src/public/enqueued_call_side_effect_trace.ts index 84e85adcd640..a7e24ac55208 100644 --- a/yarn-project/simulator/src/public/enqueued_call_side_effect_trace.ts +++ b/yarn-project/simulator/src/public/enqueued_call_side_effect_trace.ts @@ -179,15 +179,13 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI this.avmCircuitHints.contractInstances.items.push(...forkedTrace.avmCircuitHints.contractInstances.items); this.avmCircuitHints.contractBytecodeHints.items.push(...forkedTrace.avmCircuitHints.contractBytecodeHints.items); - this.avmCircuitHints.storageReadRequest.items.push(...forkedTrace.avmCircuitHints.storageReadRequest.items); - this.avmCircuitHints.storageUpdateRequest.items.push(...forkedTrace.avmCircuitHints.storageUpdateRequest.items); - this.avmCircuitHints.nullifierReadRequest.items.push(...forkedTrace.avmCircuitHints.nullifierReadRequest.items); - this.avmCircuitHints.nullifierWriteHints.items.push(...forkedTrace.avmCircuitHints.nullifierWriteHints.items); - this.avmCircuitHints.noteHashReadRequest.items.push(...forkedTrace.avmCircuitHints.noteHashReadRequest.items); - this.avmCircuitHints.noteHashWriteRequest.items.push(...forkedTrace.avmCircuitHints.noteHashWriteRequest.items); - this.avmCircuitHints.l1ToL2MessageReadRequest.items.push( - ...forkedTrace.avmCircuitHints.l1ToL2MessageReadRequest.items, - ); + this.avmCircuitHints.publicDataReads.items.push(...forkedTrace.avmCircuitHints.publicDataReads.items); + this.avmCircuitHints.publicDataWrites.items.push(...forkedTrace.avmCircuitHints.publicDataWrites.items); + this.avmCircuitHints.nullifierReads.items.push(...forkedTrace.avmCircuitHints.nullifierReads.items); + this.avmCircuitHints.nullifierWrites.items.push(...forkedTrace.avmCircuitHints.nullifierWrites.items); + this.avmCircuitHints.noteHashReads.items.push(...forkedTrace.avmCircuitHints.noteHashReads.items); + this.avmCircuitHints.noteHashWrites.items.push(...forkedTrace.avmCircuitHints.noteHashWrites.items); + this.avmCircuitHints.l1ToL2MessageReads.items.push(...forkedTrace.avmCircuitHints.l1ToL2MessageReads.items); } public getCounter() { @@ -211,7 +209,7 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI assert(leafPreimage.value.equals(value), 'Value mismatch when tracing in public data write'); } - this.avmCircuitHints.storageReadRequest.items.push(new AvmPublicDataReadTreeHint(leafPreimage, leafIndex, path)); + this.avmCircuitHints.publicDataReads.items.push(new AvmPublicDataReadTreeHint(leafPreimage, leafIndex, path)); this.log.debug(`SLOAD cnt: ${this.sideEffectCounter} val: ${value} slot: ${slot}`); this.incrementSideEffectCounter(); } @@ -245,7 +243,7 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI // New hinting const readHint = new AvmPublicDataReadTreeHint(lowLeafPreimage, lowLeafIndex, lowLeafPath); - this.avmCircuitHints.storageUpdateRequest.items.push( + this.avmCircuitHints.publicDataWrites.items.push( new AvmPublicDataWriteTreeHint(readHint, newLeafPreimage, insertionPath), ); @@ -264,7 +262,7 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI path: Fr[] = emptyNoteHashPath(), ) { // New Hinting - this.avmCircuitHints.noteHashReadRequest.items.push(new AvmAppendTreeHint(leafIndex, noteHash, path)); + this.avmCircuitHints.noteHashReads.items.push(new AvmAppendTreeHint(leafIndex, noteHash, path)); // NOTE: counter does not increment for note hash checks (because it doesn't rely on pending note hashes) } @@ -282,7 +280,7 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI //const siloedNoteHash = siloNoteHash(contractAddress, noteHash); this.noteHashes.push(new NoteHash(noteHash, this.sideEffectCounter).scope(contractAddress)); this.log.debug(`NEW_NOTE_HASH cnt: ${this.sideEffectCounter}`); - this.avmCircuitHints.noteHashWriteRequest.items.push(new AvmAppendTreeHint(leafIndex, noteHash, path)); + this.avmCircuitHints.noteHashWrites.items.push(new AvmAppendTreeHint(leafIndex, noteHash, path)); this.incrementSideEffectCounter(); } @@ -293,7 +291,7 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI lowLeafIndex: Fr = Fr.zero(), lowLeafPath: Fr[] = emptyNullifierPath(), ) { - this.avmCircuitHints.nullifierReadRequest.items.push( + this.avmCircuitHints.nullifierReads.items.push( new AvmNullifierReadTreeHint(lowLeafPreimage, lowLeafIndex, lowLeafPath), ); this.log.debug(`NULLIFIER_EXISTS cnt: ${this.sideEffectCounter}`); @@ -314,7 +312,7 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI this.nullifiers.push(new Nullifier(siloedNullifier, this.sideEffectCounter, /*noteHash=*/ Fr.ZERO)); const lowLeafReadHint = new AvmNullifierReadTreeHint(lowLeafPreimage, lowLeafIndex, lowLeafPath); - this.avmCircuitHints.nullifierWriteHints.items.push(new AvmNullifierWriteTreeHint(lowLeafReadHint, insertionPath)); + this.avmCircuitHints.nullifierWrites.items.push(new AvmNullifierWriteTreeHint(lowLeafReadHint, insertionPath)); this.log.debug(`NEW_NULLIFIER cnt: ${this.sideEffectCounter}`); this.incrementSideEffectCounter(); } @@ -327,7 +325,7 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI _exists: boolean, path: Fr[] = emptyL1ToL2MessagePath(), ) { - this.avmCircuitHints.l1ToL2MessageReadRequest.items.push(new AvmAppendTreeHint(msgLeafIndex, msgHash, path)); + this.avmCircuitHints.l1ToL2MessageReads.items.push(new AvmAppendTreeHint(msgLeafIndex, msgHash, path)); } public traceNewL2ToL1Message(contractAddress: AztecAddress, recipient: Fr, content: Fr) { diff --git a/yarn-project/simulator/src/public/fixtures/index.ts b/yarn-project/simulator/src/public/fixtures/index.ts index 512cbf93d30b..68eec22d6e13 100644 --- a/yarn-project/simulator/src/public/fixtures/index.ts +++ b/yarn-project/simulator/src/public/fixtures/index.ts @@ -34,14 +34,10 @@ import { MerkleTrees } from '@aztec/world-state'; import { strict as assert } from 'assert'; -/** - * If assertionErrString is set, we expect a (non exceptional halting) revert due to a failing assertion and - * we check that the revert reason error contains this string. However, the circuit must correctly prove the - * execution. - */ export async function simulateAvmTestContractGenerateCircuitInputs( functionName: string, calldata: Fr[] = [], + expectRevert: boolean = false, assertionErrString?: string, ): Promise { const sender = AztecAddress.random(); @@ -80,13 +76,15 @@ export async function simulateAvmTestContractGenerateCircuitInputs( const avmResult = await simulator.simulate(tx); - if (assertionErrString == undefined) { + if (!expectRevert) { expect(avmResult.revertCode.isOK()).toBe(true); } else { // Explicit revert when an assertion failed. expect(avmResult.revertCode.isOK()).toBe(false); expect(avmResult.revertReason).toBeDefined(); - expect(avmResult.revertReason?.getMessage()).toContain(assertionErrString); + if (assertionErrString !== undefined) { + expect(avmResult.revertReason?.getMessage()).toContain(assertionErrString); + } } const avmCircuitInputs: AvmCircuitInputs = avmResult.avmProvingRequest.inputs; diff --git a/yarn-project/simulator/src/public/public_tx_context.ts b/yarn-project/simulator/src/public/public_tx_context.ts index 13ab1ad6d8c2..94057597a186 100644 --- a/yarn-project/simulator/src/public/public_tx_context.ts +++ b/yarn-project/simulator/src/public/public_tx_context.ts @@ -91,7 +91,7 @@ export class PublicTxContext { const previousAccumulatedDataArrayLengths = new SideEffectArrayLengths( /*publicDataWrites*/ 0, countAccumulatedItems(nonRevertibleAccumulatedDataFromPrivate.noteHashes), - countAccumulatedItems(nonRevertibleAccumulatedDataFromPrivate.nullifiers), + /*nullifiers=*/ 0, countAccumulatedItems(nonRevertibleAccumulatedDataFromPrivate.l2ToL1Msgs), /*unencryptedLogsHashes*/ 0, ); diff --git a/yarn-project/simulator/src/public/side_effect_trace.ts b/yarn-project/simulator/src/public/side_effect_trace.ts index 474e3ff155dd..8e9f93256d07 100644 --- a/yarn-project/simulator/src/public/side_effect_trace.ts +++ b/yarn-project/simulator/src/public/side_effect_trace.ts @@ -138,7 +138,7 @@ export class PublicSideEffectTrace implements PublicSideEffectTraceInterface { ); // New hinting - this.avmCircuitHints.storageReadRequest.items.push(new AvmPublicDataReadTreeHint(leafPreimage, leafIndex, path)); + this.avmCircuitHints.publicDataReads.items.push(new AvmPublicDataReadTreeHint(leafPreimage, leafIndex, path)); this.log.debug(`SLOAD cnt: ${this.sideEffectCounter} val: ${value} slot: ${slot}`); this.incrementSideEffectCounter(); @@ -168,7 +168,7 @@ export class PublicSideEffectTrace implements PublicSideEffectTraceInterface { // New hinting const readHint = new AvmPublicDataReadTreeHint(lowLeafPreimage, lowLeafIndex, lowLeafPath); - this.avmCircuitHints.storageUpdateRequest.items.push( + this.avmCircuitHints.publicDataWrites.items.push( new AvmPublicDataWriteTreeHint(readHint, newLeafPreimage, insertionPath), ); this.log.debug(`SSTORE cnt: ${this.sideEffectCounter} val: ${value} slot: ${slot}`); @@ -193,7 +193,7 @@ export class PublicSideEffectTrace implements PublicSideEffectTraceInterface { new AvmKeyValueHint(/*key=*/ new Fr(leafIndex), /*value=*/ exists ? Fr.ONE : Fr.ZERO), ); // New Hinting - this.avmCircuitHints.noteHashReadRequest.items.push(new AvmAppendTreeHint(leafIndex, noteHash, path)); + this.avmCircuitHints.noteHashReads.items.push(new AvmAppendTreeHint(leafIndex, noteHash, path)); // NOTE: counter does not increment for note hash checks (because it doesn't rely on pending note hashes) } @@ -210,7 +210,7 @@ export class PublicSideEffectTrace implements PublicSideEffectTraceInterface { this.log.debug(`NEW_NOTE_HASH cnt: ${this.sideEffectCounter}`); // New Hinting - this.avmCircuitHints.noteHashWriteRequest.items.push(new AvmAppendTreeHint(leafIndex, noteHash, path)); + this.avmCircuitHints.noteHashWrites.items.push(new AvmAppendTreeHint(leafIndex, noteHash, path)); this.incrementSideEffectCounter(); } @@ -237,7 +237,7 @@ export class PublicSideEffectTrace implements PublicSideEffectTraceInterface { ); // New Hints - this.avmCircuitHints.nullifierReadRequest.items.push( + this.avmCircuitHints.nullifierReads.items.push( new AvmNullifierReadTreeHint(lowLeafPreimage, lowLeafIndex, lowLeafPath), ); this.log.debug(`NULLIFIER_EXISTS cnt: ${this.sideEffectCounter}`); @@ -259,7 +259,7 @@ export class PublicSideEffectTrace implements PublicSideEffectTraceInterface { this.nullifiers.push(new Nullifier(siloedNullifier, this.sideEffectCounter, /*noteHash=*/ Fr.ZERO)); // New hinting const lowLeafReadHint = new AvmNullifierReadTreeHint(lowLeafPreimage, lowLeafIndex, lowLeafPath); - this.avmCircuitHints.nullifierWriteHints.items.push(new AvmNullifierWriteTreeHint(lowLeafReadHint, insertionPath)); + this.avmCircuitHints.nullifierWrites.items.push(new AvmNullifierWriteTreeHint(lowLeafReadHint, insertionPath)); this.log.debug(`NEW_NULLIFIER cnt: ${this.sideEffectCounter}`); this.incrementSideEffectCounter(); } @@ -282,7 +282,7 @@ export class PublicSideEffectTrace implements PublicSideEffectTraceInterface { ); // New Hinting - this.avmCircuitHints.l1ToL2MessageReadRequest.items.push(new AvmAppendTreeHint(msgLeafIndex, msgHash, path)); + this.avmCircuitHints.l1ToL2MessageReads.items.push(new AvmAppendTreeHint(msgLeafIndex, msgHash, path)); // NOTE: counter does not increment for l1tol2 message checks (because it doesn't rely on pending messages) } diff --git a/yarn-project/simulator/src/public/transitional_adapters.ts b/yarn-project/simulator/src/public/transitional_adapters.ts index 398e530e3fd3..09ec0094110f 100644 --- a/yarn-project/simulator/src/public/transitional_adapters.ts +++ b/yarn-project/simulator/src/public/transitional_adapters.ts @@ -1,57 +1,28 @@ -import { type AvmProvingRequest, ProvingRequestType, type PublicExecutionRequest } from '@aztec/circuit-types'; import { - AvmCircuitInputs, - AvmCircuitPublicInputs, - AztecAddress, - ContractStorageRead, - ContractStorageUpdateRequest, - Fr, - Gas, + type AvmCircuitPublicInputs, + type Fr, + type Gas, type GasSettings, type GlobalVariables, - type Header, - L2ToL1Message, - LogHash, - MAX_ENQUEUED_CALLS_PER_CALL, - MAX_L1_TO_L2_MSG_READ_REQUESTS_PER_CALL, - MAX_L2_TO_L1_MSGS_PER_CALL, MAX_L2_TO_L1_MSGS_PER_TX, - MAX_NOTE_HASHES_PER_CALL, MAX_NOTE_HASHES_PER_TX, - MAX_NOTE_HASH_READ_REQUESTS_PER_CALL, - MAX_NULLIFIERS_PER_CALL, - MAX_NULLIFIER_NON_EXISTENT_READ_REQUESTS_PER_CALL, - MAX_NULLIFIER_READ_REQUESTS_PER_CALL, - MAX_PUBLIC_DATA_READS_PER_CALL, - MAX_PUBLIC_DATA_UPDATE_REQUESTS_PER_CALL, MAX_PUBLIC_DATA_UPDATE_REQUESTS_PER_TX, - MAX_UNENCRYPTED_LOGS_PER_CALL, - NoteHash, - Nullifier, PrivateToAvmAccumulatedData, PrivateToAvmAccumulatedDataArrayLengths, type PrivateToPublicAccumulatedData, PublicCallRequest, - PublicCircuitPublicInputs, PublicDataWrite, - PublicInnerCallRequest, - ReadRequest, - RevertCode, + type RevertCode, type StateReference, - TreeLeafReadRequest, TreeSnapshots, countAccumulatedItems, mergeAccumulatedData, } from '@aztec/circuits.js'; -import { computeNoteHashNonce, computeUniqueNoteHash, computeVarArgsHash, siloNoteHash } from '@aztec/circuits.js/hash'; +import { computeNoteHashNonce, computeUniqueNoteHash, siloNoteHash } from '@aztec/circuits.js/hash'; import { padArrayEnd } from '@aztec/foundation/collection'; import { assertLength } from '@aztec/foundation/serialize'; -import { AvmFinalizedCallResult } from '../avm/avm_contract_call_result.js'; -import { AvmExecutionEnvironment } from '../avm/avm_execution_environment.js'; -import { type AvmPersistableStateManager } from '../avm/journal/journal.js'; import { type PublicEnqueuedCallSideEffectTrace } from './enqueued_call_side_effect_trace.js'; -import { type EnqueuedPublicCallExecutionResult, type PublicFunctionCallResult } from './execution.js'; export function generateAvmCircuitPublicInputs( trace: PublicEnqueuedCallSideEffectTrace, @@ -176,155 +147,3 @@ export function generateAvmCircuitPublicInputs( //console.log(`AvmCircuitPublicInputs:\n${inspect(avmCircuitPublicInputs)}`); return avmCircuitPublicInputs; } - -export function generateAvmProvingRequest( - real: boolean, - fnName: string, - stateManager: AvmPersistableStateManager, - historicalHeader: Header, - globalVariables: GlobalVariables, - executionRequest: PublicExecutionRequest, - result: EnqueuedPublicCallExecutionResult, - allocatedGas: Gas, - transactionFee: Fr, -): AvmProvingRequest { - const avmExecutionEnv = new AvmExecutionEnvironment( - executionRequest.callContext.contractAddress, - executionRequest.callContext.msgSender, - executionRequest.callContext.functionSelector, - /*contractCallDepth=*/ Fr.zero(), - transactionFee, - globalVariables, - executionRequest.callContext.isStaticCall, - executionRequest.args, - ); - - const avmCallResult = new AvmFinalizedCallResult(result.reverted, result.returnValues, result.endGasLeft); - - // Generate an AVM proving request - let avmProvingRequest: AvmProvingRequest; - if (real) { - const deprecatedFunctionCallResult = stateManager.trace.toPublicFunctionCallResult( - avmExecutionEnv, - /*startGasLeft=*/ allocatedGas, - Buffer.alloc(0), - avmCallResult, - fnName, - ); - const publicInputs = getPublicCircuitPublicInputs(historicalHeader, globalVariables, deprecatedFunctionCallResult); - avmProvingRequest = makeAvmProvingRequest(publicInputs, deprecatedFunctionCallResult); - } else { - avmProvingRequest = emptyAvmProvingRequest(); - } - return avmProvingRequest; -} - -function emptyAvmProvingRequest(): AvmProvingRequest { - return { - type: ProvingRequestType.PUBLIC_VM, - inputs: AvmCircuitInputs.empty(), - }; -} -function makeAvmProvingRequest(inputs: PublicCircuitPublicInputs, result: PublicFunctionCallResult): AvmProvingRequest { - return { - type: ProvingRequestType.PUBLIC_VM, - inputs: new AvmCircuitInputs( - result.functionName, - result.calldata, - inputs, - result.avmCircuitHints, - AvmCircuitPublicInputs.empty(), - ), - }; -} - -function getPublicCircuitPublicInputs( - historicalHeader: Header, - globalVariables: GlobalVariables, - result: PublicFunctionCallResult, -) { - const header = historicalHeader.clone(); // don't modify the original - header.state.partial.publicDataTree.root = Fr.zero(); // AVM doesn't check this yet - - return PublicCircuitPublicInputs.from({ - callContext: result.executionRequest.callContext, - proverAddress: AztecAddress.ZERO, - argsHash: computeVarArgsHash(result.executionRequest.args), - noteHashes: padArrayEnd( - result.noteHashes, - NoteHash.empty(), - MAX_NOTE_HASHES_PER_CALL, - `Too many note hashes. Got ${result.noteHashes.length} with max being ${MAX_NOTE_HASHES_PER_CALL}`, - ), - nullifiers: padArrayEnd( - result.nullifiers, - Nullifier.empty(), - MAX_NULLIFIERS_PER_CALL, - `Too many nullifiers. Got ${result.nullifiers.length} with max being ${MAX_NULLIFIERS_PER_CALL}`, - ), - l2ToL1Msgs: padArrayEnd( - result.l2ToL1Messages, - L2ToL1Message.empty(), - MAX_L2_TO_L1_MSGS_PER_CALL, - `Too many L2 to L1 messages. Got ${result.l2ToL1Messages.length} with max being ${MAX_L2_TO_L1_MSGS_PER_CALL}`, - ), - startSideEffectCounter: result.startSideEffectCounter, - endSideEffectCounter: result.endSideEffectCounter, - returnsHash: computeVarArgsHash(result.returnValues), - noteHashReadRequests: padArrayEnd( - result.noteHashReadRequests, - TreeLeafReadRequest.empty(), - MAX_NOTE_HASH_READ_REQUESTS_PER_CALL, - `Too many note hash read requests. Got ${result.noteHashReadRequests.length} with max being ${MAX_NOTE_HASH_READ_REQUESTS_PER_CALL}`, - ), - nullifierReadRequests: padArrayEnd( - result.nullifierReadRequests, - ReadRequest.empty(), - MAX_NULLIFIER_READ_REQUESTS_PER_CALL, - `Too many nullifier read requests. Got ${result.nullifierReadRequests.length} with max being ${MAX_NULLIFIER_READ_REQUESTS_PER_CALL}`, - ), - nullifierNonExistentReadRequests: padArrayEnd( - result.nullifierNonExistentReadRequests, - ReadRequest.empty(), - MAX_NULLIFIER_NON_EXISTENT_READ_REQUESTS_PER_CALL, - `Too many nullifier non-existent read requests. Got ${result.nullifierNonExistentReadRequests.length} with max being ${MAX_NULLIFIER_NON_EXISTENT_READ_REQUESTS_PER_CALL}`, - ), - l1ToL2MsgReadRequests: padArrayEnd( - result.l1ToL2MsgReadRequests, - TreeLeafReadRequest.empty(), - MAX_L1_TO_L2_MSG_READ_REQUESTS_PER_CALL, - `Too many L1 to L2 message read requests. Got ${result.l1ToL2MsgReadRequests.length} with max being ${MAX_L1_TO_L2_MSG_READ_REQUESTS_PER_CALL}`, - ), - contractStorageReads: padArrayEnd( - result.contractStorageReads, - ContractStorageRead.empty(), - MAX_PUBLIC_DATA_READS_PER_CALL, - `Too many public data reads. Got ${result.contractStorageReads.length} with max being ${MAX_PUBLIC_DATA_READS_PER_CALL}`, - ), - contractStorageUpdateRequests: padArrayEnd( - result.contractStorageUpdateRequests, - ContractStorageUpdateRequest.empty(), - MAX_PUBLIC_DATA_UPDATE_REQUESTS_PER_CALL, - `Too many public data update requests. Got ${result.contractStorageUpdateRequests.length} with max being ${MAX_PUBLIC_DATA_UPDATE_REQUESTS_PER_CALL}`, - ), - publicCallRequests: padArrayEnd( - result.publicCallRequests, - PublicInnerCallRequest.empty(), - MAX_ENQUEUED_CALLS_PER_CALL, - `Too many public call requests. Got ${result.publicCallRequests.length} with max being ${MAX_ENQUEUED_CALLS_PER_CALL}`, - ), - unencryptedLogsHashes: padArrayEnd( - result.unencryptedLogsHashes, - LogHash.empty(), - MAX_UNENCRYPTED_LOGS_PER_CALL, - `Too many unencrypted logs. Got ${result.unencryptedLogsHashes.length} with max being ${MAX_UNENCRYPTED_LOGS_PER_CALL}`, - ), - historicalHeader: header, - globalVariables: globalVariables, - startGasLeft: Gas.from(result.startGasLeft), - endGasLeft: Gas.from(result.endGasLeft), - transactionFee: result.transactionFee, - // TODO(@just-mitch): need better mapping from simulator to revert code. - revertCode: result.reverted ? RevertCode.APP_LOGIC_REVERTED : RevertCode.OK, - }); -}