Skip to content

Commit

Permalink
feat: track timeout status of proving jobs (#6868)
Browse files Browse the repository at this point in the history
This PR modifies the proving queue to expect regular heartbeats from its
consumers. If a job becomes stale because the agent processing it goes
away then it will get re-queued and ready to be picked up by another
agent.
  • Loading branch information
alexghr authored Jun 4, 2024
1 parent dfea1c7 commit 7306176
Show file tree
Hide file tree
Showing 13 changed files with 333 additions and 39 deletions.
4 changes: 4 additions & 0 deletions yarn-project/circuit-types/src/interfaces/prover-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ export type ProverConfig = {
proverAgentPollInterval: number;
/** The maximum number of proving jobs to be run in parallel */
proverAgentConcurrency: number;
/** Jobs are retried if not kept alive for this long */
proverJobTimeoutMs: number;
/** The interval to check job health status */
proverJobPollIntervalMs: number;
};

/**
Expand Down
21 changes: 21 additions & 0 deletions yarn-project/circuit-types/src/interfaces/proving-job.ts
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,30 @@ export type ProvingRequestPublicInputs = {
export type ProvingRequestResult<T extends ProvingRequestType> = ProvingRequestPublicInputs[T];

export interface ProvingJobSource {
/**
* Gets the next proving job. `heartbeat` must be called periodically to keep the job alive.
* @returns The proving job, or undefined if there are no jobs available.
*/
getProvingJob(): Promise<ProvingJob<ProvingRequest> | undefined>;

/**
* Keeps the job alive. If this isn't called regularly then the job will be
* considered abandoned and re-queued for another consumer to pick up
* @param jobId The ID of the job to heartbeat.
*/
heartbeat(jobId: string): Promise<void>;

/**
* Resolves a proving job.
* @param jobId - The ID of the job to resolve.
* @param result - The result of the proving job.
*/
resolveProvingJob<T extends ProvingRequestType>(jobId: string, result: ProvingRequestResult<T>): Promise<void>;

/**
* Rejects a proving job.
* @param jobId - The ID of the job to reject.
* @param reason - The reason for rejecting the job.
*/
rejectProvingJob(jobId: string, reason: Error): Promise<void>;
}
2 changes: 1 addition & 1 deletion yarn-project/foundation/src/error/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ export class TimeoutError extends Error {}
/**
* Represents an error thrown when an operation is aborted.
*/
export class AbortedError extends Error {}
export class AbortError extends Error {}
2 changes: 1 addition & 1 deletion yarn-project/foundation/src/promise/running-promise.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export class RunningPromise {
private runningPromise = Promise.resolve();
private interruptibleSleep = new InterruptibleSleep();

constructor(private fn: () => Promise<void>, private pollingIntervalMS = 10000) {}
constructor(private fn: () => void | Promise<void>, private pollingIntervalMS = 10000) {}

/**
* Starts the running promise.
Expand Down
19 changes: 13 additions & 6 deletions yarn-project/prover-client/src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,16 @@ export function getProverEnvVars(): ProverClientConfig {
PROVER_AGENT_CONCURRENCY = PROVER_AGENTS,
PROVER_AGENT_POLL_INTERVAL_MS = '100',
PROVER_REAL_PROOFS = '',
PROVER_JOB_TIMEOUT_MS = '60000',
PROVER_JOB_POLL_INTERVAL_MS = '1000',
} = process.env;

const realProofs = ['1', 'true'].includes(PROVER_REAL_PROOFS);
const proverAgentEnabled = ['1', 'true'].includes(PROVER_AGENT_ENABLED);
const parsedProverConcurrency = parseInt(PROVER_AGENT_CONCURRENCY, 10);
const proverAgentConcurrency = Number.isSafeInteger(parsedProverConcurrency) ? parsedProverConcurrency : 1;
const parsedProverAgentPollInterval = parseInt(PROVER_AGENT_POLL_INTERVAL_MS, 10);
const proverAgentPollInterval = Number.isSafeInteger(parsedProverAgentPollInterval)
? parsedProverAgentPollInterval
: 100;
const proverAgentConcurrency = safeParseNumber(PROVER_AGENT_CONCURRENCY, 1);
const proverAgentPollInterval = safeParseNumber(PROVER_AGENT_POLL_INTERVAL_MS, 100);
const proverJobTimeoutMs = safeParseNumber(PROVER_JOB_TIMEOUT_MS, 60000);
const proverJobPollIntervalMs = safeParseNumber(PROVER_JOB_POLL_INTERVAL_MS, 1000);

return {
acvmWorkingDirectory: ACVM_WORKING_DIRECTORY,
Expand All @@ -55,5 +55,12 @@ export function getProverEnvVars(): ProverClientConfig {
proverAgentPollInterval,
proverAgentConcurrency,
nodeUrl: AZTEC_NODE_URL,
proverJobPollIntervalMs,
proverJobTimeoutMs,
};
}

function safeParseNumber(value: string, defaultValue: number): number {
const parsedValue = parseInt(value, 10);
return Number.isSafeInteger(parsedValue) ? parsedValue : defaultValue;
}
1 change: 1 addition & 0 deletions yarn-project/prover-client/src/mocks/test_context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ export class TestContext {
const orchestrator = new ProvingOrchestrator(actualDb, queue);
const agent = new ProverAgent(localProver, proverCount);

queue.start();
agent.start(queue);

return new this(
Expand Down
4 changes: 2 additions & 2 deletions yarn-project/prover-client/src/orchestrator/orchestrator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ import {
} from '@aztec/circuits.js';
import { makeTuple } from '@aztec/foundation/array';
import { padArrayEnd } from '@aztec/foundation/collection';
import { AbortedError } from '@aztec/foundation/error';
import { AbortError } from '@aztec/foundation/error';
import { createDebugLogger } from '@aztec/foundation/log';
import { promiseWithResolvers } from '@aztec/foundation/promise';
import { BufferReader, type Tuple } from '@aztec/foundation/serialize';
Expand Down Expand Up @@ -475,7 +475,7 @@ export class ProvingOrchestrator {

await callback(result);
} catch (err) {
if (err instanceof AbortedError) {
if (err instanceof AbortError) {
// operation was cancelled, probably because the block was cancelled
// drop this result
return;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import { type ServerCircuitProver } from '@aztec/circuit-types';
import { RECURSIVE_PROOF_LENGTH, type RootParityInput } from '@aztec/circuits.js';
import { makeBaseParityInputs, makeRootParityInput } from '@aztec/circuits.js/testing';
import { AbortError } from '@aztec/foundation/error';
import { promiseWithResolvers } from '@aztec/foundation/promise';
import { sleep } from '@aztec/foundation/sleep';

import { type MockProxy, mock } from 'jest-mock-extended';

import { MemoryProvingQueue } from './memory-proving-queue.js';
import { ProverAgent } from './prover-agent.js';

describe('Prover agent <-> queue integration', () => {
let queue: MemoryProvingQueue;
let agent: ProverAgent;
let prover: MockProxy<ServerCircuitProver>;
let agentPollInterval: number;
let queuePollInterval: number;
let queueJobTimeout: number;

beforeEach(() => {
prover = mock<ServerCircuitProver>();

queueJobTimeout = 100;
queuePollInterval = 10;
queue = new MemoryProvingQueue(queueJobTimeout, queuePollInterval);

agentPollInterval = 10;
agent = new ProverAgent(prover, 1, agentPollInterval);

queue.start();
agent.start(queue);
});

afterEach(async () => {
await agent.stop();
await queue.stop();
});

it('picks up jobs from the queue', async () => {
const { promise, resolve } = promiseWithResolvers<RootParityInput<typeof RECURSIVE_PROOF_LENGTH>>();
const output = makeRootParityInput(RECURSIVE_PROOF_LENGTH, 1);
prover.getBaseParityProof.mockResolvedValueOnce(promise);
const proofPromise = queue.getBaseParityProof(makeBaseParityInputs());

await sleep(agentPollInterval);
resolve(output);
await expect(proofPromise).resolves.toEqual(output);
});

it('keeps job alive', async () => {
const { promise, resolve } = promiseWithResolvers<RootParityInput<typeof RECURSIVE_PROOF_LENGTH>>();
const output = makeRootParityInput(RECURSIVE_PROOF_LENGTH, 1);
prover.getBaseParityProof.mockResolvedValueOnce(promise);
const proofPromise = queue.getBaseParityProof(makeBaseParityInputs());

await sleep(2 * queueJobTimeout);
resolve(output);
await expect(proofPromise).resolves.toEqual(output);
});

it('reports cancellations', async () => {
const { promise, resolve } = promiseWithResolvers<RootParityInput<typeof RECURSIVE_PROOF_LENGTH>>();
const output = makeRootParityInput(RECURSIVE_PROOF_LENGTH, 1);
prover.getBaseParityProof.mockResolvedValueOnce(promise);
const controller = new AbortController();
const proofPromise = queue.getBaseParityProof(makeBaseParityInputs(), controller.signal);
await sleep(agentPollInterval);
controller.abort();
resolve(output);
await expect(proofPromise).rejects.toThrow(AbortError);
});

it('re-queues timed out jobs', async () => {
const firstRun = promiseWithResolvers<RootParityInput<typeof RECURSIVE_PROOF_LENGTH>>();
const output = makeRootParityInput(RECURSIVE_PROOF_LENGTH, 1);
prover.getBaseParityProof.mockResolvedValueOnce(firstRun.promise);
const proofPromise = queue.getBaseParityProof(makeBaseParityInputs());

// stop the agent to simulate a machine going down
await agent.stop();

// give the queue a chance to figure out the node is timed out and re-queue the job
await sleep(queueJobTimeout);
// reset the mock
const secondRun = promiseWithResolvers<RootParityInput<typeof RECURSIVE_PROOF_LENGTH>>();
prover.getBaseParityProof.mockResolvedValueOnce(secondRun.promise);
const newAgent = new ProverAgent(prover, 1, agentPollInterval);
newAgent.start(queue);
// test that the job is re-queued and kept alive by the new agent
await sleep(queueJobTimeout * 2);
secondRun.resolve(output);
await expect(proofPromise).resolves.toEqual(output);

firstRun.reject(new Error('stop this promise otherwise it hangs jest'));

await newAgent.stop();
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,25 @@ import {
makeRecursiveProof,
} from '@aztec/circuits.js';
import { makeBaseParityInputs, makeBaseRollupInputs, makeParityPublicInputs } from '@aztec/circuits.js/testing';
import { AbortError } from '@aztec/foundation/error';
import { sleep } from '@aztec/foundation/sleep';

import { MemoryProvingQueue } from './memory-proving-queue.js';

describe('MemoryProvingQueue', () => {
let queue: MemoryProvingQueue;
let jobTimeoutMs: number;
let pollingIntervalMs: number;

beforeEach(() => {
queue = new MemoryProvingQueue();
jobTimeoutMs = 100;
pollingIntervalMs = 10;
queue = new MemoryProvingQueue(jobTimeoutMs, pollingIntervalMs);
queue.start();
});

afterEach(async () => {
await queue.stop();
});

it('returns jobs in order', async () => {
Expand Down Expand Up @@ -68,4 +79,39 @@ describe('MemoryProvingQueue', () => {

await expect(promise).rejects.toEqual(error);
});

it('reaps timed out jobs', async () => {
const controller = new AbortController();
const promise = queue.getBaseParityProof(makeBaseParityInputs(), controller.signal);
const job = await queue.getProvingJob();

expect(queue.isJobRunning(job!.id)).toBe(true);
await sleep(jobTimeoutMs + 2 * pollingIntervalMs);
expect(queue.isJobRunning(job!.id)).toBe(false);

controller.abort();
await expect(promise).rejects.toThrow(AbortError);
});

it('keeps jobs running while heartbeat is called', async () => {
const promise = queue.getBaseParityProof(makeBaseParityInputs());
const job = await queue.getProvingJob();

expect(queue.isJobRunning(job!.id)).toBe(true);
await sleep(pollingIntervalMs);
expect(queue.isJobRunning(job!.id)).toBe(true);

await queue.heartbeat(job!.id);
expect(queue.isJobRunning(job!.id)).toBe(true);
await sleep(pollingIntervalMs);
expect(queue.isJobRunning(job!.id)).toBe(true);

const output = new RootParityInput(
makeRecursiveProof(RECURSIVE_PROOF_LENGTH),
VerificationKeyAsFields.makeFake(),
makeParityPublicInputs(),
);
await queue.resolveProvingJob(job!.id, output);
await expect(promise).resolves.toEqual(output);
});
});
Loading

0 comments on commit 7306176

Please sign in to comment.