From 3308054637c9ae5e827d31bdfd6f0612f737137f Mon Sep 17 00:00:00 2001 From: Tom French <15848336+TomAFrench@users.noreply.github.com> Date: Fri, 29 Sep 2023 22:13:27 +0100 Subject: [PATCH] chore: refactor integration test (#2901) --- .../browser/compile_prove_verify.test.ts | 186 ++++++++---------- 1 file changed, 84 insertions(+), 102 deletions(-) diff --git a/compiler/integration-tests/test/integration/browser/compile_prove_verify.test.ts b/compiler/integration-tests/test/integration/browser/compile_prove_verify.test.ts index f45734f6327..155fbbf02fa 100644 --- a/compiler/integration-tests/test/integration/browser/compile_prove_verify.test.ts +++ b/compiler/integration-tests/test/integration/browser/compile_prove_verify.test.ts @@ -3,22 +3,16 @@ import { TEST_LOG_LEVEL } from '../../environment.js'; import { Logger } from 'tslog'; import { initializeResolver } from '@noir-lang/source-resolver'; import newCompiler, { compile, init_log_level as compilerLogLevel } from '@noir-lang/noir_wasm'; -import { acvm, abi } from '@noir-lang/noir_js'; +import { acvm, abi, generateWitness, acirToUint8Array } from '@noir-lang/noir_js'; import { Barretenberg, RawBuffer, Crs } from '@aztec/bb.js'; -import { decompressSync as gunzip } from 'fflate'; import { ethers } from 'ethers'; import * as TOML from 'smol-toml'; -const mnemonic = 'test test test test test test test test test test test junk'; const provider = new ethers.JsonRpcProvider('http://localhost:8545'); -const walletMnemonic = ethers.Wallet.fromPhrase(mnemonic); -const wallet = walletMnemonic.connect(provider); const logger = new Logger({ name: 'test', minLevel: TEST_LOG_LEVEL }); -const { default: initACVM, executeCircuit, compressWitness } = acvm; -const { default: newABICoder, abiEncode } = abi; - -type WitnessMap = acvm.WitnessMap; +const { default: initACVM } = acvm; +const { default: newABICoder } = abi; await newCompiler(); await newABICoder(); @@ -26,8 +20,9 @@ await initACVM(); compilerLogLevel('INFO'); -async function getFile(url: URL): Promise { - const response = await fetch(url); +async function getFile(file_path: string): Promise { + const file_url = new URL(file_path, import.meta.url); + const response = await fetch(file_url); if (!response.ok) throw new Error('Network response was not OK'); @@ -58,45 +53,64 @@ const suite = Mocha.Suite.create(mocha.suite, 'Noir end to end test'); suite.timeout(60 * 20e3); //20mins -test_cases.forEach((testInfo) => { - const test_name = testInfo.case.split('/').pop(); - const caseLogger = logger.getSubLogger({ - prefix: [test_name], - }); - const mochaTest = new Mocha.Test(`${test_name} (Compile, Execute, Prove, Verify)`, async () => { - const base_relative_path = '../../../../..'; - const test_case = testInfo.case; +const api = await Barretenberg.new(numberOfThreads); +await api.commonInitSlabAllocator(CIRCUIT_SIZE); - const noir_source_url = new URL(`${base_relative_path}/${test_case}/src/main.nr`, import.meta.url); - const prover_toml_url = new URL(`${base_relative_path}/${test_case}/Prover.toml`, import.meta.url); - const compiled_contract_url = new URL(`${base_relative_path}/${testInfo.compiled}`, import.meta.url); - const deploy_information_url = new URL(`${base_relative_path}/${testInfo.deployInformation}`, import.meta.url); +// Plus 1 needed! +const crs = await Crs.new(CIRCUIT_SIZE + 1); +await api.srsInitSrs(new RawBuffer(crs.getG1Data()), crs.numPoints, new RawBuffer(crs.getG2Data())); - const noir_source = await getFile(noir_source_url); - const prover_toml = await getFile(prover_toml_url); - const compiled_contract = await getFile(compiled_contract_url); - const deploy_information = await getFile(deploy_information_url); +const acirComposer = await api.acirNewAcirComposer(CIRCUIT_SIZE); - const { abi } = JSON.parse(compiled_contract); - const { deployedTo } = JSON.parse(deploy_information); +async function getCircuit(noirSource: string) { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + initializeResolver((id: string) => { + logger.debug('source-resolver: resolving:', id); + return noirSource; + }); - const contract = new ethers.Contract(deployedTo, abi, wallet); + return compile({}); +} - expect(noir_source).to.be.a.string; +function separatePublicInputsFromProof( + proof: Uint8Array, + numPublicInputs: number, +): { proof: Uint8Array; publicInputs: Uint8Array[] } { + const publicInputs = Array.from({ length: numPublicInputs }, (_, i) => { + const offset = i * FIELD_ELEMENT_BYTES; + return proof.slice(offset, offset + FIELD_ELEMENT_BYTES); + }); + const slicedProof = proof.slice(numPublicInputs * FIELD_ELEMENT_BYTES); - initializeResolver((id: string) => { - caseLogger.debug('source-resolver: resolving:', id); - return noir_source; - }); + return { + proof: slicedProof, + publicInputs, + }; +} - const inputs = TOML.parse(prover_toml); +async function generateProof(base64Bytecode: string, witnessUint8Array: Uint8Array, optimizeForRecursion: boolean) { + const acirUint8Array = acirToUint8Array(base64Bytecode); + // This took ~6.5 minutes! + return api.acirCreateProof(acirComposer, acirUint8Array, witnessUint8Array, optimizeForRecursion); +} - expect(inputs, 'Prover.toml').to.be.an('object'); +async function verifyProof(proof: Uint8Array, optimizeForRecursion: boolean) { + await api.acirInitVerificationKey(acirComposer); + const verified = await api.acirVerifyProof(acirComposer, proof, optimizeForRecursion); + return verified; +} - let compile_output; +test_cases.forEach((testInfo) => { + const test_name = testInfo.case.split('/').pop(); + const mochaTest = new Mocha.Test(`${test_name} (Compile, Execute, Prove, Verify)`, async () => { + const base_relative_path = '../../../../..'; + const test_case = testInfo.case; + + const noir_source = await getFile(`${base_relative_path}/${test_case}/src/main.nr`); + let compile_output; try { - compile_output = await compile({}); + compile_output = await getCircuit(noir_source); expect(await compile_output, 'Compile output ').to.be.an('object'); } catch (e) { @@ -104,72 +118,40 @@ test_cases.forEach((testInfo) => { throw e; } - let witnessMap: WitnessMap; - try { - witnessMap = abiEncode(compile_output.abi, inputs, null); - } catch (e) { - expect(e, 'Abi Encoding Step').to.not.be.an('error'); - throw e; - } + const prover_toml = await getFile(`${base_relative_path}/${test_case}/Prover.toml`); + const inputs = TOML.parse(prover_toml); - let solvedWitness: WitnessMap; - let compressedByteCode; - try { - compressedByteCode = Uint8Array.from(atob(compile_output.circuit), (c) => c.charCodeAt(0)); + const witnessArray: Uint8Array = await generateWitness( + { + bytecode: compile_output.circuit, + abi: compile_output.abi, + }, + inputs, + ); - solvedWitness = await executeCircuit(compressedByteCode, witnessMap, () => { - throw Error('unexpected oracle'); - }); - } catch (e) { - expect(e, 'Abi Encoding Step').to.not.be.an('error'); - throw e; - } + // JS Proving - try { - const compressedWitness = compressWitness(solvedWitness); - const acirUint8Array = gunzip(compressedByteCode); - const witnessUint8Array = gunzip(compressedWitness); - - const isRecursive = false; - const api = await Barretenberg.new(numberOfThreads); - await api.commonInitSlabAllocator(CIRCUIT_SIZE); - - // Plus 1 needed! - const crs = await Crs.new(CIRCUIT_SIZE + 1); - await api.srsInitSrs(new RawBuffer(crs.getG1Data()), crs.numPoints, new RawBuffer(crs.getG2Data())); - - const acirComposer = await api.acirNewAcirComposer(CIRCUIT_SIZE); - - // This took ~6.5 minutes! - const proof = await api.acirCreateProof(acirComposer, acirUint8Array, witnessUint8Array, isRecursive); - - // And this took ~5 minutes! - const verified = await api.acirVerifyProof(acirComposer, proof, isRecursive); - - expect(verified, 'Proof fails verification in JS').to.be.true; - - try { - let result; - if (testInfo.numPublicInputs === 0) { - result = await contract.verify(proof, []); - } else { - const publicInputs = Array.from({ length: testInfo.numPublicInputs }, (_, i) => { - const offset = i * FIELD_ELEMENT_BYTES; - return proof.slice(offset, offset + FIELD_ELEMENT_BYTES); - }); - const slicedProof = proof.slice(testInfo.numPublicInputs * FIELD_ELEMENT_BYTES); - result = await contract.verify(slicedProof, publicInputs); - } - - expect(result).to.be.true; - } catch (error) { - console.error('Error while submitting the proof:', error); - throw error; - } - } catch (e) { - expect(e, 'Proving and Verifying').to.not.be.an('error'); - throw e; - } + const isRecursive = false; + const proofWithPublicInputs = await generateProof(compile_output.circuit, witnessArray, isRecursive); + + // JS verification + + const verified = await verifyProof(proofWithPublicInputs, isRecursive); + expect(verified, 'Proof fails verification in JS').to.be.true; + + // Smart contract verification + + const compiled_contract = await getFile(`${base_relative_path}/${testInfo.compiled}`); + const deploy_information = await getFile(`${base_relative_path}/${testInfo.deployInformation}`); + + const { abi } = JSON.parse(compiled_contract); + const { deployedTo } = JSON.parse(deploy_information); + const contract = new ethers.Contract(deployedTo, abi, provider); + + const { proof, publicInputs } = separatePublicInputsFromProof(proofWithPublicInputs, testInfo.numPublicInputs); + const result = await contract.verify(proof, publicInputs); + + expect(result).to.be.true; }); suite.addTest(mochaTest);