Skip to content

Commit

Permalink
refactor: orchestrator talks directly to broker
Browse files Browse the repository at this point in the history
  • Loading branch information
alexghr committed Dec 13, 2024
1 parent 37e1999 commit 440a8ef
Show file tree
Hide file tree
Showing 22 changed files with 365 additions and 525 deletions.
6 changes: 0 additions & 6 deletions yarn-project/circuit-types/src/interfaces/prover-broker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,6 @@ export interface ProvingJobProducer {
*/
cancelProvingJob(id: ProvingJobId): Promise<void>;

/**
* Cleans up after a job has completed. Throws if the job is in-progress
* @param id - The ID of the job to cancel
*/
cleanUpProvingJobState(id: ProvingJobId): Promise<void>;

/**
* Returns the current status fof the proving job
* @param id - The ID of the job to get the status of
Expand Down
35 changes: 1 addition & 34 deletions yarn-project/circuit-types/src/interfaces/prover-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import { z } from 'zod';
import { type TxHash } from '../tx/tx_hash.js';
import { type EpochProver } from './epoch-prover.js';
import { type ProvingJobConsumer } from './prover-broker.js';
import { type ProvingJobStatus } from './proving-job.js';

export type ActualProverConfig = {
/** Whether to construct real proofs */
Expand All @@ -24,9 +23,6 @@ export type ProverConfig = ActualProverConfig & {
nodeUrl?: string;
/** Identifier of the prover */
proverId: Fr;
/** Where to store temporary data */
cacheDir?: string;

proverAgentCount: number;
};

Expand All @@ -35,7 +31,6 @@ export const ProverConfigSchema = z.object({
realProofs: z.boolean(),
proverId: schemas.Fr,
proverTestDelayMs: z.number(),
cacheDir: z.string().optional(),
proverAgentCount: z.number(),
}) satisfies ZodFor<ProverConfig>;

Expand All @@ -60,11 +55,6 @@ export const proverConfigMappings: ConfigMappingsType<ProverConfig> = {
description: 'Artificial delay to introduce to all operations to the test prover.',
...numberConfigHelper(0),
},
cacheDir: {
env: 'PROVER_CACHE_DIR',
description: 'Where to store cache data generated while proving',
defaultValue: '/tmp/aztec-prover',
},
proverAgentCount: {
env: 'PROVER_AGENT_COUNT',
description: 'The number of prover agents to start',
Expand All @@ -76,35 +66,12 @@ function parseProverId(str: string) {
return Fr.fromHexString(str.startsWith('0x') ? str : Buffer.from(str, 'utf8').toString('hex'));
}

/**
* A database where the proving orchestrator can store intermediate results
*/
export interface ProverCache {
/**
* Saves the status of a proving job
* @param jobId - The job ID
* @param status - The status of the proof
*/
setProvingJobStatus(jobId: string, status: ProvingJobStatus): Promise<void>;

/**
* Retrieves the status of a proving job (if known)
* @param jobId - The job ID
*/
getProvingJobStatus(jobId: string): Promise<ProvingJobStatus>;

/**
* Closes the cache
*/
close(): Promise<void>;
}

/**
* The interface to the prover client.
* Provides the ability to generate proofs and build rollups.
*/
export interface EpochProverManager {
createEpochProver(cache?: ProverCache): EpochProver;
createEpochProver(): EpochProver;

start(): Promise<void>;

Expand Down
1 change: 0 additions & 1 deletion yarn-project/foundation/src/config/env_var.ts
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ export type EnvVar =
| 'PROVER_REAL_PROOFS'
| 'PROVER_REQUIRED_CONFIRMATIONS'
| 'PROVER_TEST_DELAY_MS'
| 'PROVER_CACHE_DIR'
| 'PXE_L2_STARTING_BLOCK'
| 'PXE_PROVER_ENABLED'
| 'QUOTE_PROVIDER_BASIS_POINT_FEE'
Expand Down
4 changes: 4 additions & 0 deletions yarn-project/foundation/src/string/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,7 @@ export function pluralize(str: string, count: number | bigint, plural?: string):
export function count(count: number | bigint, str: string, plural?: string): string {
return `${count} ${pluralize(str, count, plural)}`;
}

export function truncate(str: string, length: number = 64): string {
return str.length > length ? str.slice(0, length) + '...' : str;
}
1 change: 0 additions & 1 deletion yarn-project/prover-client/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,3 @@ export { EpochProverManager } from '@aztec/circuit-types';

export * from './prover-client/index.js';
export * from './config.js';
export * from './proving_broker/prover_cache/memory.js';
13 changes: 3 additions & 10 deletions yarn-project/prover-client/src/prover-client/prover-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import {
type EpochProver,
type EpochProverManager,
type ForkMerkleTreeOperations,
type ProverCache,
type ProvingJobBroker,
type ProvingJobConsumer,
type ProvingJobProducer,
Expand All @@ -16,22 +15,17 @@ import { createLogger } from '@aztec/foundation/log';
import { NativeACVMSimulator } from '@aztec/simulator';
import { type TelemetryClient } from '@aztec/telemetry-client';

import { join } from 'path';

import { type ProverClientConfig } from '../config.js';
import { ProvingOrchestrator } from '../orchestrator/orchestrator.js';
import { CachingBrokerFacade } from '../proving_broker/caching_broker_facade.js';
import { BrokerCircuitProverFacade } from '../proving_broker/broker_prover_facade.js';
import { InlineProofStore } from '../proving_broker/proof_store.js';
import { InMemoryProverCache } from '../proving_broker/prover_cache/memory.js';
import { ProvingAgent } from '../proving_broker/proving_agent.js';

/** Manages proving of epochs by orchestrating the proving of individual blocks relying on a pool of prover agents. */
export class ProverClient implements EpochProverManager {
private running = false;
private agents: ProvingAgent[] = [];

private cacheDir?: string;

private constructor(
private config: ProverClientConfig,
private worldState: ForkMerkleTreeOperations,
Expand All @@ -42,13 +36,12 @@ export class ProverClient implements EpochProverManager {
) {
// TODO(palla/prover-node): Cache the paddingTx here, and not in each proving orchestrator,
// so it can be reused across multiple ones and not recomputed every time.
this.cacheDir = this.config.cacheDir ? join(this.config.cacheDir, `tx_prover_${this.config.proverId}`) : undefined;
}

public createEpochProver(cache: ProverCache = new InMemoryProverCache()): EpochProver {
public createEpochProver(): EpochProver {
return new ProvingOrchestrator(
this.worldState,
new CachingBrokerFacade(this.orchestratorClient, cache),
new BrokerCircuitProverFacade(this.orchestratorClient),
this.telemetry,
this.config.proverId,
);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import { TestCircuitProver } from '@aztec/bb-prover';
import { ProvingJobProducer, makePublicInputsAndRecursiveProof } from '@aztec/circuit-types';
import { RECURSIVE_PROOF_LENGTH, VerificationKeyData, makeRecursiveProof } from '@aztec/circuits.js';
import { makeBaseParityInputs, makeParityPublicInputs } from '@aztec/circuits.js/testing';
import { AbortError } from '@aztec/foundation/error';
import { promiseWithResolvers } from '@aztec/foundation/promise';
import { sleep } from '@aztec/foundation/sleep';

import { jest } from '@jest/globals';
import { MockProxy } from 'jest-mock-extended';
import { mock } from 'jest-mock-extended';

import { MockProver, TestBroker } from '../test/mock_prover.js';
import { BrokerCircuitProverFacade } from './broker_prover_facade.js';
import { InlineProofStore } from './proof_store.js';

describe('BrokerCircuitProverFacade', () => {
let facade: BrokerCircuitProverFacade;
let proofStore: InlineProofStore;
let broker: TestBroker;
let prover: MockProver;
let agentPollInterval: number;

beforeEach(async () => {
proofStore = new InlineProofStore();
prover = new MockProver();
agentPollInterval = 100;
broker = new TestBroker(2, prover, proofStore, agentPollInterval);
facade = new BrokerCircuitProverFacade(broker, proofStore);

await broker.start();
});

it('sends jobs to the broker', async () => {
const inputs = makeBaseParityInputs();
const controller = new AbortController();

jest.spyOn(broker, 'enqueueProvingJob');
jest.spyOn(prover, 'getBaseParityProof');

await expect(facade.getBaseParityProof(inputs, controller.signal, 42)).resolves.toBeDefined();

expect(broker.enqueueProvingJob).toHaveBeenCalled();
expect(prover.getBaseParityProof).toHaveBeenCalledWith(inputs, expect.anything(), 42);
});

it('handles multiple calls for the same job', async () => {
const inputs = makeBaseParityInputs();
const controller = new AbortController();
const promises: Promise<any>[] = [];

const resultPromise = promiseWithResolvers<any>();
jest.spyOn(broker, 'enqueueProvingJob');
jest.spyOn(prover, 'getBaseParityProof').mockReturnValue(resultPromise.promise);

// send N identical proof requests
const CALLS = 50;
for (let i = 0; i < CALLS; i++) {
promises.push(facade.getBaseParityProof(inputs, controller.signal, 42));
}

await sleep(agentPollInterval);
// the broker should have received all of them
expect(broker.enqueueProvingJob).toHaveBeenCalledTimes(CALLS);

// but really, it should have only enqueued just one
expect(prover.getBaseParityProof).toHaveBeenCalledTimes(1);
expect(prover.getBaseParityProof).toHaveBeenCalledWith(inputs, expect.anything(), 42);

// now we have 50 promises all waiting on the same result
// resolve the proof
const result = makePublicInputsAndRecursiveProof(
makeParityPublicInputs(),
makeRecursiveProof(RECURSIVE_PROOF_LENGTH),
VerificationKeyData.makeFakeHonk(),
);
resultPromise.resolve(result);

// enqueue another N requests for the same jobs
for (let i = 0; i < CALLS; i++) {
promises.push(facade.getBaseParityProof(inputs, controller.signal, 42));
}

await sleep(agentPollInterval);
// the broker will have received the new requests
expect(broker.enqueueProvingJob).toHaveBeenCalledTimes(2 * CALLS);
// but no new jobs where created
expect(prover.getBaseParityProof).toHaveBeenCalledTimes(1);

// and all 2 * N requests will have been resolved with the same result
for (const promise of promises) {
await expect(promise).resolves.toEqual(result);
}
});

it('handles proof errors', async () => {
const inputs = makeBaseParityInputs();
const controller = new AbortController();
const promises: Promise<any>[] = [];

const resultPromise = promiseWithResolvers<any>();
jest.spyOn(broker, 'enqueueProvingJob');
jest.spyOn(prover, 'getBaseParityProof').mockReturnValue(resultPromise.promise);

// send N identical proof requests
const CALLS = 50;
for (let i = 0; i < CALLS; i++) {
// wrap the error in a resolved promises so that we don't have unhandled rejections
promises.push(facade.getBaseParityProof(inputs, controller.signal, 42).catch(err => ({ err })));
}

await sleep(agentPollInterval);
// the broker should have received all of them
expect(broker.enqueueProvingJob).toHaveBeenCalledTimes(CALLS);

// but really, it should have only enqueued just one
expect(prover.getBaseParityProof).toHaveBeenCalledTimes(1);
expect(prover.getBaseParityProof).toHaveBeenCalledWith(inputs, expect.anything(), 42);

resultPromise.reject(new Error('TEST ERROR'));

// enqueue another N requests for the same jobs
for (let i = 0; i < CALLS; i++) {
promises.push(facade.getBaseParityProof(inputs, controller.signal, 42).catch(err => ({ err })));
}

await sleep(agentPollInterval);
// the broker will have received the new requests
expect(broker.enqueueProvingJob).toHaveBeenCalledTimes(2 * CALLS);
// but no new jobs where created
expect(prover.getBaseParityProof).toHaveBeenCalledTimes(1);

// and all 2 * N requests will have been resolved with the same result
for (const promise of promises) {
await expect(promise).resolves.toEqual({ err: new Error('TEST ERROR') });
}
});

it('handles aborts', async () => {
const inputs = makeBaseParityInputs();
const controller = new AbortController();

const resultPromise = promiseWithResolvers<any>();
jest.spyOn(broker, 'enqueueProvingJob');
jest.spyOn(prover, 'getBaseParityProof').mockReturnValue(resultPromise.promise);

const promise = facade.getBaseParityProof(inputs, controller.signal, 42).catch(err => ({ err }));

await sleep(agentPollInterval);
expect(prover.getBaseParityProof).toHaveBeenCalled();

controller.abort();

await expect(promise).resolves.toEqual({ err: new Error('Aborted') });
});
});
Loading

0 comments on commit 440a8ef

Please sign in to comment.