Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: abort ongoing proving jobs #6049

Merged
merged 1 commit into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion yarn-project/circuit-types/src/interfaces/proving-job.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ export type ProvingRequestPublicInputs = {
export type ProvingRequestResult<T extends ProvingRequestType> = ProvingRequestPublicInputs[T];

export interface ProvingJobSource {
getProvingJob(): Promise<ProvingJob<ProvingRequest> | null>;
getProvingJob(): Promise<ProvingJob<ProvingRequest> | undefined>;

resolveProvingJob<T extends ProvingRequestType>(jobId: string, result: ProvingRequestResult<T>): Promise<void>;

Expand Down
5 changes: 5 additions & 0 deletions yarn-project/foundation/src/error/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,8 @@ export class InterruptError extends Error {}
* An error thrown when an action times out.
*/
export class TimeoutError extends Error {}

/**
* Represents an error thrown when an operation is aborted.
*/
export class AbortedError extends Error {}
4 changes: 2 additions & 2 deletions yarn-project/prover-client/src/dummy-prover.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ export class DummyProver implements ProverClient {
}

class DummyProvingJobSource implements ProvingJobSource {
getProvingJob(): Promise<ProvingJob<ProvingRequest> | null> {
return Promise.resolve(null);
getProvingJob(): Promise<ProvingJob<ProvingRequest> | undefined> {
return Promise.resolve(undefined);
}

rejectProvingJob(): Promise<void> {
Expand Down
53 changes: 43 additions & 10 deletions yarn-project/prover-client/src/orchestrator/orchestrator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import {
} from '@aztec/circuits.js';
import { makeTuple } from '@aztec/foundation/array';
import { padArrayEnd } from '@aztec/foundation/collection';
import { AbortedError } from '@aztec/foundation/error';
import { createDebugLogger } from '@aztec/foundation/log';
import { promiseWithResolvers } from '@aztec/foundation/promise';
import { type Tuple } from '@aztec/foundation/serialize';
Expand Down Expand Up @@ -82,6 +83,8 @@ const KernelTypesWithoutFunctions: Set<PublicKernelType> = new Set<PublicKernelT
*/
export class ProvingOrchestrator {
private provingState: ProvingState | undefined = undefined;
private pendingProvingJobs: AbortController[] = [];

constructor(private db: MerkleTreeOperations, private prover: CircuitProver) {}

/**
Expand Down Expand Up @@ -211,6 +214,10 @@ export class ProvingOrchestrator {
* Cancel any further proving of the block
*/
public cancelBlock() {
for (const controller of this.pendingProvingJobs) {
controller.abort();
}

this.provingState?.cancel();
}

Expand Down Expand Up @@ -303,30 +310,56 @@ export class ProvingOrchestrator {
*/
private deferredProving<T>(
provingState: ProvingState | undefined,
request: () => Promise<T>,
request: (signal: AbortSignal) => Promise<T>,
callback: (result: T, durationMs: number) => void | Promise<void>,
) {
if (!provingState?.verifyState()) {
logger.debug(`Not enqueuing job, state no longer valid`);
return;
}

const controller = new AbortController();
this.pendingProvingJobs.push(controller);

// We use a 'safeJob'. We don't want promise rejections in the proving pool, we want to capture the error here
// and reject the proving job whilst keeping the event loop free of rejections
const safeJob = async () => {
try {
// there's a delay between enqueueing this job and it actually running
if (controller.signal.aborted) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps some debug logs when we drop jobs due to an abort signal.

return;
}

const timer = new Timer();
const result = await request();
const result = await request(controller.signal);
const duration = timer.ms();

if (!provingState?.verifyState()) {
logger.debug(`State no longer valid, discarding result`);
return;
}

// we could have been cancelled whilst waiting for the result
// and the prover ignored the signal. Drop the result in that case
if (controller.signal.aborted) {
return;
}

await callback(result, duration);
} catch (err) {
if (err instanceof AbortedError) {
// operation was cancelled, probably because the block was cancelled
// drop this result
return;
}

logger.error(`Error thrown when proving job`);
provingState!.reject(`${err}`);
} finally {
const index = this.pendingProvingJobs.indexOf(controller);
if (index > -1) {
this.pendingProvingJobs.splice(index, 1);
}
}
};

Expand Down Expand Up @@ -441,7 +474,7 @@ export class ProvingOrchestrator {

this.deferredProving(
provingState,
() => this.prover.getBaseRollupProof(tx.baseRollupInputs),
signal => this.prover.getBaseRollupProof(tx.baseRollupInputs, signal),
(result, duration) => {
this.emitCircuitSimulationStats(
'base-rollup',
Expand Down Expand Up @@ -472,7 +505,7 @@ export class ProvingOrchestrator {

this.deferredProving(
provingState,
() => this.prover.getMergeRollupProof(inputs),
signal => this.prover.getMergeRollupProof(inputs, signal),
(result, duration) => {
this.emitCircuitSimulationStats(
'merge-rollup',
Expand Down Expand Up @@ -508,7 +541,7 @@ export class ProvingOrchestrator {

this.deferredProving(
provingState,
() => this.prover.getRootRollupProof(inputs),
signal => this.prover.getRootRollupProof(inputs, signal),
(result, duration) => {
this.emitCircuitSimulationStats(
'root-rollup',
Expand All @@ -533,7 +566,7 @@ export class ProvingOrchestrator {
private enqueueBaseParityCircuit(provingState: ProvingState, inputs: BaseParityInputs, index: number) {
this.deferredProving(
provingState,
() => this.prover.getBaseParityProof(inputs),
signal => this.prover.getBaseParityProof(inputs, signal),
(rootInput, duration) => {
this.emitCircuitSimulationStats(
'base-parity',
Expand All @@ -560,7 +593,7 @@ export class ProvingOrchestrator {
private enqueueRootParityCircuit(provingState: ProvingState | undefined, inputs: RootParityInputs) {
this.deferredProving(
provingState,
() => this.prover.getRootParityProof(inputs),
signal => this.prover.getRootParityProof(inputs, signal),
async (rootInput, duration) => {
this.emitCircuitSimulationStats(
'root-parity',
Expand Down Expand Up @@ -674,11 +707,11 @@ export class ProvingOrchestrator {

this.deferredProving(
provingState,
(): Promise<PublicInputsAndProof<KernelCircuitPublicInputs | PublicKernelCircuitPublicInputs>> => {
(signal): Promise<PublicInputsAndProof<KernelCircuitPublicInputs | PublicKernelCircuitPublicInputs>> => {
if (request.type === PublicKernelType.TAIL) {
return this.prover.getPublicTailProof(request);
return this.prover.getPublicTailProof(request, signal);
} else {
return this.prover.getPublicKernelProof(request);
return this.prover.getPublicKernelProof(request, signal);
}
},
(result, duration) => {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
import { PROVING_STATUS, type ProvingFailure } from '@aztec/circuit-types';
import { type GlobalVariables, NUMBER_OF_L1_L2_MESSAGES_PER_ROLLUP } from '@aztec/circuits.js';
import { fr } from '@aztec/circuits.js/testing';
import {
type GlobalVariables,
NUMBER_OF_L1_L2_MESSAGES_PER_ROLLUP,
NUM_BASE_PARITY_PER_ROOT_PARITY,
} from '@aztec/circuits.js';
import { fr, makeGlobalVariables } from '@aztec/circuits.js/testing';
import { range } from '@aztec/foundation/array';
import { createDebugLogger } from '@aztec/foundation/log';
import { type PromiseWithResolvers, promiseWithResolvers } from '@aztec/foundation/promise';
import { sleep } from '@aztec/foundation/sleep';

import { jest } from '@jest/globals';

import { makeBloatedProcessedTx, makeEmptyProcessedTestTx, makeGlobals } from '../mocks/fixtures.js';
import { TestContext } from '../mocks/test_context.js';
import { type CircuitProver } from '../prover/interface.js';
import { TestCircuitProver } from '../prover/test_circuit_prover.js';
import { ProvingOrchestrator } from './orchestrator.js';

const logger = createDebugLogger('aztec:orchestrator-lifecycle');

Expand Down Expand Up @@ -123,6 +134,28 @@ describe('prover/orchestrator/lifecycle', () => {
const finalisedBlock = await context.orchestrator.finaliseBlock();

expect(finalisedBlock.block.number).toEqual(101);
}, 60000);

it('cancels proving requests', async () => {
const prover: CircuitProver = new TestCircuitProver();
const orchestrator = new ProvingOrchestrator(context.actualDb, prover);

const spy = jest.spyOn(prover, 'getBaseParityProof');
const deferredPromises: PromiseWithResolvers<any>[] = [];
spy.mockImplementation(() => {
const deferred = promiseWithResolvers<any>();
deferredPromises.push(deferred);
return deferred.promise;
});
await orchestrator.startNewBlock(2, makeGlobalVariables(1), [], await makeEmptyProcessedTestTx(context.actualDb));

await sleep(1);

expect(spy).toHaveBeenCalledTimes(NUM_BASE_PARITY_PER_ROOT_PARITY);
expect(spy.mock.calls.every(([_, signal]) => !signal?.aborted)).toBeTruthy();

orchestrator.cancelBlock();
expect(spy.mock.calls.every(([_, signal]) => signal?.aborted)).toBeTruthy();
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ describe('MemoryProvingQueue', () => {
expect(job2?.request.type).toEqual(ProvingRequestType.BASE_ROLLUP);
});

it('returns null when no jobs are available', async () => {
await expect(queue.getProvingJob({ timeoutSec: 0 })).resolves.toBeNull();
it('returns undefined when no jobs are available', async () => {
await expect(queue.getProvingJob({ timeoutSec: 0 })).resolves.toBeUndefined();
});

it('notifies of completion', async () => {
Expand Down
Loading
Loading