diff --git a/yarn-project/foundation/src/error/index.ts b/yarn-project/foundation/src/error/index.ts index 2bc84be567fc..1986e2dec0b9 100644 --- a/yarn-project/foundation/src/error/index.ts +++ b/yarn-project/foundation/src/error/index.ts @@ -9,3 +9,8 @@ export class InterruptError extends Error {} * An error thrown when an action times out. */ export class TimeoutError extends Error {} + +/** + * Represents an error thrown when an operation is aborted. + */ +export class AbortedError extends Error {} diff --git a/yarn-project/prover-client/src/orchestrator/orchestrator.ts b/yarn-project/prover-client/src/orchestrator/orchestrator.ts index d7fada2ebb52..e79ec164e445 100644 --- a/yarn-project/prover-client/src/orchestrator/orchestrator.ts +++ b/yarn-project/prover-client/src/orchestrator/orchestrator.ts @@ -36,6 +36,7 @@ import { } from '@aztec/circuits.js'; import { makeTuple } from '@aztec/foundation/array'; import { padArrayEnd } from '@aztec/foundation/collection'; +import { AbortedError } from '@aztec/foundation/error'; import { createDebugLogger } from '@aztec/foundation/log'; import { promiseWithResolvers } from '@aztec/foundation/promise'; import { type Tuple } from '@aztec/foundation/serialize'; @@ -82,6 +83,8 @@ const KernelTypesWithoutFunctions: Set = new Set( provingState: ProvingState | undefined, - request: () => Promise, + request: (signal: AbortSignal) => Promise, callback: (result: T, durationMs: number) => void | Promise, ) { if (!provingState?.verifyState()) { logger.debug(`Not enqueuing job, state no longer valid`); return; } + + const controller = new AbortController(); + this.pendingProvingJobs.push(controller); + // We use a 'safeJob'. We don't want promise rejections in the proving pool, we want to capture the error here // and reject the proving job whilst keeping the event loop free of rejections const safeJob = async () => { try { + // there's a delay between enqueueing this job and it actually running + if (controller.signal.aborted) { + return; + } + const timer = new Timer(); - const result = await request(); + const result = await request(controller.signal); const duration = timer.ms(); if (!provingState?.verifyState()) { @@ -323,10 +339,27 @@ export class ProvingOrchestrator { return; } + // we could have been cancelled whilst waiting for the result + // and the prover ignored the signal. Drop the result in that case + if (controller.signal.aborted) { + return; + } + await callback(result, duration); } catch (err) { + if (err instanceof AbortedError) { + // operation was cancelled, probably because the block was cancelled + // drop this result + return; + } + logger.error(`Error thrown when proving job`); provingState!.reject(`${err}`); + } finally { + const index = this.pendingProvingJobs.indexOf(controller); + if (index > -1) { + this.pendingProvingJobs.splice(index, 1); + } } }; @@ -441,7 +474,7 @@ export class ProvingOrchestrator { this.deferredProving( provingState, - () => this.prover.getBaseRollupProof(tx.baseRollupInputs), + signal => this.prover.getBaseRollupProof(tx.baseRollupInputs, signal), (result, duration) => { this.emitCircuitSimulationStats( 'base-rollup', @@ -472,7 +505,7 @@ export class ProvingOrchestrator { this.deferredProving( provingState, - () => this.prover.getMergeRollupProof(inputs), + signal => this.prover.getMergeRollupProof(inputs, signal), (result, duration) => { this.emitCircuitSimulationStats( 'merge-rollup', @@ -508,7 +541,7 @@ export class ProvingOrchestrator { this.deferredProving( provingState, - () => this.prover.getRootRollupProof(inputs), + signal => this.prover.getRootRollupProof(inputs, signal), (result, duration) => { this.emitCircuitSimulationStats( 'root-rollup', @@ -533,7 +566,7 @@ export class ProvingOrchestrator { private enqueueBaseParityCircuit(provingState: ProvingState, inputs: BaseParityInputs, index: number) { this.deferredProving( provingState, - () => this.prover.getBaseParityProof(inputs), + signal => this.prover.getBaseParityProof(inputs, signal), (rootInput, duration) => { this.emitCircuitSimulationStats( 'base-parity', @@ -560,7 +593,7 @@ export class ProvingOrchestrator { private enqueueRootParityCircuit(provingState: ProvingState | undefined, inputs: RootParityInputs) { this.deferredProving( provingState, - () => this.prover.getRootParityProof(inputs), + signal => this.prover.getRootParityProof(inputs, signal), async (rootInput, duration) => { this.emitCircuitSimulationStats( 'root-parity', @@ -674,11 +707,11 @@ export class ProvingOrchestrator { this.deferredProving( provingState, - (): Promise> => { + (signal): Promise> => { if (request.type === PublicKernelType.TAIL) { - return this.prover.getPublicTailProof(request); + return this.prover.getPublicTailProof(request, signal); } else { - return this.prover.getPublicKernelProof(request); + return this.prover.getPublicKernelProof(request, signal); } }, (result, duration) => { diff --git a/yarn-project/prover-client/src/orchestrator/orchestrator_lifecycle.test.ts b/yarn-project/prover-client/src/orchestrator/orchestrator_lifecycle.test.ts index 898e3aab9bb5..3b172f6cf316 100644 --- a/yarn-project/prover-client/src/orchestrator/orchestrator_lifecycle.test.ts +++ b/yarn-project/prover-client/src/orchestrator/orchestrator_lifecycle.test.ts @@ -1,11 +1,22 @@ import { PROVING_STATUS, type ProvingFailure } from '@aztec/circuit-types'; -import { type GlobalVariables, NUMBER_OF_L1_L2_MESSAGES_PER_ROLLUP } from '@aztec/circuits.js'; -import { fr } from '@aztec/circuits.js/testing'; +import { + type GlobalVariables, + NUMBER_OF_L1_L2_MESSAGES_PER_ROLLUP, + NUM_BASE_PARITY_PER_ROOT_PARITY, +} from '@aztec/circuits.js'; +import { fr, makeGlobalVariables } from '@aztec/circuits.js/testing'; import { range } from '@aztec/foundation/array'; import { createDebugLogger } from '@aztec/foundation/log'; +import { type PromiseWithResolvers, promiseWithResolvers } from '@aztec/foundation/promise'; +import { sleep } from '@aztec/foundation/sleep'; + +import { jest } from '@jest/globals'; import { makeBloatedProcessedTx, makeEmptyProcessedTestTx, makeGlobals } from '../mocks/fixtures.js'; import { TestContext } from '../mocks/test_context.js'; +import { type CircuitProver } from '../prover/interface.js'; +import { TestCircuitProver } from '../prover/test_circuit_prover.js'; +import { ProvingOrchestrator } from './orchestrator.js'; const logger = createDebugLogger('aztec:orchestrator-lifecycle'); @@ -124,5 +135,27 @@ describe('prover/orchestrator/lifecycle', () => { expect(finalisedBlock.block.number).toEqual(101); }, 60000); + + it('cancels proving requests', async () => { + const prover: CircuitProver = new TestCircuitProver(); + const orchestrator = new ProvingOrchestrator(context.actualDb, prover); + + const spy = jest.spyOn(prover, 'getBaseParityProof'); + const deferredPromises: PromiseWithResolvers[] = []; + spy.mockImplementation(() => { + const deferred = promiseWithResolvers(); + deferredPromises.push(deferred); + return deferred.promise; + }); + await orchestrator.startNewBlock(2, makeGlobalVariables(1), [], await makeEmptyProcessedTestTx(context.actualDb)); + + await sleep(1); + + expect(spy).toHaveBeenCalledTimes(NUM_BASE_PARITY_PER_ROOT_PARITY); + expect(spy.mock.calls.every(([_, signal]) => !signal?.aborted)).toBeTruthy(); + + orchestrator.cancelBlock(); + expect(spy.mock.calls.every(([_, signal]) => signal?.aborted)).toBeTruthy(); + }); }); }); diff --git a/yarn-project/prover-client/src/prover-pool/memory-proving-queue.ts b/yarn-project/prover-client/src/prover-pool/memory-proving-queue.ts index 773af0cfd64d..68b5598bd7b9 100644 --- a/yarn-project/prover-client/src/prover-pool/memory-proving-queue.ts +++ b/yarn-project/prover-client/src/prover-pool/memory-proving-queue.ts @@ -22,7 +22,7 @@ import type { RootRollupInputs, RootRollupPublicInputs, } from '@aztec/circuits.js'; -import { TimeoutError } from '@aztec/foundation/error'; +import { AbortedError, TimeoutError } from '@aztec/foundation/error'; import { MemoryFifo } from '@aztec/foundation/fifo'; import { createDebugLogger } from '@aztec/foundation/log'; import { type PromiseWithResolvers, promiseWithResolvers } from '@aztec/foundation/promise'; @@ -32,6 +32,7 @@ import { type CircuitProver } from '../prover/interface.js'; type ProvingJobWithResolvers = { id: string; request: T; + signal?: AbortSignal; } & PromiseWithResolvers>; export class MemoryProvingQueue implements CircuitProver, ProvingJobSource { @@ -43,7 +44,7 @@ export class MemoryProvingQueue implements CircuitProver, ProvingJobSource { async getProvingJob({ timeoutSec = 1 } = {}): Promise | null> { try { const job = await this.queue.get(timeoutSec); - if (!job) { + if (!job || job.signal?.aborted) { return null; } @@ -68,6 +69,11 @@ export class MemoryProvingQueue implements CircuitProver, ProvingJobSource { } this.jobsInProgress.delete(jobId); + + if (job.signal?.aborted) { + return Promise.resolve(); + } + job.resolve(result); return Promise.resolve(); } @@ -79,20 +85,33 @@ export class MemoryProvingQueue implements CircuitProver, ProvingJobSource { } this.jobsInProgress.delete(jobId); + + if (job.signal?.aborted) { + return Promise.resolve(); + } + job.reject(err); return Promise.resolve(); } - private enqueue(request: T): Promise> { + private enqueue( + request: T, + signal?: AbortSignal, + ): Promise> { const { promise, resolve, reject } = promiseWithResolvers>(); const item: ProvingJobWithResolvers = { id: String(this.jobId++), request, + signal, promise, resolve, reject, }; + if (signal) { + signal.addEventListener('abort', () => reject(new AbortedError('Operation has been aborted'))); + } + this.log.info(`Adding ${ProvingRequestType[request.type]} proving job to queue`); // TODO (alexg) remove the `any` if (!this.queue.put(item as any)) { @@ -106,55 +125,85 @@ export class MemoryProvingQueue implements CircuitProver, ProvingJobSource { * Creates a proof for the given input. * @param input - Input to the circuit. */ - getBaseParityProof(inputs: BaseParityInputs): Promise> { - return this.enqueue({ - type: ProvingRequestType.BASE_PARITY, - inputs, - }); + getBaseParityProof( + inputs: BaseParityInputs, + signal?: AbortSignal, + ): Promise> { + return this.enqueue( + { + type: ProvingRequestType.BASE_PARITY, + inputs, + }, + signal, + ); } /** * Creates a proof for the given input. * @param input - Input to the circuit. */ - getRootParityProof(inputs: RootParityInputs): Promise> { - return this.enqueue({ - type: ProvingRequestType.ROOT_PARITY, - inputs, - }); + getRootParityProof( + inputs: RootParityInputs, + signal?: AbortSignal, + ): Promise> { + return this.enqueue( + { + type: ProvingRequestType.ROOT_PARITY, + inputs, + }, + signal, + ); } /** * Creates a proof for the given input. * @param input - Input to the circuit. */ - getBaseRollupProof(input: BaseRollupInputs): Promise> { - return this.enqueue({ - type: ProvingRequestType.BASE_ROLLUP, - inputs: input, - }); + getBaseRollupProof( + input: BaseRollupInputs, + signal?: AbortSignal, + ): Promise> { + return this.enqueue( + { + type: ProvingRequestType.BASE_ROLLUP, + inputs: input, + }, + signal, + ); } /** * Creates a proof for the given input. * @param input - Input to the circuit. */ - getMergeRollupProof(input: MergeRollupInputs): Promise> { - return this.enqueue({ - type: ProvingRequestType.MERGE_ROLLUP, - inputs: input, - }); + getMergeRollupProof( + input: MergeRollupInputs, + signal?: AbortSignal, + ): Promise> { + return this.enqueue( + { + type: ProvingRequestType.MERGE_ROLLUP, + inputs: input, + }, + signal, + ); } /** * Creates a proof for the given input. * @param input - Input to the circuit. */ - getRootRollupProof(input: RootRollupInputs): Promise> { - return this.enqueue({ - type: ProvingRequestType.ROOT_ROLLUP, - inputs: input, - }); + getRootRollupProof( + input: RootRollupInputs, + signal?: AbortSignal, + ): Promise> { + return this.enqueue( + { + type: ProvingRequestType.ROOT_ROLLUP, + inputs: input, + }, + signal, + ); } /** @@ -163,31 +212,40 @@ export class MemoryProvingQueue implements CircuitProver, ProvingJobSource { */ getPublicKernelProof( kernelRequest: PublicKernelNonTailRequest, + signal?: AbortSignal, ): Promise> { - return this.enqueue({ - type: ProvingRequestType.PUBLIC_KERNEL_NON_TAIL, - kernelType: kernelRequest.type, - inputs: kernelRequest.inputs, - }); + return this.enqueue( + { + type: ProvingRequestType.PUBLIC_KERNEL_NON_TAIL, + kernelType: kernelRequest.type, + inputs: kernelRequest.inputs, + }, + signal, + ); } /** * Create a public kernel tail proof. * @param kernelRequest - Object containing the details of the proof required */ - getPublicTailProof(kernelRequest: PublicKernelTailRequest): Promise> { - return this.enqueue({ - type: ProvingRequestType.PUBLIC_KERNEL_TAIL, - kernelType: kernelRequest.type, - inputs: kernelRequest.inputs, - }); + getPublicTailProof( + kernelRequest: PublicKernelTailRequest, + signal?: AbortSignal, + ): Promise> { + return this.enqueue( + { + type: ProvingRequestType.PUBLIC_KERNEL_TAIL, + kernelType: kernelRequest.type, + inputs: kernelRequest.inputs, + }, + signal, + ); } /** * Verifies a circuit proof */ verifyProof(): Promise { - // no-op - return Promise.resolve(); + return Promise.reject('not implemented'); } } diff --git a/yarn-project/prover-client/src/prover/interface.ts b/yarn-project/prover-client/src/prover/interface.ts index aa24cf6eda84..bed9add0d50f 100644 --- a/yarn-project/prover-client/src/prover/interface.ts +++ b/yarn-project/prover-client/src/prover/interface.ts @@ -69,31 +69,46 @@ export interface CircuitProver { * Creates a proof for the given input. * @param input - Input to the circuit. */ - getBaseParityProof(inputs: BaseParityInputs): Promise>; + getBaseParityProof( + inputs: BaseParityInputs, + signal?: AbortSignal, + ): Promise>; /** * Creates a proof for the given input. * @param input - Input to the circuit. */ - getRootParityProof(inputs: RootParityInputs): Promise>; + getRootParityProof( + inputs: RootParityInputs, + signal?: AbortSignal, + ): Promise>; /** * Creates a proof for the given input. * @param input - Input to the circuit. */ - getBaseRollupProof(input: BaseRollupInputs): Promise>; + getBaseRollupProof( + input: BaseRollupInputs, + signal?: AbortSignal, + ): Promise>; /** * Creates a proof for the given input. * @param input - Input to the circuit. */ - getMergeRollupProof(input: MergeRollupInputs): Promise>; + getMergeRollupProof( + input: MergeRollupInputs, + signal?: AbortSignal, + ): Promise>; /** * Creates a proof for the given input. * @param input - Input to the circuit. */ - getRootRollupProof(input: RootRollupInputs): Promise>; + getRootRollupProof( + input: RootRollupInputs, + signal?: AbortSignal, + ): Promise>; /** * Create a public kernel proof. @@ -101,18 +116,22 @@ export interface CircuitProver { */ getPublicKernelProof( kernelRequest: PublicKernelNonTailRequest, + signal?: AbortSignal, ): Promise>; /** * Create a public kernel tail proof. * @param kernelRequest - Object containing the details of the proof required */ - getPublicTailProof(kernelRequest: PublicKernelTailRequest): Promise>; + getPublicTailProof( + kernelRequest: PublicKernelTailRequest, + signal?: AbortSignal, + ): Promise>; /** * Verifies a circuit proof */ - verifyProof(artifact: ServerProtocolArtifact, proof: Proof): Promise; + verifyProof(artifact: ServerProtocolArtifact, proof: Proof, signal?: AbortSignal): Promise; } /**