Skip to content

Commit

Permalink
feat: add new proving agent
Browse files Browse the repository at this point in the history
  • Loading branch information
alexghr committed Nov 15, 2024
1 parent 05e4b27 commit 1abdd86
Show file tree
Hide file tree
Showing 5 changed files with 586 additions and 11 deletions.
213 changes: 213 additions & 0 deletions yarn-project/prover-client/src/proving_broker/proving_agent.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
import {
ProvingRequestType,
type PublicInputsAndRecursiveProof,
type V2ProvingJob,
type V2ProvingJobId,
makePublicInputsAndRecursiveProof,
} from '@aztec/circuit-types';
import {
type ParityPublicInputs,
RECURSIVE_PROOF_LENGTH,
VerificationKeyData,
makeRecursiveProof,
} from '@aztec/circuits.js';
import { makeBaseParityInputs, makeParityPublicInputs } from '@aztec/circuits.js/testing';
import { randomBytes } from '@aztec/foundation/crypto';
import { AbortError } from '@aztec/foundation/error';
import { promiseWithResolvers } from '@aztec/foundation/promise';

import { jest } from '@jest/globals';

import { MockProver } from '../test/mock_prover.js';
import { ProvingAgent } from './proving_agent.js';
import { type ProvingJobConsumer } from './proving_broker_interface.js';

describe('ProvingAgent', () => {
let prover: MockProver;
let jobSource: jest.Mocked<ProvingJobConsumer>;
let agent: ProvingAgent;
const agentPollIntervalMs = 1000;

beforeEach(() => {
jest.useFakeTimers();

prover = new MockProver();
jobSource = {
getProvingJob: jest.fn(),
reportProvingJobProgress: jest.fn(),
reportProvingJobError: jest.fn(),
reportProvingJobSuccess: jest.fn(),
};
agent = new ProvingAgent(jobSource, prover, [ProvingRequestType.BASE_PARITY]);
});

afterEach(async () => {
await agent.stop();
});

it('polls for jobs passing the permitted list of proofs', () => {
agent.start();
expect(jobSource.getProvingJob).toHaveBeenCalledWith({ allowList: [ProvingRequestType.BASE_PARITY] });
});

it('only takes a single job from the source at a time', async () => {
expect(jobSource.getProvingJob).not.toHaveBeenCalled();

// simulate the proof taking a long time
const { promise, resolve } =
promiseWithResolvers<PublicInputsAndRecursiveProof<ParityPublicInputs, typeof RECURSIVE_PROOF_LENGTH>>();
jest.spyOn(prover, 'getBaseParityProof').mockReturnValueOnce(promise);

const jobResponse = makeBaseParityJob();
jobSource.getProvingJob.mockResolvedValueOnce(jobResponse);
agent.start();

await jest.advanceTimersByTimeAsync(agentPollIntervalMs);
expect(jobSource.getProvingJob).toHaveBeenCalledTimes(1);

await jest.advanceTimersByTimeAsync(agentPollIntervalMs);
expect(jobSource.getProvingJob).toHaveBeenCalledTimes(1);

await jest.advanceTimersByTimeAsync(agentPollIntervalMs);
expect(jobSource.getProvingJob).toHaveBeenCalledTimes(1);

// let's resolve the proof
const result = makePublicInputsAndRecursiveProof(
makeParityPublicInputs(),
makeRecursiveProof(RECURSIVE_PROOF_LENGTH),
VerificationKeyData.makeFakeHonk(),
);
resolve(result);

await jest.advanceTimersByTimeAsync(agentPollIntervalMs);
expect(jobSource.getProvingJob).toHaveBeenCalledTimes(2);
});

it('reports success to the job source', async () => {
const jobResponse = makeBaseParityJob();
const result = makeBaseParityResult();
jest.spyOn(prover, 'getBaseParityProof').mockResolvedValueOnce(result.value);

jobSource.getProvingJob.mockResolvedValueOnce(jobResponse);
agent.start();

await jest.advanceTimersByTimeAsync(agentPollIntervalMs);
expect(jobSource.reportProvingJobSuccess).toHaveBeenCalledWith(jobResponse.job.id, result);
});

it('reports errors to the job source', async () => {
const jobResponse = makeBaseParityJob();
jest.spyOn(prover, 'getBaseParityProof').mockRejectedValueOnce(new Error('test error'));

jobSource.getProvingJob.mockResolvedValueOnce(jobResponse);
agent.start();

await jest.advanceTimersByTimeAsync(agentPollIntervalMs);
expect(jobSource.reportProvingJobError).toHaveBeenCalledWith(jobResponse.job.id, new Error('test error'));
});

it('reports jobs in progress to the job source', async () => {
const jobResponse = makeBaseParityJob();
const { promise, resolve } =
promiseWithResolvers<PublicInputsAndRecursiveProof<ParityPublicInputs, typeof RECURSIVE_PROOF_LENGTH>>();
jest.spyOn(prover, 'getBaseParityProof').mockReturnValueOnce(promise);

jobSource.getProvingJob.mockResolvedValueOnce(jobResponse);
agent.start();

await jest.advanceTimersByTimeAsync(agentPollIntervalMs);
expect(jobSource.reportProvingJobProgress).toHaveBeenCalledWith(jobResponse.job.id, jobResponse.time, {
allowList: [ProvingRequestType.BASE_PARITY],
});

await jest.advanceTimersByTimeAsync(agentPollIntervalMs);
expect(jobSource.reportProvingJobProgress).toHaveBeenCalledWith(jobResponse.job.id, jobResponse.time, {
allowList: [ProvingRequestType.BASE_PARITY],
});

resolve(makeBaseParityResult().value);
});

it('abandons jobs if told so by the source', async () => {
const firstJobResponse = makeBaseParityJob();
let firstProofAborted = false;
const firstProof =
promiseWithResolvers<PublicInputsAndRecursiveProof<ParityPublicInputs, typeof RECURSIVE_PROOF_LENGTH>>();

// simulate a long running proving job that can be aborted
jest.spyOn(prover, 'getBaseParityProof').mockImplementationOnce((_, signal) => {
signal?.addEventListener('abort', () => {
firstProof.reject(new AbortError('test abort'));
firstProofAborted = true;
});
return firstProof.promise;
});

jobSource.getProvingJob.mockResolvedValueOnce(firstJobResponse);
agent.start();

// now the agent should be happily proving and reporting progress
await jest.advanceTimersByTimeAsync(agentPollIntervalMs);
expect(jobSource.reportProvingJobProgress).toHaveBeenCalledTimes(1);
expect(jobSource.reportProvingJobProgress).toHaveBeenCalledWith(firstJobResponse.job.id, firstJobResponse.time, {
allowList: [ProvingRequestType.BASE_PARITY],
});

await jest.advanceTimersByTimeAsync(agentPollIntervalMs);
expect(jobSource.reportProvingJobProgress).toHaveBeenCalledTimes(2);

// now let's simulate the job source cancelling the job and giving the agent something else to do
// this should cause the agent to abort the current job and start the new one
const secondJobResponse = makeBaseParityJob();
jobSource.reportProvingJobProgress.mockResolvedValueOnce(secondJobResponse);

const secondProof =
promiseWithResolvers<PublicInputsAndRecursiveProof<ParityPublicInputs, typeof RECURSIVE_PROOF_LENGTH>>();
jest.spyOn(prover, 'getBaseParityProof').mockReturnValueOnce(secondProof.promise);

await jest.advanceTimersByTimeAsync(agentPollIntervalMs);
expect(jobSource.reportProvingJobProgress).toHaveBeenCalledTimes(3);
expect(jobSource.reportProvingJobProgress).toHaveBeenLastCalledWith(
firstJobResponse.job.id,
firstJobResponse.time,
{
allowList: [ProvingRequestType.BASE_PARITY],
},
);
expect(firstProofAborted).toBe(true);

// agent should have switched now
await jest.advanceTimersByTimeAsync(agentPollIntervalMs);
expect(jobSource.reportProvingJobProgress).toHaveBeenCalledTimes(4);
expect(jobSource.reportProvingJobProgress).toHaveBeenLastCalledWith(
secondJobResponse.job.id,
secondJobResponse.time,
{
allowList: [ProvingRequestType.BASE_PARITY],
},
);

secondProof.resolve(makeBaseParityResult().value);
});

function makeBaseParityJob(): { job: V2ProvingJob; time: number } {
const time = jest.now();
const job: V2ProvingJob = {
id: randomBytes(8).toString('hex') as V2ProvingJobId,
blockNumber: 1,
type: ProvingRequestType.BASE_PARITY,
inputs: makeBaseParityInputs(),
};

return { job, time };
}

function makeBaseParityResult() {
const value = makePublicInputsAndRecursiveProof(
makeParityPublicInputs(),
makeRecursiveProof(RECURSIVE_PROOF_LENGTH),
VerificationKeyData.makeFakeHonk(),
);
return { type: ProvingRequestType.BASE_PARITY, value };
}
});
84 changes: 84 additions & 0 deletions yarn-project/prover-client/src/proving_broker/proving_agent.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import { type ProvingRequestType, type ServerCircuitProver, type V2ProvingJob } from '@aztec/circuit-types';
import { createDebugLogger } from '@aztec/foundation/log';
import { RunningPromise } from '@aztec/foundation/running-promise';

import { type ProvingJobConsumer } from './proving_broker_interface.js';
import { ProvingJobController, ProvingJobStatus } from './proving_job_controller.js';

/**
* A helper class that encapsulates a circuit prover and connects it to a job source.
*/
export class ProvingAgent {
private currentJobController?: ProvingJobController;
private runningPromise: RunningPromise;

constructor(
/** The source of proving jobs */
private jobSource: ProvingJobConsumer,
/** The prover implementation to defer jobs to */
private circuitProver: ServerCircuitProver,
/** Optional list of allowed proof types to build */
private proofAllowList?: Array<ProvingRequestType>,
/** How long to wait between jobs */
private pollIntervalMs = 1000,
private log = createDebugLogger('aztec:proving-broker:proving-agent'),
) {
this.runningPromise = new RunningPromise(this.safeWork, this.pollIntervalMs);
}

public setCircuitProver(circuitProver: ServerCircuitProver): void {
this.circuitProver = circuitProver;
}

public isRunning(): boolean {
return this.runningPromise?.isRunning() ?? false;
}

public start(): void {
this.runningPromise.start();
}

public async stop(): Promise<void> {
this.currentJobController?.abort();
await this.runningPromise.stop();
}

private safeWork = async () => {
try {
// every tick we need to
// (1) either do a heartbeat, telling the broker that we're working
// (2) get a new job
// If during (1) the broker returns a new job that means we can cancel the current job and start the new one
let maybeJob: { job: V2ProvingJob; time: number } | undefined;
if (this.currentJobController?.getStatus() === ProvingJobStatus.PROVING) {
maybeJob = await this.jobSource.reportProvingJobProgress(
this.currentJobController.getJobId(),
this.currentJobController.getStartedAt(),
{ allowList: this.proofAllowList },
);
} else {
maybeJob = await this.jobSource.getProvingJob({ allowList: this.proofAllowList });
}

if (!maybeJob) {
return;
}

if (this.currentJobController?.getStatus() === ProvingJobStatus.PROVING) {
this.currentJobController?.abort();
}

const { job, time } = maybeJob;
this.currentJobController = new ProvingJobController(job, time, this.circuitProver, (err, result) => {
if (err) {
return this.jobSource.reportProvingJobError(job.id, err);
} else if (result) {
return this.jobSource.reportProvingJobSuccess(job.id, result);
}
});
this.currentJobController.start();
} catch (err) {
this.log.error(`Error in ProvingAgent: ${String(err)}`);
}
};
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import { ProvingRequestType, type V2ProvingJobId, makePublicInputsAndRecursiveProof } from '@aztec/circuit-types';
import { RECURSIVE_PROOF_LENGTH, VerificationKeyData, makeRecursiveProof } from '@aztec/circuits.js';
import { makeBaseParityInputs, makeParityPublicInputs } from '@aztec/circuits.js/testing';
import { sleep } from '@aztec/foundation/sleep';

import { jest } from '@jest/globals';

import { MockProver } from '../test/mock_prover.js';
import { ProvingJobController, ProvingJobStatus } from './proving_job_controller.js';

describe('ProvingJobController', () => {
let prover: MockProver;
let onComplete: jest.Mock<any>;
let controller: ProvingJobController;

beforeEach(() => {
prover = new MockProver();
onComplete = jest.fn();
controller = new ProvingJobController(
{
type: ProvingRequestType.BASE_PARITY,
blockNumber: 1,
id: '1' as V2ProvingJobId,
inputs: makeBaseParityInputs(),
},
0,
prover,
onComplete,
);
});

it('reports IDLE status initially', () => {
expect(controller.getStatus()).toBe(ProvingJobStatus.IDLE);
});

it('reports PROVING status while busy', () => {
controller.start();
expect(controller.getStatus()).toBe(ProvingJobStatus.PROVING);
});

it('reports DONE status after job is done', async () => {
controller.start();
await sleep(1); // give promises a chance to complete
expect(controller.getStatus()).toBe(ProvingJobStatus.DONE);
});

it('calls onComplete with the proof', async () => {
const resp = makePublicInputsAndRecursiveProof(
makeParityPublicInputs(),
makeRecursiveProof(RECURSIVE_PROOF_LENGTH),
VerificationKeyData.makeFakeHonk(),
);
jest.spyOn(prover, 'getBaseParityProof').mockResolvedValueOnce(resp);

controller.start();
await sleep(1); // give promises a chance to complete
expect(onComplete).toHaveBeenCalledWith(undefined, {
type: ProvingRequestType.BASE_PARITY,
value: resp,
});
});

it('calls onComplete with the error', async () => {
const err = new Error('test error');
jest.spyOn(prover, 'getBaseParityProof').mockRejectedValueOnce(err);

controller.start();
await sleep(1);
expect(onComplete).toHaveBeenCalledWith(err, undefined);
});

it('does not crash if onComplete throws', async () => {
const err = new Error('test error');
onComplete.mockImplementationOnce(() => {
throw err;
});

controller.start();
await sleep(1);
expect(onComplete).toHaveBeenCalled();
});

it('does not crash if onComplete rejects', async () => {
const err = new Error('test error');
onComplete.mockRejectedValueOnce(err);

controller.start();
await sleep(1);
expect(onComplete).toHaveBeenCalled();
});
});
Loading

0 comments on commit 1abdd86

Please sign in to comment.