Skip to content

Commit

Permalink
Update prover client tests
Browse files Browse the repository at this point in the history
  • Loading branch information
spalladino committed Nov 28, 2024
1 parent 3397f50 commit cc7d996
Show file tree
Hide file tree
Showing 20 changed files with 266 additions and 214 deletions.
9 changes: 1 addition & 8 deletions yarn-project/circuit-types/src/test/factories.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,7 @@ export function makeBloatedProcessedTx({
privateOnly?: boolean;
} = {}) {
seed *= 0x1000; // Avoid clashing with the previous mock values if seed only increases by 1.

if (!header) {
if (db) {
header = db.getInitialHeader();
} else {
header = makeHeader(seed);
}
}
header ??= db?.getInitialHeader() ?? makeHeader(seed);

const txConstantData = TxConstantData.empty();
txConstantData.historicalHeader = header;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import { OutboxAbi, RollupAbi } from '@aztec/l1-artifacts';
import { SHA256Trunc, StandardTree } from '@aztec/merkle-tree';
import { getVKTreeRoot } from '@aztec/noir-protocol-circuits-types';
import { protocolContractTreeRoot } from '@aztec/protocol-contracts';
import { LightweightBlockBuilder } from '@aztec/prover-client/block-builder';
import { L1Publisher } from '@aztec/sequencer-client';
import { NoopTelemetryClient } from '@aztec/telemetry-client/noop';
import {
Expand Down Expand Up @@ -51,7 +52,6 @@ import {
} from 'viem';
import { type PrivateKeyAccount, privateKeyToAccount } from 'viem/accounts';

import { LightweightBlockBuilder } from '../../../sequencer-client/src/block_builder/light.js';
import { sendL1ToL2Message } from '../fixtures/l1_to_l2_messaging.js';
import { setupL1Contracts } from '../fixtures/utils.js';

Expand Down
20 changes: 19 additions & 1 deletion yarn-project/foundation/src/collection/array.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { compactArray, removeArrayPaddingEnd, times, unique } from './array.js';
import { compactArray, maxBy, removeArrayPaddingEnd, times, unique } from './array.js';

describe('times', () => {
it('should return an array with the result from all executions', () => {
Expand Down Expand Up @@ -61,3 +61,21 @@ describe('unique', () => {
expect(unique([1n, 2n, 1n])).toEqual([1n, 2n]);
});
});

describe('maxBy', () => {
it('returns the max value', () => {
expect(maxBy([1, 2, 3], x => x)).toEqual(3);
});

it('returns the first max value', () => {
expect(maxBy([{ a: 1 }, { a: 3, b: 1 }, { a: 3, b: 2 }], ({ a }) => a)).toEqual({ a: 3, b: 1 });
});

it('returns undefined for an empty array', () => {
expect(maxBy([], x => x)).toBeUndefined();
});

it('applies the mapping function', () => {
expect(maxBy([1, 2, 3], x => -x)).toEqual(1);
});
});
24 changes: 24 additions & 0 deletions yarn-project/foundation/src/collection/array.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,20 @@ export function times<T>(n: number, fn: (i: number) => T): T[] {
return [...Array(n).keys()].map(i => fn(i));
}

/**
* Executes the given async function n times and returns the results in an array. Awaits each execution before starting the next one.
* @param n - How many times to repeat.
* @param fn - Mapper from index to value.
* @returns The array with the result from all executions.
*/
export async function timesAsync<T>(n: number, fn: (i: number) => Promise<T>): Promise<T[]> {
const results: T[] = [];
for (let i = 0; i < n; i++) {
results.push(await fn(i));
}
return results;
}

/**
* Returns the serialized size of all non-empty items in an array.
* @param arr - Array
Expand Down Expand Up @@ -121,3 +135,13 @@ export function areArraysEqual<T>(a: T[], b: T[], eq: (a: T, b: T) => boolean =
}
return true;
}

/**
* Returns the element of the array that has the maximum value of the given function.
* In case of a tie, returns the first element with the maximum value.
* @param arr - The array.
* @param fn - The function to get the value to compare.
*/
export function maxBy<T>(arr: T[], fn: (x: T) => number): T | undefined {
return arr.reduce((max, x) => (fn(x) > fn(max) ? x : max), arr[0]);
}
20 changes: 20 additions & 0 deletions yarn-project/prover-client/src/block_builder/light.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,23 @@ export class LightweightBlockBuilderFactory {
return new LightweightBlockBuilder(db, this.telemetry ?? new NoopTelemetryClient());
}
}

/**
* Creates a block builder under the hood with the given txs and messages and creates a block.
* Automatically adds padding txs to get to a minimum of 2 txs in the block.
* @param db - A db fork to use for block building.
*/
export async function buildBlock(
txs: ProcessedTx[],
globalVariables: GlobalVariables,
l1ToL2Messages: Fr[],
db: MerkleTreeWriteOperations,
telemetry: TelemetryClient = new NoopTelemetryClient(),
) {
const builder = new LightweightBlockBuilder(db, telemetry);
await builder.startNewBlock(Math.max(txs.length, 2), globalVariables, l1ToL2Messages);
for (const tx of txs) {
await builder.addNewTx(tx);
}
return await builder.setBlockCompleted();
}
3 changes: 0 additions & 3 deletions yarn-project/prover-client/src/mocks/fixtures.ts
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,6 @@ export async function getSimulationProvider(
return new WASMSimulator();
}

export const makeBloatedProcessedTxWithVKRoot = (builderDb: MerkleTreeReadOperations, seed = 0x1) =>
makeBloatedProcessedTx({ db: builderDb, vkTreeRoot: getVKTreeRoot(), protocolContractTreeRoot, seed });

// Updates the expectedDb trees based on the new note hashes, contracts, and nullifiers from these txs
export const updateExpectedTreesFromTxs = async (db: MerkleTreeWriteOperations, txs: ProcessedTx[]) => {
await db.appendLeaves(
Expand Down
91 changes: 62 additions & 29 deletions yarn-project/prover-client/src/mocks/test_context.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import { type BBProverConfig } from '@aztec/bb-prover';
import {
type ForkMerkleTreeWriteOperations,
type MerkleTreeWriteOperations,
type ProcessedTx,
type ProcessedTxHandler,
type PublicExecutionRequest,
type ServerCircuitProver,
type Tx,
type TxValidator,
} from '@aztec/circuit-types';
import { makeBloatedProcessedTx } from '@aztec/circuit-types/test';
import { type Gas, type GlobalVariables, Header } from '@aztec/circuits.js';
import { times } from '@aztec/foundation/collection';
import { Fr } from '@aztec/foundation/fields';
import { type DebugLogger } from '@aztec/foundation/log';
import { openTmpStore } from '@aztec/kv-store/utils';
import { getVKTreeRoot } from '@aztec/noir-protocol-circuits-types';
import { protocolContractTreeRoot } from '@aztec/protocol-contracts';
import {
PublicProcessor,
PublicTxSimulator,
Expand All @@ -21,30 +22,31 @@ import {
type WorldStateDB,
} from '@aztec/simulator';
import { NoopTelemetryClient } from '@aztec/telemetry-client/noop';
import { MerkleTrees } from '@aztec/world-state';
import { type MerkleTreeAdminDatabase } from '@aztec/world-state';
import { NativeWorldStateService } from '@aztec/world-state/native';

import { jest } from '@jest/globals';
import * as fs from 'fs/promises';
import { type MockProxy, mock } from 'jest-mock-extended';
import { mock } from 'jest-mock-extended';

import { TestCircuitProver } from '../../../bb-prover/src/test/test_circuit_prover.js';
import { AvmFinalizedCallResult } from '../../../simulator/src/avm/avm_contract_call_result.js';
import { type AvmPersistableStateManager } from '../../../simulator/src/avm/journal/journal.js';
import { buildBlock } from '../block_builder/light.js';
import { ProvingOrchestrator } from '../orchestrator/index.js';
import { MemoryProvingQueue } from '../prover-agent/memory-proving-queue.js';
import { ProverAgent } from '../prover-agent/prover-agent.js';
import { getEnvironmentConfig, getSimulationProvider, makeGlobals } from './fixtures.js';

export class TestContext {
private headers: Map<number, Header> = new Map();

constructor(
public publicTxSimulator: PublicTxSimulator,
public worldStateDB: MockProxy<WorldStateDB>,
public worldState: MerkleTreeAdminDatabase,
public publicProcessor: PublicProcessor,
public simulationProvider: SimulationProvider,
public globalVariables: GlobalVariables,
public actualDb: MerkleTreeWriteOperations,
public forksProvider: ForkMerkleTreeWriteOperations,
public prover: ServerCircuitProver,
public proverAgent: ProverAgent,
public orchestrator: ProvingOrchestrator,
Expand All @@ -59,11 +61,10 @@ export class TestContext {

static async new(
logger: DebugLogger,
worldState: 'native' | 'legacy' = 'native',
proverCount = 4,
createProver: (bbConfig: BBProverConfig) => Promise<ServerCircuitProver> = _ =>
Promise.resolve(new TestCircuitProver(new NoopTelemetryClient(), new WASMSimulator())),
blockNumber = 3,
blockNumber = 1,
) {
const directoriesToCleanup: string[] = [];
const globalVariables = makeGlobals(blockNumber);
Expand All @@ -72,21 +73,9 @@ export class TestContext {
const telemetry = new NoopTelemetryClient();

// Separated dbs for public processor and prover - see public_processor for context
let publicDb: MerkleTreeWriteOperations;
let proverDb: MerkleTreeWriteOperations;
let forksProvider: ForkMerkleTreeWriteOperations;

if (worldState === 'native') {
const ws = await NativeWorldStateService.tmp();
publicDb = await ws.fork();
proverDb = await ws.fork();
forksProvider = ws;
} else {
const ws = await MerkleTrees.new(openTmpStore(), telemetry);
publicDb = await ws.getLatest();
proverDb = await ws.getLatest();
forksProvider = ws;
}
const ws = await NativeWorldStateService.tmp();
const publicDb = await ws.fork();

worldStateDB.getMerkleInterface.mockReturnValue(publicDb);

const publicTxSimulator = new PublicTxSimulator(publicDb, worldStateDB, telemetry, globalVariables);
Expand Down Expand Up @@ -123,20 +112,18 @@ export class TestContext {
}

const queue = new MemoryProvingQueue(telemetry);
const orchestrator = new ProvingOrchestrator(forksProvider, queue, telemetry, Fr.ZERO);
const orchestrator = new ProvingOrchestrator(ws, queue, telemetry, Fr.ZERO);
const agent = new ProverAgent(localProver, proverCount);

queue.start();
agent.start(queue);

return new this(
publicTxSimulator,
worldStateDB,
ws,
processor,
simulationProvider,
globalVariables,
proverDb,
forksProvider,
localProver,
agent,
orchestrator,
Expand All @@ -146,13 +133,59 @@ export class TestContext {
);
}

public getFork() {
return this.worldState.fork();
}

public getHeader(blockNumber: 0): Header;
public getHeader(blockNumber: number): Header | undefined;
public getHeader(blockNumber = 0) {
return blockNumber === 0 ? this.worldState.getCommitted().getInitialHeader() : this.headers.get(blockNumber);
}

async cleanup() {
await this.proverAgent.stop();
for (const dir of this.directoriesToCleanup.filter(x => x !== '')) {
await fs.rm(dir, { recursive: true, force: true });
}
}

public makeProcessedTx(opts?: Parameters<typeof makeBloatedProcessedTx>[0]): ProcessedTx;
public makeProcessedTx(seed?: number): ProcessedTx;
public makeProcessedTx(seedOrOpts?: Parameters<typeof makeBloatedProcessedTx>[0] | number): ProcessedTx {
const opts = typeof seedOrOpts === 'number' ? { seed: seedOrOpts } : seedOrOpts;
const blockNum = (opts?.globalVariables ?? this.globalVariables).blockNumber.toNumber();
const header = this.getHeader(blockNum - 1);
return makeBloatedProcessedTx({
header,
vkTreeRoot: getVKTreeRoot(),
protocolContractTreeRoot,
globalVariables: this.globalVariables,
...opts,
});
}

/** Creates a block with the given number of txs and adds it to world-state */
public async makePendingBlock(
numTxs: number,
numMsgs: number = 0,
blockNumOrGlobals: GlobalVariables | number = this.globalVariables,
makeProcessedTxOpts: (index: number) => Partial<Parameters<typeof makeBloatedProcessedTx>[0]> = () => ({}),
) {
const globalVariables = typeof blockNumOrGlobals === 'number' ? makeGlobals(blockNumOrGlobals) : blockNumOrGlobals;
const blockNum = globalVariables.blockNumber.toNumber();
const db = await this.worldState.fork();
const msgs = times(numMsgs, i => new Fr(blockNum * 100 + i));
const txs = times(numTxs, i =>
this.makeProcessedTx({ seed: i + blockNum * 1000, globalVariables, ...makeProcessedTxOpts(i) }),
);

const block = await buildBlock(txs, globalVariables, msgs, db);
this.headers.set(blockNum, block.header);
await this.worldState.handleL2BlockAndMessages(block, msgs);
return { block, txs, msgs };
}

public async processPublicFunctions(
txs: Tx[],
maxTransactions: number,
Expand Down
54 changes: 30 additions & 24 deletions yarn-project/prover-client/src/orchestrator/orchestrator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import {
makeEmptyRecursiveProof,
} from '@aztec/circuits.js';
import { makeTuple } from '@aztec/foundation/array';
import { padArrayEnd } from '@aztec/foundation/collection';
import { maxBy, padArrayEnd } from '@aztec/foundation/collection';
import { AbortError } from '@aztec/foundation/error';
import { createDebugLogger } from '@aztec/foundation/log';
import { promiseWithResolvers } from '@aztec/foundation/promise';
Expand Down Expand Up @@ -162,7 +162,7 @@ export class ProvingOrchestrator implements EpochProver {
}

logger.info(
`Starting block ${globalVariables.blockNumber} for slot ${globalVariables.slotNumber} with ${numTxs} transactions`,
`Starting block ${globalVariables.blockNumber.toNumber()} for slot ${globalVariables.slotNumber.toNumber()} with ${numTxs} transactions`,
);

// Fork world state at the end of the immediately previous block
Expand Down Expand Up @@ -235,34 +235,39 @@ export class ProvingOrchestrator implements EpochProver {
}))
public async addNewTx(tx: ProcessedTx): Promise<void> {
const blockNumber = tx.constants.globalVariables.blockNumber.toNumber();
try {
const provingState = this.provingState?.getBlockProvingStateByBlockNumber(blockNumber);
if (!provingState) {
throw new Error(`Block proving state for ${blockNumber} not found`);
}

const provingState = this.provingState?.getBlockProvingStateByBlockNumber(blockNumber);
if (!provingState) {
throw new Error(`Block proving state for ${blockNumber} not found`);
}

if (!provingState.isAcceptingTransactions()) {
throw new Error(`Rollup not accepting further transactions`);
}
if (!provingState.isAcceptingTransactions()) {
throw new Error(`Rollup not accepting further transactions`);
}

if (!provingState.verifyState()) {
throw new Error(`Invalid proving state when adding a tx`);
}
if (!provingState.verifyState()) {
throw new Error(`Invalid proving state when adding a tx`);
}

validateTx(tx);
validateTx(tx);

logger.info(`Received transaction: ${tx.hash}`);
logger.info(`Received transaction: ${tx.hash}`);

if (tx.isEmpty) {
logger.warn(`Ignoring empty transaction ${tx.hash} - it will not be added to this block`);
return;
}
if (tx.isEmpty) {
logger.warn(`Ignoring empty transaction ${tx.hash} - it will not be added to this block`);
return;
}

const [hints, treeSnapshots] = await this.prepareTransaction(tx, provingState);
this.enqueueFirstProofs(hints, treeSnapshots, tx, provingState);
const [hints, treeSnapshots] = await this.prepareTransaction(tx, provingState);
this.enqueueFirstProofs(hints, treeSnapshots, tx, provingState);

if (provingState.transactionsReceived === provingState.totalNumTxs) {
logger.verbose(`All transactions received for block ${provingState.globalVariables.blockNumber}.`);
if (provingState.transactionsReceived === provingState.totalNumTxs) {
logger.verbose(`All transactions received for block ${provingState.globalVariables.blockNumber}.`);
}
} catch (err: any) {
throw new Error(`Error adding transaction ${tx.hash.toString()} to block ${blockNumber}: ${err.message}`, {
cause: err,
});
}
}

Expand Down Expand Up @@ -348,7 +353,8 @@ export class ProvingOrchestrator implements EpochProver {
})
private padEpoch(): Promise<void> {
const provingState = this.provingState!;
const lastBlock = provingState.blocks.at(-1)?.block;
logger.warn(`DA BLOCKS: ${provingState.blocks.map(b => `${b.blockNumber}: ${!!b.block}`).join(', ')}`);
const lastBlock = maxBy(provingState.blocks, b => b.blockNumber)?.block;
if (!lastBlock) {
return Promise.reject(new Error(`Epoch needs at least one completed block in order to be padded`));
}
Expand Down
Loading

0 comments on commit cc7d996

Please sign in to comment.