Skip to content

Commit

Permalink
feat(bb): add ability to write pk to file or stdout (#3335)
Browse files Browse the repository at this point in the history
A vendor succinct is looking to consume noir, in order to not need to
compile the circuit each time nargo prove is run.

A followup pr will come to allow other flows to consume the pk
  • Loading branch information
Maddiaa0 authored Nov 20, 2023
1 parent 7e89ff3 commit c99862c
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 25 deletions.
1 change: 1 addition & 0 deletions barretenberg/acir_tests/flows/all_cmds.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ FLAGS="-c $CRS_PATH $VFLAG"
$BIN gates $FLAGS $BFLAG > /dev/null
$BIN prove -o proof $FLAGS $BFLAG
$BIN write_vk -o vk $FLAGS $BFLAG
$BIN write_pk -o pk $FLAGS $BFLAG
$BIN verify -k vk -p proof $FLAGS

# Check supplemental functions.
Expand Down
67 changes: 44 additions & 23 deletions barretenberg/cpp/src/barretenberg/bb/main.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "barretenberg/dsl/acir_format/acir_format.hpp"
#include "barretenberg/dsl/types.hpp"
#include "barretenberg/plonk/proof_system/proving_key/serialize.hpp"
#include "config.hpp"
#include "get_bytecode.hpp"
#include "get_crs.hpp"
Expand Down Expand Up @@ -183,7 +184,7 @@ bool verify(const std::string& proof_path, bool recursive, const std::string& vk
* @param bytecodePath Path to the file containing the serialized circuit
* @param outputPath Path to write the verification key to
*/
void writeVk(const std::string& bytecodePath, const std::string& outputPath)
void write_vk(const std::string& bytecodePath, const std::string& outputPath)
{
auto constraint_system = get_constraint_system(bytecodePath);
auto acir_composer = init(constraint_system);
Expand All @@ -199,6 +200,22 @@ void writeVk(const std::string& bytecodePath, const std::string& outputPath)
}
}

void write_pk(const std::string& bytecodePath, const std::string& outputPath)
{
auto constraint_system = get_constraint_system(bytecodePath);
auto acir_composer = init(constraint_system);
auto pk = acir_composer.init_proving_key(constraint_system);
auto serialized_pk = to_buffer(*pk);

if (outputPath == "-") {
writeRawBytesToStdout(serialized_pk);
vinfo("pk written to stdout");
} else {
write_file(outputPath, serialized_pk);
vinfo("pk written to: ", outputPath);
}
}

/**
* @brief Writes a Solidity verifier contract for an ACIR circuit to a file
*
Expand Down Expand Up @@ -253,7 +270,7 @@ void contract(const std::string& output_path, const std::string& vk_path)
* @param vk_path Path to the file containing the serialized verification key
* @param output_path Path to write the proof to
*/
void proofAsFields(const std::string& proof_path, std::string const& vk_path, const std::string& output_path)
void proof_as_fields(const std::string& proof_path, std::string const& vk_path, const std::string& output_path)
{
auto acir_composer = init();
auto vk_data = from_buffer<plonk::verification_key_data>(read_file(vk_path));
Expand Down Expand Up @@ -282,7 +299,7 @@ void proofAsFields(const std::string& proof_path, std::string const& vk_path, co
* @param vk_path Path to the file containing the serialized verification key
* @param output_path Path to write the verification key to
*/
void vkAsFields(const std::string& vk_path, const std::string& output_path)
void vk_as_fields(const std::string& vk_path, const std::string& output_path)
{
auto acir_composer = init();
auto vk_data = from_buffer<plonk::verification_key_data>(read_file(vk_path));
Expand Down Expand Up @@ -311,7 +328,7 @@ void vkAsFields(const std::string& vk_path, const std::string& output_path)
*
* @param output_path Path to write the information to
*/
void acvmInfo(const std::string& output_path)
void acvm_info(const std::string& output_path)
{

const char* jsonData = R"({
Expand All @@ -335,12 +352,12 @@ void acvmInfo(const std::string& output_path)
}
}

bool flagPresent(std::vector<std::string>& args, const std::string& flag)
bool flag_present(std::vector<std::string>& args, const std::string& flag)
{
return std::find(args.begin(), args.end(), flag) != args.end();
}

std::string getOption(std::vector<std::string>& args, const std::string& option, const std::string& defaultValue)
std::string get_option(std::vector<std::string>& args, const std::string& option, const std::string& defaultValue)
{
auto itr = std::find(args.begin(), args.end(), option);
return (itr != args.end() && std::next(itr) != args.end()) ? *(std::next(itr)) : defaultValue;
Expand All @@ -350,7 +367,7 @@ int main(int argc, char* argv[])
{
try {
std::vector<std::string> args(argv + 1, argv + argc);
verbose = flagPresent(args, "-v") || flagPresent(args, "--verbose");
verbose = flag_present(args, "-v") || flag_present(args, "--verbose");

if (args.empty()) {
std::cerr << "No command provided.\n";
Expand All @@ -359,46 +376,50 @@ int main(int argc, char* argv[])

std::string command = args[0];

std::string bytecode_path = getOption(args, "-b", "./target/acir.gz");
std::string witness_path = getOption(args, "-w", "./target/witness.gz");
std::string proof_path = getOption(args, "-p", "./proofs/proof");
std::string vk_path = getOption(args, "-k", "./target/vk");
CRS_PATH = getOption(args, "-c", "./crs");
bool recursive = flagPresent(args, "-r") || flagPresent(args, "--recursive");
std::string bytecode_path = get_option(args, "-b", "./target/acir.gz");
std::string witness_path = get_option(args, "-w", "./target/witness.gz");
std::string proof_path = get_option(args, "-p", "./proofs/proof");
std::string vk_path = get_option(args, "-k", "./target/vk");
std::string pk_path = get_option(args, "-r", "./target/pk");
CRS_PATH = get_option(args, "-c", "./crs");
bool recursive = flag_present(args, "-r") || flag_present(args, "--recursive");

// Skip CRS initialization for any command which doesn't require the CRS.
if (command == "--version") {
writeStringToStdout(BB_VERSION);
return 0;
}
if (command == "info") {
std::string output_path = getOption(args, "-o", "info.json");
acvmInfo(output_path);
std::string output_path = get_option(args, "-o", "info.json");
acvm_info(output_path);
return 0;
}

if (command == "prove_and_verify") {
return proveAndVerify(bytecode_path, witness_path, recursive) ? 0 : 1;
}
if (command == "prove") {
std::string output_path = getOption(args, "-o", "./proofs/proof");
std::string output_path = get_option(args, "-o", "./proofs/proof");
prove(bytecode_path, witness_path, recursive, output_path);
} else if (command == "gates") {
gateCount(bytecode_path);
} else if (command == "verify") {
return verify(proof_path, recursive, vk_path) ? 0 : 1;
} else if (command == "contract") {
std::string output_path = getOption(args, "-o", "./target/contract.sol");
std::string output_path = get_option(args, "-o", "./target/contract.sol");
contract(output_path, vk_path);
} else if (command == "write_vk") {
std::string output_path = getOption(args, "-o", "./target/vk");
writeVk(bytecode_path, output_path);
std::string output_path = get_option(args, "-o", "./target/vk");
write_vk(bytecode_path, output_path);
} else if (command == "write_pk") {
std::string output_path = get_option(args, "-o", "./target/pk");
write_pk(bytecode_path, output_path);
} else if (command == "proof_as_fields") {
std::string output_path = getOption(args, "-o", proof_path + "_fields.json");
proofAsFields(proof_path, vk_path, output_path);
std::string output_path = get_option(args, "-o", proof_path + "_fields.json");
proof_as_fields(proof_path, vk_path, output_path);
} else if (command == "vk_as_fields") {
std::string output_path = getOption(args, "-o", vk_path + "_fields.json");
vkAsFields(vk_path, output_path);
std::string output_path = get_option(args, "-o", vk_path + "_fields.json");
vk_as_fields(vk_path, output_path);
} else {
std::cerr << "Unknown command: " << command << "\n";
return 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "barretenberg/dsl/acir_format/acir_format.hpp"
#include "barretenberg/dsl/acir_format/recursion_constraint.hpp"
#include "barretenberg/dsl/types.hpp"
#include "barretenberg/plonk/proof_system/proving_key/proving_key.hpp"
#include "barretenberg/plonk/proof_system/proving_key/serialize.hpp"
#include "barretenberg/plonk/proof_system/verification_key/sol_gen.hpp"
#include "barretenberg/plonk/proof_system/verification_key/verification_key.hpp"
Expand All @@ -30,12 +31,14 @@ void AcirComposer::create_circuit(acir_format::acir_format& constraint_system)
vinfo("gates: ", builder_.get_total_circuit_size());
}

void AcirComposer::init_proving_key(acir_format::acir_format& constraint_system)
std::shared_ptr<proof_system::plonk::proving_key> AcirComposer::init_proving_key(
acir_format::acir_format& constraint_system)
{
create_circuit(constraint_system);
acir_format::Composer composer;
vinfo("computing proving key...");
proving_key_ = composer.compute_proving_key(builder_);
return proving_key_;
}

std::vector<uint8_t> AcirComposer::create_proof(acir_format::acir_format& constraint_system,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class AcirComposer {

void create_circuit(acir_format::acir_format& constraint_system);

void init_proving_key(acir_format::acir_format& constraint_system);
std::shared_ptr<proof_system::plonk::proving_key> init_proving_key(acir_format::acir_format& constraint_system);

std::vector<uint8_t> create_proof(acir_format::acir_format& constraint_system,
acir_format::WitnessVector& witness,
Expand Down
10 changes: 10 additions & 0 deletions barretenberg/cpp/src/barretenberg/dsl/acir_proofs/c_bind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "barretenberg/common/serialize.hpp"
#include "barretenberg/common/slab_allocator.hpp"
#include "barretenberg/dsl/acir_format/acir_format.hpp"
#include "barretenberg/plonk/proof_system/proving_key/serialize.hpp"
#include "barretenberg/plonk/proof_system/verification_key/verification_key.hpp"
#include "barretenberg/srs/global_crs.hpp"
#include <cstdint>
Expand Down Expand Up @@ -73,6 +74,15 @@ WASM_EXPORT void acir_get_verification_key(in_ptr acir_composer_ptr, uint8_t** o
*out = to_heap_buffer(to_buffer(*vk));
}

WASM_EXPORT void acir_get_proving_key(in_ptr acir_composer_ptr, uint8_t const* acir_vec, uint8_t** out)
{
auto acir_composer = reinterpret_cast<acir_proofs::AcirComposer*>(*acir_composer_ptr);
auto constraint_system = acir_format::circuit_buf_to_acir_format(from_buffer<std::vector<uint8_t>>(acir_vec));
auto pk = acir_composer->init_proving_key(constraint_system);
// We flatten to a vector<uint8_t> first, as that's how we treat it on the calling side.
*out = to_heap_buffer(to_buffer(*pk));
}

WASM_EXPORT void acir_verify_proof(in_ptr acir_composer_ptr,
uint8_t const* proof_buf,
bool const* is_recursive,
Expand Down
2 changes: 2 additions & 0 deletions barretenberg/cpp/src/barretenberg/dsl/acir_proofs/c_bind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ WASM_EXPORT void acir_init_verification_key(in_ptr acir_composer_ptr);

WASM_EXPORT void acir_get_verification_key(in_ptr acir_composer_ptr, uint8_t** out);

WASM_EXPORT void acir_get_proving_key(in_ptr acir_composer_ptr, uint8_t const* acir_vec, uint8_t** out);

WASM_EXPORT void acir_verify_proof(in_ptr acir_composer_ptr,
uint8_t const* proof_buf,
bool const* is_recursive,
Expand Down
12 changes: 12 additions & 0 deletions barretenberg/ts/src/barretenberg_api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,18 @@ export class BarretenbergApi {
return out[0];
}

async acirGetProvingKey(acirComposerPtr: Ptr, constraintSystemBuf: Uint8Array): Promise<Uint8Array> {
const inArgs = [acirComposerPtr, constraintSystemBuf].map(serializeBufferable);
const outTypes: OutputType[] = [BufferDeserializer()];
const result = await this.wasm.callWasmExport(
'acir_get_proving_key',
inArgs,
outTypes.map(t => t.SIZE_IN_BYTES),
);
const out = result.map((r, i) => outTypes[i].fromBuffer(r));
return out[0];
}

async acirVerifyProof(acirComposerPtr: Ptr, proofBuf: Uint8Array, isRecursive: boolean): Promise<boolean> {
const inArgs = [acirComposerPtr, proofBuf, isRecursive].map(serializeBufferable);
const outTypes: OutputType[] = [BoolDeserializer()];
Expand Down
29 changes: 29 additions & 0 deletions barretenberg/ts/src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,25 @@ export async function writeVk(bytecodePath: string, crsPath: string, outputPath:
}
}

export async function writePk(bytecodePath: string, crsPath: string, outputPath: string) {
const { api, acirComposer } = await init(bytecodePath, crsPath);
try {
debug('initing proving key...');
const bytecode = getBytecode(bytecodePath);
const pk = await api.acirGetProvingKey(acirComposer, bytecode);

if (outputPath === '-') {
process.stdout.write(pk);
debug(`pk written to stdout`);
} else {
writeFileSync(outputPath, pk);
debug(`pk written to: ${outputPath}`);
}
} finally {
await api.destroy();
}
}

export async function proofAsFields(proofPath: string, vkPath: string, outputPath: string) {
const { api, acirComposer } = await initLite();

Expand Down Expand Up @@ -347,6 +366,16 @@ program
await writeVk(bytecodePath, crsPath, outputPath);
});

program
.command('write_pk')
.description('Output proving key.')
.option('-b, --bytecode-path <path>', 'Specify the bytecode path', './target/acir.gz')
.requiredOption('-o, --output-path <path>', 'Specify the path to write the key')
.action(async ({ bytecodePath, outputPath, crsPath }) => {
handleGlobalOptions();
await writePk(bytecodePath, crsPath, outputPath);
});

program
.command('proof_as_fields')
.description('Return the proof as fields elements')
Expand Down

0 comments on commit c99862c

Please sign in to comment.