Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(avm): re-enable proof in some unit tests #6056

Merged
merged 2 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions barretenberg/cpp/src/barretenberg/bb/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -556,17 +556,14 @@ void avm_prove(const std::filesystem::path& bytecode_path,
bool avm_verify(const std::filesystem::path& proof_path)
{
std::filesystem::path vk_path = proof_path.parent_path() / "vk";

// Actual verification temporarily stopped (#4954)
// std::vector<fr> const proof = many_from_buffer<fr>(read_file(proof_path));
//
// std::vector<uint8_t> vk_bytes = read_file(vk_path);
// auto circuit_size = from_buffer<size_t>(vk_bytes, 0);
// auto _num_public_inputs = from_buffer<size_t>(vk_bytes, sizeof(size_t));
// auto vk = AvmFlavor::VerificationKey(circuit_size, num_public_inputs);
//
// std::cout << avm_trace::Execution::verify(vk, proof);
// return avm_trace::Execution::verify(vk, proof);
std::vector<fr> const proof = many_from_buffer<fr>(read_file(proof_path));
std::vector<uint8_t> vk_bytes = read_file(vk_path);
auto circuit_size = from_buffer<size_t>(vk_bytes, 0);
auto num_public_inputs = from_buffer<size_t>(vk_bytes, sizeof(size_t));
auto vk = AvmFlavor::VerificationKey(circuit_size, num_public_inputs);

std::cout << avm_trace::Execution::verify(vk, proof);
return avm_trace::Execution::verify(vk, proof);

std::cout << 1;
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,8 @@ namespace bb::avm_trace {
* @param bytecode A vector of bytes representing the bytecode to execute.
* @param calldata expressed as a vector of finite field elements.
* @throws runtime_error exception when the bytecode is invalid.
* @return A zk proof of the execution.
* @return The verifier key and zk proof of the execution.
*/
HonkProof Execution::run_and_prove(std::vector<uint8_t> const& bytecode, std::vector<FF> const& calldata)
{
auto instructions = Deserialization::parse(bytecode);
auto trace = gen_trace(instructions, calldata);
auto circuit_builder = bb::AvmCircuitBuilder();
circuit_builder.set_trace(std::move(trace));

auto composer = AvmComposer();
auto prover = composer.create_prover(circuit_builder);
auto verifier = composer.create_verifier(circuit_builder);
auto proof = prover.construct_proof();
return proof;
}

std::tuple<AvmFlavor::VerificationKey, HonkProof> Execution::prove(std::vector<uint8_t> const& bytecode,
std::vector<FF> const& calldata)
{
Expand All @@ -51,9 +37,6 @@ std::tuple<AvmFlavor::VerificationKey, HonkProof> Execution::prove(std::vector<u
auto circuit_builder = bb::AvmCircuitBuilder();
circuit_builder.set_trace(std::move(trace));

// Temporarily use this until #4954 is resolved
assert(circuit_builder.check_circuit());

auto composer = AvmComposer();
auto prover = composer.create_prover(circuit_builder);
auto verifier = composer.create_verifier(circuit_builder);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ class Execution {

static std::vector<Row> gen_trace(std::vector<Instruction> const& instructions,
std::vector<FF> const& calldata = {});
static bb::HonkProof run_and_prove(std::vector<uint8_t> const& bytecode, std::vector<FF> const& calldata = {});

static std::tuple<AvmFlavor::VerificationKey, bb::HonkProof> prove(std::vector<uint8_t> const& bytecode,
std::vector<FF> const& calldata = {});
static bool verify(AvmFlavor::VerificationKey vk, HonkProof const& proof);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ TEST_F(AvmArithmeticTestsFF, mixedOperationsWithError)
trace_builder.halt();

auto trace = trace_builder.finalize();
validate_trace(std::move(trace));
validate_trace(std::move(trace), true);
}

// Test of equality on FF elements
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ TEST_P(AvmBitwiseTestsAnd, AllAndTest)
FF ff_b = FF(uint256_t::from_uint128(b));
FF ff_output = FF(uint256_t::from_uint128(output));
common_validate_bit_op(trace, 0, ff_a, ff_b, ff_output, FF(0), FF(1), FF(2), mem_tag);
validate_trace(std::move(trace));
validate_trace(std::move(trace), true);
}
INSTANTIATE_TEST_SUITE_P(AvmBitwiseTests,
AvmBitwiseTestsAnd,
Expand Down
13 changes: 10 additions & 3 deletions barretenberg/cpp/src/barretenberg/vm/tests/avm_cast.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ class AvmCastTests : public ::testing::Test {
uint32_t src_address,
uint32_t dst_address,
AvmMemoryTag src_tag,
AvmMemoryTag dst_tag
AvmMemoryTag dst_tag,
bool force_proof = true

)
{
Expand Down Expand Up @@ -102,7 +103,13 @@ class AvmCastTests : public ::testing::Test {
alu_row_next,
AllOf(Field("op_cast", &Row::avm_alu_op_cast, 0), Field("op_cast_prev", &Row::avm_alu_op_cast_prev, 1)));

validate_trace(std::move(trace));
// We still want the ability to enable proving through the environment variable and therefore we do not pass
// the boolean variable force_proof to validate_trace second argument.
if (force_proof) {
validate_trace(std::move(trace), true);
} else {
validate_trace(std::move(trace));
}
}
};

Expand Down Expand Up @@ -190,7 +197,7 @@ TEST_F(AvmCastTests, indirectAddrTruncationU64ToU8)
trace = trace_builder.finalize();
gen_indices();

validate_cast_trace(256'000'000'203UL, 203, 10, 11, AvmMemoryTag::U64, AvmMemoryTag::U8);
validate_cast_trace(256'000'000'203UL, 203, 10, 11, AvmMemoryTag::U64, AvmMemoryTag::U8, true);
}

TEST_F(AvmCastTests, indirectAddrWrongResolutionU64ToU8)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <cstddef>
#include <cstdint>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <string>
#include <vector>
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ TEST_F(AvmControlFlowTests, simpleCall)
EXPECT_EQ(halt_row->avm_main_pc, FF(CALL_ADDRESS));
EXPECT_EQ(halt_row->avm_main_internal_return_ptr, FF(AvmTraceBuilder::CALLSTACK_OFFSET + 1));
}
validate_trace(std::move(trace));
validate_trace(std::move(trace), true);
}

TEST_F(AvmControlFlowTests, simpleJump)
Expand Down
46 changes: 10 additions & 36 deletions barretenberg/cpp/src/barretenberg/vm/tests/avm_execution.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,6 @@
#include "barretenberg/vm/avm_trace/avm_common.hpp"
#include "barretenberg/vm/avm_trace/avm_deserialization.hpp"
#include "barretenberg/vm/avm_trace/avm_opcode.hpp"
#include <cstdint>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <string>
#include <utility>

namespace tests_avm {

Expand All @@ -17,26 +12,6 @@ using namespace testing;

using bb::utils::hex_to_bytes;

namespace {

void gen_proof_and_validate(std::vector<uint8_t> const& bytecode,
std::vector<Row>&& trace,
std::vector<FF> const& calldata)
{
auto circuit_builder = AvmCircuitBuilder();
circuit_builder.set_trace(std::move(trace));
EXPECT_TRUE(circuit_builder.check_circuit());

auto composer = AvmComposer();
auto verifier = composer.create_verifier(circuit_builder);

auto proof = avm_trace::Execution::run_and_prove(bytecode, calldata);

// TODO(#4944): uncomment the following line to revive full verification
// EXPECT_TRUE(verifier.verify_proof(proof));
}
} // namespace

class AvmExecutionTests : public ::testing::Test {
public:
AvmTraceBuilder trace_builder;
Expand Down Expand Up @@ -84,7 +59,7 @@ TEST_F(AvmExecutionTests, basicAddReturn)
ElementsAre(VariantWith<uint8_t>(0), VariantWith<uint32_t>(0), VariantWith<uint32_t>(0)))));

auto trace = Execution::gen_trace(instructions);
gen_proof_and_validate(bytecode, std::move(trace), {});
validate_trace(std::move(trace), true);
}

// Positive test for SET and SUB opcodes
Expand Down Expand Up @@ -149,8 +124,7 @@ TEST_F(AvmExecutionTests, setAndSubOpcodes)
// Find the first row enabling the subtraction selector
auto row = std::ranges::find_if(trace.begin(), trace.end(), [](Row r) { return r.avm_main_sel_op_sub == 1; });
EXPECT_EQ(row->avm_main_ic, 10000); // 47123 - 37123 = 10000

gen_proof_and_validate(bytecode, std::move(trace), {});
validate_trace(std::move(trace), true);
}

// Positive test for multiple MUL opcodes
Expand Down Expand Up @@ -230,7 +204,7 @@ TEST_F(AvmExecutionTests, powerWithMulOpcodes)
trace.begin(), trace.end(), [](Row r) { return r.avm_main_sel_op_mul == 1 && r.avm_main_pc == 13; });
EXPECT_EQ(row->avm_main_ic, 244140625); // 5^12 = 244140625

gen_proof_and_validate(bytecode, std::move(trace), {});
validate_trace(std::move(trace));
}

// Positive test about a single internal_call and internal_return
Expand Down Expand Up @@ -297,7 +271,7 @@ TEST_F(AvmExecutionTests, simpleInternalCall)
auto row = std::ranges::find_if(trace.begin(), trace.end(), [](Row r) { return r.avm_main_sel_op_add == 1; });
EXPECT_EQ(row->avm_main_ic, 345567789);

gen_proof_and_validate(bytecode, std::move(trace), {});
validate_trace(std::move(trace));
}

// Positive test with some nested internall calls
Expand Down Expand Up @@ -377,7 +351,7 @@ TEST_F(AvmExecutionTests, nestedInternalCalls)
EXPECT_EQ(row->avm_main_ic, 187);
EXPECT_EQ(row->avm_main_pc, 4);

gen_proof_and_validate(bytecode, std::move(trace), {});
validate_trace(std::move(trace));
}

// Positive test with JUMP and CALLDATACOPY
Expand Down Expand Up @@ -451,7 +425,7 @@ TEST_F(AvmExecutionTests, jumpAndCalldatacopy)
// It must have failed as subtraction was "jumped over".
EXPECT_EQ(row, trace.end());

gen_proof_and_validate(bytecode, std::move(trace), std::vector<FF>{ 13, 156 });
validate_trace(std::move(trace));
}

// Positive test with MOV.
Expand Down Expand Up @@ -499,7 +473,7 @@ TEST_F(AvmExecutionTests, movOpcode)
EXPECT_EQ(row->avm_main_ia, 19);
EXPECT_EQ(row->avm_main_ic, 19);

gen_proof_and_validate(bytecode, std::move(trace), {});
validate_trace(std::move(trace));
}

// Positive test with CMOV.
Expand Down Expand Up @@ -555,7 +529,7 @@ TEST_F(AvmExecutionTests, cmovOpcode)
EXPECT_EQ(row->avm_main_ic, 3);
EXPECT_EQ(row->avm_main_id, 5);

gen_proof_and_validate(bytecode, std::move(trace), {});
validate_trace(std::move(trace));
}

// Positive test with indirect MOV.
Expand Down Expand Up @@ -603,7 +577,7 @@ TEST_F(AvmExecutionTests, indMovOpcode)
EXPECT_EQ(row->avm_main_ia, 255);
EXPECT_EQ(row->avm_main_ic, 255);

gen_proof_and_validate(bytecode, std::move(trace), {});
validate_trace(std::move(trace));
}

// Positive test for SET and CAST opcodes
Expand Down Expand Up @@ -644,7 +618,7 @@ TEST_F(AvmExecutionTests, setAndCastOpcodes)
auto row = std::ranges::find_if(trace.begin(), trace.end(), [](Row r) { return r.avm_main_sel_op_cast == 1; });
EXPECT_EQ(row->avm_main_ic, 19); // 0XB813 --> 0X13 = 19

gen_proof_and_validate(bytecode, std::move(trace), {});
validate_trace(std::move(trace));
}

// Negative test detecting an invalid opcode byte.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ TEST_F(AvmIndirectMemTests, allIndirectAdd)
EXPECT_EQ(row->avm_main_mem_op_b, FF(1));
EXPECT_EQ(row->avm_main_mem_op_c, FF(1));

validate_trace(std::move(trace));
validate_trace(std::move(trace), true);
}

// Testing a subtraction operation with direct input operands a, b, and an indirect
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ TEST_F(AvmMemOpcodeTests, indirectMovInvalidAddressTag)
Field(&Row::avm_mem_r_in_tag, static_cast<uint32_t>(AvmMemoryTag::U32)),
Field(&Row::avm_mem_ind_op_c, 1)));

validate_trace(std::move(trace));
validate_trace(std::move(trace), true);
}

/******************************************************************************
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ TEST_F(AvmMemoryTests, mismatchedTagAddOperation)
EXPECT_EQ(row->avm_mem_r_in_tag, FF(static_cast<uint32_t>(AvmMemoryTag::U8)));
EXPECT_EQ(row->avm_mem_tag, FF(static_cast<uint32_t>(AvmMemoryTag::FF)));

validate_trace(std::move(trace));
validate_trace(std::move(trace), true);
}

// Testing an equality operation with a mismatched memory tag.
Expand Down
Loading