Skip to content

Commit

Permalink
feat(avm): wip
Browse files Browse the repository at this point in the history
  • Loading branch information
fcarreiro committed Sep 18, 2024
1 parent b4e2fd0 commit a80882e
Show file tree
Hide file tree
Showing 42 changed files with 265 additions and 157 deletions.
7 changes: 4 additions & 3 deletions barretenberg/cpp/src/barretenberg/bb/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "barretenberg/dsl/acir_format/proof_surgeon.hpp"
#include "barretenberg/dsl/acir_proofs/honk_contract.hpp"
#include "barretenberg/honk/proof_system/types/proof.hpp"
#include "barretenberg/numeric/bitop/get_msb.hpp"
#include "barretenberg/plonk/proof_system/proving_key/serialize.hpp"
#include "barretenberg/plonk_honk_shared/types/aggregation_object_type.hpp"
#include "barretenberg/serialize/cbind.hpp"
Expand Down Expand Up @@ -966,8 +967,8 @@ void avm_prove(const std::filesystem::path& bytecode_path,
std::vector<fr> vk_as_fields = verification_key.to_field_elements();

vinfo("vk fields size: ", vk_as_fields.size());
vinfo("circuit size: ", vk_as_fields[0]);
vinfo("num of pub inputs: ", vk_as_fields[1]);
vinfo("circuit size: ", static_cast<size_t>(vk_as_fields[0]));
vinfo("num of pub inputs: ", static_cast<size_t>(vk_as_fields[1]));

std::string vk_json = to_json(vk_as_fields);
const auto proof_path = output_path / "proof";
Expand Down Expand Up @@ -1016,7 +1017,7 @@ bool avm_verify(const std::filesystem::path& proof_path, const std::filesystem::
std::span vk_span(vk_as_fields);

vinfo("vk fields size: ", vk_as_fields.size());
vinfo("circuit size: ", circuit_size);
vinfo("circuit size: ", circuit_size, " (next or eq power: 2^", numeric::round_up_power_2(circuit_size), ")");
vinfo("num of pub inputs: ", num_public_inputs);

// Each commitment (precomputed entity) is represented as 2 Fq field elements.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "barretenberg/common/throw_or_abort.hpp"
#include "barretenberg/ecc/curves/bn254/fr.hpp"
#include "barretenberg/honk/proof_system/logderivative_library.hpp"
#include "barretenberg/numeric/bitop/get_msb.hpp"
#include "barretenberg/relations/generic_lookup/generic_lookup_relation.hpp"
#include "barretenberg/relations/generic_permutation/generic_permutation_relation.hpp"
#include "barretenberg/vm/stats.hpp"
Expand All @@ -16,28 +17,31 @@ namespace bb {

AvmCircuitBuilder::ProverPolynomials AvmCircuitBuilder::compute_polynomials() const
{
const auto num_rows = get_circuit_subgroup_size();
const size_t circuit_subgroup_size = get_circuit_subgroup_size();
// FIXME: Either some algo or the Polynomial class seems to require this to be a power of 2.
const size_t num_rows = numeric::round_up_power_2(get_num_gates());
ProverPolynomials polys;

// Allocate mem for each column
AVM_TRACK_TIME("circuit_builder/init_polys_to_be_shifted", ({
for (auto& poly : polys.get_to_be_shifted()) {
poly = Polynomial{ /*memory size*/ num_rows - 1,
/*largest possible index*/ num_rows,
/*largest possible index*/ circuit_subgroup_size,
/*make shiftable with offset*/ 1 };
}
}));
// catch-all with fully formed polynomials
AVM_TRACK_TIME("circuit_builder/init_polys_unshifted", ({
auto unshifted = polys.get_unshifted();
bb::parallel_for(unshifted.size(), [&](size_t i) {
auto& poly = unshifted[i];
if (poly.is_empty()) {
// Not set above
poly = Polynomial{ /*memory size*/ num_rows, /*largest possible index*/ num_rows };
}
});
}));
AVM_TRACK_TIME(
"circuit_builder/init_polys_unshifted", ({
auto unshifted = polys.get_unshifted();
bb::parallel_for(unshifted.size(), [&](size_t i) {
auto& poly = unshifted[i];
if (poly.is_empty()) {
// Not set above
poly = Polynomial{ /*memory size*/ num_rows, /*largest possible index*/ circuit_subgroup_size };
}
});
}));

AVM_TRACK_TIME(
"circuit_builder/set_polys_unshifted", ({
Expand Down Expand Up @@ -721,7 +725,8 @@ bool AvmCircuitBuilder::check_circuit() const
};

auto polys = compute_polynomials();
const size_t num_rows = polys.get_polynomial_size();
// We'll only check up to the generated trace which might be << than the circuit subgroup size.
const size_t num_rows = get_num_gates();

// Checks that we will run.
using SignalErrorFn = const std::function<void(const std::string&)>&;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,38 @@ namespace bb {

class AvmCircuitBuilder {
public:
// Do not use this constant directly, use get_circuit_subgroup_size() instead.
constexpr static size_t CIRCUIT_SUBGROUP_SIZE = 1 << 21;

using Flavor = bb::AvmFlavor;
using FF = Flavor::FF;
using Row = AvmFullRow<FF>;

// TODO: template
using Polynomial = Flavor::Polynomial;
using ProverPolynomials = Flavor::ProverPolynomials;

std::vector<Row> rows;

void set_trace(std::vector<Row>&& trace) { rows = std::move(trace); }
void set_trace(std::vector<Row>&& trace)
{
rows = std::move(trace);
num_rows = rows.size();
}
void clear_trace()
{
rows.clear();
rows.shrink_to_fit();
num_rows = 0;
}

ProverPolynomials compute_polynomials() const;

bool check_circuit() const;

size_t get_num_gates() const { return rows.size(); }
size_t get_num_gates() const { return num_rows; }

size_t get_circuit_subgroup_size() const
{
const size_t num_rows = get_num_gates();
const auto num_rows_log2 = static_cast<size_t>(numeric::get_msb64(num_rows));
size_t num_rows_pow2 = 1UL << (num_rows_log2 + (1UL << num_rows_log2 == num_rows ? 0 : 1));
return num_rows_pow2;
}
size_t get_circuit_subgroup_size() const { return CIRCUIT_SUBGROUP_SIZE; }

private:
size_t num_rows = 0;
std::vector<Row> rows;
};

} // namespace bb
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,8 @@ std::shared_ptr<Flavor::ProvingKey> AvmComposer::compute_proving_key(CircuitCons
return proving_key;
}

// Initialize proving_key
const size_t subgroup_size = circuit_constructor.get_circuit_subgroup_size();
proving_key = std::make_shared<Flavor::ProvingKey>(subgroup_size, 0);

return proving_key;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ class AvmComposer {
// The commitment key is passed to the prover but also used herein to compute the verfication key commitments
std::shared_ptr<CommitmentKey> commitment_key;

AggregationObjectPubInputIndices recursive_proof_public_input_indices;
bool contains_recursive_proof = false;
bool computed_witness = false;

AvmComposer() { crs_factory_ = bb::srs::get_bn254_crs_factory(); }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ bool AvmVerifier::verify_proof(const HonkProof& proof,
CommitmentLabels commitment_labels;

const auto circuit_size = transcript->template receive_from_prover<uint32_t>("circuit_size");

if (circuit_size != key->circuit_size) {
vinfo("Circuit size mismatch: expected", key->circuit_size, " got ", circuit_size);
return false;
Expand Down Expand Up @@ -102,7 +101,7 @@ bool AvmVerifier::verify_proof(const HonkProof& proof,

// If Sumcheck did not verify, return false
if (!sumcheck_verified.has_value() || !sumcheck_verified.value()) {
vinfo("Sumcheck failed");
vinfo("Sumcheck verification failed");
return false;
}

Expand Down Expand Up @@ -156,10 +155,10 @@ bool AvmVerifier::verify_proof(const HonkProof& proof,
transcript);

auto pairing_points = PCS::reduce_verify(opening_claim, transcript);
bool zeromoprh_verified = key->pcs_verification_key->pairing_check(pairing_points[0], pairing_points[1]);
auto zeromorph_verified = key->pcs_verification_key->pairing_check(pairing_points[0], pairing_points[1]);

if (!zeromoprh_verified) {
vinfo("ZeroMorph failed");
if (!zeromorph_verified) {
vinfo("ZeroMorph verification failed");
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ class AvmArithmeticTests : public ::testing::Test {
public:
AvmArithmeticTests()
: public_inputs(generate_base_public_inputs())
, trace_builder(AvmTraceBuilder(public_inputs))
, trace_builder(
AvmTraceBuilder(public_inputs).set_full_precomputed_tables(false).set_range_check_required(false))
{
srs::init_crs_factory("../srs_db/ignition");
}
Expand All @@ -215,7 +216,9 @@ class AvmArithmeticTests : public ::testing::Test {

void gen_trace_builder(std::vector<FF> const& calldata)
{
trace_builder = AvmTraceBuilder(public_inputs, {}, 0, calldata);
trace_builder = AvmTraceBuilder(public_inputs, {}, 0, calldata)
.set_full_precomputed_tables(false)
.set_range_check_required(false);
}

// Generate a trace with an EQ opcode operation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,8 @@ class AvmBitwiseTests : public ::testing::Test {
public:
AvmBitwiseTests()
: public_inputs(generate_base_public_inputs())
, trace_builder(AvmTraceBuilder(public_inputs))
, trace_builder(
AvmTraceBuilder(public_inputs).set_full_precomputed_tables(false).set_range_check_required(false))
{
srs::init_crs_factory("../srs_db/ignition");
}
Expand Down
27 changes: 20 additions & 7 deletions barretenberg/cpp/src/barretenberg/vm/avm/tests/cast.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ class AvmCastTests : public ::testing::Test {
public:
AvmCastTests()
: public_inputs(generate_base_public_inputs())
, trace_builder(AvmTraceBuilder(public_inputs))
, trace_builder(
AvmTraceBuilder(public_inputs).set_full_precomputed_tables(false).set_range_check_required(false))
{
srs::init_crs_factory("../srs_db/ignition");
}
Expand Down Expand Up @@ -162,7 +163,9 @@ TEST_F(AvmCastTests, noTruncationFFToU32)
TEST_F(AvmCastTests, truncationFFToU16ModMinus1)
{
calldata = { FF::modulus - 1 };
trace_builder = AvmTraceBuilder(public_inputs, {}, 0, calldata);
trace_builder = AvmTraceBuilder(public_inputs, {}, 0, calldata)
.set_full_precomputed_tables(false)
.set_range_check_required(false);
trace_builder.op_set(0, 1, 1, AvmMemoryTag::U32);
trace_builder.op_calldata_copy(0, 0, 1, 0);
trace_builder.op_cast(0, 0, 1, AvmMemoryTag::U16);
Expand All @@ -176,7 +179,9 @@ TEST_F(AvmCastTests, truncationFFToU16ModMinus1)
TEST_F(AvmCastTests, truncationFFToU16ModMinus2)
{
calldata = { FF::modulus_minus_two };
trace_builder = AvmTraceBuilder(public_inputs, {}, 0, calldata);
trace_builder = AvmTraceBuilder(public_inputs, {}, 0, calldata)
.set_full_precomputed_tables(false)
.set_range_check_required(false);
trace_builder.op_set(0, 1, 1, AvmMemoryTag::U32);
trace_builder.op_calldata_copy(0, 0, 1, 0);
trace_builder.op_cast(0, 0, 1, AvmMemoryTag::U16);
Expand Down Expand Up @@ -288,7 +293,9 @@ TEST_F(AvmCastNegativeTests, wrongOutputAluIc)
TEST_F(AvmCastNegativeTests, wrongLimbDecompositionInput)
{
calldata = { FF::modulus_minus_two };
trace_builder = AvmTraceBuilder(public_inputs, {}, 0, calldata);
trace_builder = AvmTraceBuilder(public_inputs, {}, 0, calldata)
.set_full_precomputed_tables(false)
.set_range_check_required(false);
trace_builder.op_calldata_copy(0, 0, 1, 0);
trace_builder.op_cast(0, 0, 1, AvmMemoryTag::U16);
trace_builder.op_return(0, 0, 0);
Expand All @@ -313,7 +320,9 @@ TEST_F(AvmCastNegativeTests, wrongPSubALo)
TEST_F(AvmCastNegativeTests, wrongPSubAHi)
{
calldata = { FF::modulus_minus_two - 987 };
trace_builder = AvmTraceBuilder(public_inputs, {}, 0, calldata);
trace_builder = AvmTraceBuilder(public_inputs, {}, 0, calldata)
.set_full_precomputed_tables(false)
.set_range_check_required(false);
trace_builder.op_calldata_copy(0, 0, 1, 0);
trace_builder.op_cast(0, 0, 1, AvmMemoryTag::U16);
trace_builder.op_return(0, 0, 0);
Expand Down Expand Up @@ -351,7 +360,9 @@ TEST_F(AvmCastNegativeTests, wrongRangeCheckDecompositionLo)
TEST_F(AvmCastNegativeTests, wrongRangeCheckDecompositionHi)
{
calldata = { FF::modulus_minus_two - 987 };
trace_builder = AvmTraceBuilder(public_inputs, {}, 0, calldata);
trace_builder = AvmTraceBuilder(public_inputs, {}, 0, calldata)
.set_full_precomputed_tables(false)
.set_range_check_required(false);
trace_builder.op_calldata_copy(0, 0, 1, 0);
trace_builder.op_cast(0, 0, 1, AvmMemoryTag::U16);
trace_builder.op_return(0, 0, 0);
Expand Down Expand Up @@ -389,7 +400,9 @@ TEST_F(AvmCastNegativeTests, wrongCopySubLoForRangeCheck)
TEST_F(AvmCastNegativeTests, wrongCopySubHiForRangeCheck)
{
std::vector<FF> const calldata = { FF::modulus_minus_two - 972836 };
trace_builder = AvmTraceBuilder(public_inputs, {}, 0, calldata);
trace_builder = AvmTraceBuilder(public_inputs, {}, 0, calldata)
.set_full_precomputed_tables(false)
.set_range_check_required(false);
trace_builder.op_calldata_copy(0, 0, 1, 0);
trace_builder.op_cast(0, 0, 1, AvmMemoryTag::U128);
trace_builder.op_return(0, 0, 0);
Expand Down
19 changes: 14 additions & 5 deletions barretenberg/cpp/src/barretenberg/vm/avm/tests/comparison.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ class AvmCmpTests : public ::testing::Test {
public:
AvmCmpTests()
: public_inputs(generate_base_public_inputs())
, trace_builder(AvmTraceBuilder(public_inputs))
, trace_builder(
AvmTraceBuilder(public_inputs).set_full_precomputed_tables(false).set_range_check_required(false))
{
srs::init_crs_factory("../srs_db/ignition");
}
Expand All @@ -110,7 +111,9 @@ TEST_P(AvmCmpTestsLT, ParamTest)

if (mem_tag == AvmMemoryTag::FF) {
calldata = { a, b };
trace_builder = AvmTraceBuilder(public_inputs, {}, 0, calldata);
trace_builder = AvmTraceBuilder(public_inputs, {}, 0, calldata)
.set_full_precomputed_tables(false)
.set_range_check_required(false);
trace_builder.op_calldata_copy(0, 0, 2, 0);
} else {
trace_builder.op_set(0, a, 0, mem_tag);
Expand Down Expand Up @@ -146,7 +149,9 @@ TEST_P(AvmCmpTestsLTE, ParamTest)

if (mem_tag == AvmMemoryTag::FF) {
calldata = { a, b };
trace_builder = AvmTraceBuilder(public_inputs, {}, 0, calldata);
trace_builder = AvmTraceBuilder(public_inputs, {}, 0, calldata)
.set_full_precomputed_tables(false)
.set_range_check_required(false);
trace_builder.op_calldata_copy(0, 0, 2, 0);
} else {
trace_builder.op_set(0, a, 0, mem_tag);
Expand Down Expand Up @@ -324,7 +329,9 @@ TEST_P(AvmCmpNegativeTestsLT, ParamTest)
const auto [failure_string, failure_mode] = failure;
const auto [a, b, output] = params;

trace_builder = AvmTraceBuilder(public_inputs, {}, 0, std::vector<FF>{ a, b, output });
trace_builder = AvmTraceBuilder(public_inputs, {}, 0, std::vector<FF>{ a, b, output })
.set_full_precomputed_tables(false)
.set_range_check_required(false);
trace_builder.op_calldata_copy(0, 0, 3, 0);
trace_builder.op_lt(0, 0, 1, 2, AvmMemoryTag::FF);
trace_builder.op_return(0, 0, 0);
Expand All @@ -343,7 +350,9 @@ TEST_P(AvmCmpNegativeTestsLTE, ParamTest)
const auto [failure, params] = GetParam();
const auto [failure_string, failure_mode] = failure;
const auto [a, b, output] = params;
trace_builder = AvmTraceBuilder(public_inputs, {}, 0, std::vector<FF>{ a, b, output });
trace_builder = AvmTraceBuilder(public_inputs, {}, 0, std::vector<FF>{ a, b, output })
.set_full_precomputed_tables(false)
.set_range_check_required(false);
trace_builder.op_calldata_copy(0, 0, 3, 0);
trace_builder.op_lte(0, 0, 1, 2, AvmMemoryTag::FF);
trace_builder.op_return(0, 0, 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ class AvmControlFlowTests : public ::testing::Test {
public:
AvmControlFlowTests()
: public_inputs(generate_base_public_inputs())
, trace_builder(AvmTraceBuilder(public_inputs))
, trace_builder(
AvmTraceBuilder(public_inputs).set_full_precomputed_tables(false).set_range_check_required(false))
{
srs::init_crs_factory("../srs_db/ignition");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,18 @@ class AvmExecutionTests : public ::testing::Test {
VmPublicInputs public_inputs;

AvmExecutionTests()
: public_inputs_vec(PUBLIC_CIRCUIT_PUBLIC_INPUTS_LENGTH){};
: public_inputs_vec(PUBLIC_CIRCUIT_PUBLIC_INPUTS_LENGTH)
{
Execution::set_trace_builder_constructor([](VmPublicInputs public_inputs,
ExecutionHints execution_hints,
uint32_t side_effect_counter,
std::vector<FF> calldata) {
return AvmTraceBuilder(
std::move(public_inputs), std::move(execution_hints), side_effect_counter, std::move(calldata))
.set_full_precomputed_tables(false)
.set_range_check_required(false);
});
};

protected:
const FixedGasTable& GAS_COST_TABLE = FixedGasTable::get();
Expand Down
3 changes: 2 additions & 1 deletion barretenberg/cpp/src/barretenberg/vm/avm/tests/gas.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ void test_gas(StartGas startGas, OpcodesFunc apply_opcodes, CheckFunc check_trac

VmPublicInputs public_inputs;
std::get<0>(public_inputs) = kernel_inputs;
AvmTraceBuilder trace_builder(public_inputs);
auto trace_builder =
AvmTraceBuilder(public_inputs).set_full_precomputed_tables(false).set_range_check_required(false);

// We should return a value of 1 for the sender, as it exists at index 0
apply_opcodes(trace_builder);
Expand Down
Loading

0 comments on commit a80882e

Please sign in to comment.