-
Notifications
You must be signed in to change notification settings - Fork 237
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
586 additions
and
11 deletions.
There are no files selected for viewing
213 changes: 213 additions & 0 deletions
213
yarn-project/prover-client/src/proving_broker/proving_agent.test.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
import { | ||
ProvingRequestType, | ||
type PublicInputsAndRecursiveProof, | ||
type V2ProvingJob, | ||
type V2ProvingJobId, | ||
makePublicInputsAndRecursiveProof, | ||
} from '@aztec/circuit-types'; | ||
import { | ||
type ParityPublicInputs, | ||
RECURSIVE_PROOF_LENGTH, | ||
VerificationKeyData, | ||
makeRecursiveProof, | ||
} from '@aztec/circuits.js'; | ||
import { makeBaseParityInputs, makeParityPublicInputs } from '@aztec/circuits.js/testing'; | ||
import { randomBytes } from '@aztec/foundation/crypto'; | ||
import { AbortError } from '@aztec/foundation/error'; | ||
import { promiseWithResolvers } from '@aztec/foundation/promise'; | ||
|
||
import { jest } from '@jest/globals'; | ||
|
||
import { MockProver } from '../test/mock_prover.js'; | ||
import { ProvingAgent } from './proving_agent.js'; | ||
import { type ProvingJobConsumer } from './proving_broker_interface.js'; | ||
|
||
describe('ProvingAgent', () => { | ||
let prover: MockProver; | ||
let jobSource: jest.Mocked<ProvingJobConsumer>; | ||
let agent: ProvingAgent; | ||
const agentPollIntervalMs = 1000; | ||
|
||
beforeEach(() => { | ||
jest.useFakeTimers(); | ||
|
||
prover = new MockProver(); | ||
jobSource = { | ||
getProvingJob: jest.fn(), | ||
reportProvingJobProgress: jest.fn(), | ||
reportProvingJobError: jest.fn(), | ||
reportProvingJobSuccess: jest.fn(), | ||
}; | ||
agent = new ProvingAgent(jobSource, prover, [ProvingRequestType.BASE_PARITY]); | ||
}); | ||
|
||
afterEach(async () => { | ||
await agent.stop(); | ||
}); | ||
|
||
it('polls for jobs passing the permitted list of proofs', () => { | ||
agent.start(); | ||
expect(jobSource.getProvingJob).toHaveBeenCalledWith({ allowList: [ProvingRequestType.BASE_PARITY] }); | ||
}); | ||
|
||
it('only takes a single job from the source at a time', async () => { | ||
expect(jobSource.getProvingJob).not.toHaveBeenCalled(); | ||
|
||
// simulate the proof taking a long time | ||
const { promise, resolve } = | ||
promiseWithResolvers<PublicInputsAndRecursiveProof<ParityPublicInputs, typeof RECURSIVE_PROOF_LENGTH>>(); | ||
jest.spyOn(prover, 'getBaseParityProof').mockReturnValueOnce(promise); | ||
|
||
const jobResponse = makeBaseParityJob(); | ||
jobSource.getProvingJob.mockResolvedValueOnce(jobResponse); | ||
agent.start(); | ||
|
||
await jest.advanceTimersByTimeAsync(agentPollIntervalMs); | ||
expect(jobSource.getProvingJob).toHaveBeenCalledTimes(1); | ||
|
||
await jest.advanceTimersByTimeAsync(agentPollIntervalMs); | ||
expect(jobSource.getProvingJob).toHaveBeenCalledTimes(1); | ||
|
||
await jest.advanceTimersByTimeAsync(agentPollIntervalMs); | ||
expect(jobSource.getProvingJob).toHaveBeenCalledTimes(1); | ||
|
||
// let's resolve the proof | ||
const result = makePublicInputsAndRecursiveProof( | ||
makeParityPublicInputs(), | ||
makeRecursiveProof(RECURSIVE_PROOF_LENGTH), | ||
VerificationKeyData.makeFakeHonk(), | ||
); | ||
resolve(result); | ||
|
||
await jest.advanceTimersByTimeAsync(agentPollIntervalMs); | ||
expect(jobSource.getProvingJob).toHaveBeenCalledTimes(2); | ||
}); | ||
|
||
it('reports success to the job source', async () => { | ||
const jobResponse = makeBaseParityJob(); | ||
const result = makeBaseParityResult(); | ||
jest.spyOn(prover, 'getBaseParityProof').mockResolvedValueOnce(result.value); | ||
|
||
jobSource.getProvingJob.mockResolvedValueOnce(jobResponse); | ||
agent.start(); | ||
|
||
await jest.advanceTimersByTimeAsync(agentPollIntervalMs); | ||
expect(jobSource.reportProvingJobSuccess).toHaveBeenCalledWith(jobResponse.job.id, result); | ||
}); | ||
|
||
it('reports errors to the job source', async () => { | ||
const jobResponse = makeBaseParityJob(); | ||
jest.spyOn(prover, 'getBaseParityProof').mockRejectedValueOnce(new Error('test error')); | ||
|
||
jobSource.getProvingJob.mockResolvedValueOnce(jobResponse); | ||
agent.start(); | ||
|
||
await jest.advanceTimersByTimeAsync(agentPollIntervalMs); | ||
expect(jobSource.reportProvingJobError).toHaveBeenCalledWith(jobResponse.job.id, new Error('test error')); | ||
}); | ||
|
||
it('reports jobs in progress to the job source', async () => { | ||
const jobResponse = makeBaseParityJob(); | ||
const { promise, resolve } = | ||
promiseWithResolvers<PublicInputsAndRecursiveProof<ParityPublicInputs, typeof RECURSIVE_PROOF_LENGTH>>(); | ||
jest.spyOn(prover, 'getBaseParityProof').mockReturnValueOnce(promise); | ||
|
||
jobSource.getProvingJob.mockResolvedValueOnce(jobResponse); | ||
agent.start(); | ||
|
||
await jest.advanceTimersByTimeAsync(agentPollIntervalMs); | ||
expect(jobSource.reportProvingJobProgress).toHaveBeenCalledWith(jobResponse.job.id, jobResponse.time, { | ||
allowList: [ProvingRequestType.BASE_PARITY], | ||
}); | ||
|
||
await jest.advanceTimersByTimeAsync(agentPollIntervalMs); | ||
expect(jobSource.reportProvingJobProgress).toHaveBeenCalledWith(jobResponse.job.id, jobResponse.time, { | ||
allowList: [ProvingRequestType.BASE_PARITY], | ||
}); | ||
|
||
resolve(makeBaseParityResult().value); | ||
}); | ||
|
||
it('abandons jobs if told so by the source', async () => { | ||
const firstJobResponse = makeBaseParityJob(); | ||
let firstProofAborted = false; | ||
const firstProof = | ||
promiseWithResolvers<PublicInputsAndRecursiveProof<ParityPublicInputs, typeof RECURSIVE_PROOF_LENGTH>>(); | ||
|
||
// simulate a long running proving job that can be aborted | ||
jest.spyOn(prover, 'getBaseParityProof').mockImplementationOnce((_, signal) => { | ||
signal?.addEventListener('abort', () => { | ||
firstProof.reject(new AbortError('test abort')); | ||
firstProofAborted = true; | ||
}); | ||
return firstProof.promise; | ||
}); | ||
|
||
jobSource.getProvingJob.mockResolvedValueOnce(firstJobResponse); | ||
agent.start(); | ||
|
||
// now the agent should be happily proving and reporting progress | ||
await jest.advanceTimersByTimeAsync(agentPollIntervalMs); | ||
expect(jobSource.reportProvingJobProgress).toHaveBeenCalledTimes(1); | ||
expect(jobSource.reportProvingJobProgress).toHaveBeenCalledWith(firstJobResponse.job.id, firstJobResponse.time, { | ||
allowList: [ProvingRequestType.BASE_PARITY], | ||
}); | ||
|
||
await jest.advanceTimersByTimeAsync(agentPollIntervalMs); | ||
expect(jobSource.reportProvingJobProgress).toHaveBeenCalledTimes(2); | ||
|
||
// now let's simulate the job source cancelling the job and giving the agent something else to do | ||
// this should cause the agent to abort the current job and start the new one | ||
const secondJobResponse = makeBaseParityJob(); | ||
jobSource.reportProvingJobProgress.mockResolvedValueOnce(secondJobResponse); | ||
|
||
const secondProof = | ||
promiseWithResolvers<PublicInputsAndRecursiveProof<ParityPublicInputs, typeof RECURSIVE_PROOF_LENGTH>>(); | ||
jest.spyOn(prover, 'getBaseParityProof').mockReturnValueOnce(secondProof.promise); | ||
|
||
await jest.advanceTimersByTimeAsync(agentPollIntervalMs); | ||
expect(jobSource.reportProvingJobProgress).toHaveBeenCalledTimes(3); | ||
expect(jobSource.reportProvingJobProgress).toHaveBeenLastCalledWith( | ||
firstJobResponse.job.id, | ||
firstJobResponse.time, | ||
{ | ||
allowList: [ProvingRequestType.BASE_PARITY], | ||
}, | ||
); | ||
expect(firstProofAborted).toBe(true); | ||
|
||
// agent should have switched now | ||
await jest.advanceTimersByTimeAsync(agentPollIntervalMs); | ||
expect(jobSource.reportProvingJobProgress).toHaveBeenCalledTimes(4); | ||
expect(jobSource.reportProvingJobProgress).toHaveBeenLastCalledWith( | ||
secondJobResponse.job.id, | ||
secondJobResponse.time, | ||
{ | ||
allowList: [ProvingRequestType.BASE_PARITY], | ||
}, | ||
); | ||
|
||
secondProof.resolve(makeBaseParityResult().value); | ||
}); | ||
|
||
function makeBaseParityJob(): { job: V2ProvingJob; time: number } { | ||
const time = jest.now(); | ||
const job: V2ProvingJob = { | ||
id: randomBytes(8).toString('hex') as V2ProvingJobId, | ||
blockNumber: 1, | ||
type: ProvingRequestType.BASE_PARITY, | ||
inputs: makeBaseParityInputs(), | ||
}; | ||
|
||
return { job, time }; | ||
} | ||
|
||
function makeBaseParityResult() { | ||
const value = makePublicInputsAndRecursiveProof( | ||
makeParityPublicInputs(), | ||
makeRecursiveProof(RECURSIVE_PROOF_LENGTH), | ||
VerificationKeyData.makeFakeHonk(), | ||
); | ||
return { type: ProvingRequestType.BASE_PARITY, value }; | ||
} | ||
}); |
84 changes: 84 additions & 0 deletions
84
yarn-project/prover-client/src/proving_broker/proving_agent.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import { type ProvingRequestType, type ServerCircuitProver, type V2ProvingJob } from '@aztec/circuit-types'; | ||
import { createDebugLogger } from '@aztec/foundation/log'; | ||
import { RunningPromise } from '@aztec/foundation/running-promise'; | ||
|
||
import { type ProvingJobConsumer } from './proving_broker_interface.js'; | ||
import { ProvingJobController, ProvingJobStatus } from './proving_job_controller.js'; | ||
|
||
/** | ||
* A helper class that encapsulates a circuit prover and connects it to a job source. | ||
*/ | ||
export class ProvingAgent { | ||
private currentJobController?: ProvingJobController; | ||
private runningPromise: RunningPromise; | ||
|
||
constructor( | ||
/** The source of proving jobs */ | ||
private jobSource: ProvingJobConsumer, | ||
/** The prover implementation to defer jobs to */ | ||
private circuitProver: ServerCircuitProver, | ||
/** Optional list of allowed proof types to build */ | ||
private proofAllowList?: Array<ProvingRequestType>, | ||
/** How long to wait between jobs */ | ||
private pollIntervalMs = 1000, | ||
private log = createDebugLogger('aztec:proving-broker:proving-agent'), | ||
) { | ||
this.runningPromise = new RunningPromise(this.safeWork, this.pollIntervalMs); | ||
} | ||
|
||
public setCircuitProver(circuitProver: ServerCircuitProver): void { | ||
this.circuitProver = circuitProver; | ||
} | ||
|
||
public isRunning(): boolean { | ||
return this.runningPromise?.isRunning() ?? false; | ||
} | ||
|
||
public start(): void { | ||
this.runningPromise.start(); | ||
} | ||
|
||
public async stop(): Promise<void> { | ||
this.currentJobController?.abort(); | ||
await this.runningPromise.stop(); | ||
} | ||
|
||
private safeWork = async () => { | ||
try { | ||
// every tick we need to | ||
// (1) either do a heartbeat, telling the broker that we're working | ||
// (2) get a new job | ||
// If during (1) the broker returns a new job that means we can cancel the current job and start the new one | ||
let maybeJob: { job: V2ProvingJob; time: number } | undefined; | ||
if (this.currentJobController?.getStatus() === ProvingJobStatus.PROVING) { | ||
maybeJob = await this.jobSource.reportProvingJobProgress( | ||
this.currentJobController.getJobId(), | ||
this.currentJobController.getStartedAt(), | ||
{ allowList: this.proofAllowList }, | ||
); | ||
} else { | ||
maybeJob = await this.jobSource.getProvingJob({ allowList: this.proofAllowList }); | ||
} | ||
|
||
if (!maybeJob) { | ||
return; | ||
} | ||
|
||
if (this.currentJobController?.getStatus() === ProvingJobStatus.PROVING) { | ||
this.currentJobController?.abort(); | ||
} | ||
|
||
const { job, time } = maybeJob; | ||
this.currentJobController = new ProvingJobController(job, time, this.circuitProver, (err, result) => { | ||
if (err) { | ||
return this.jobSource.reportProvingJobError(job.id, err); | ||
} else if (result) { | ||
return this.jobSource.reportProvingJobSuccess(job.id, result); | ||
} | ||
}); | ||
this.currentJobController.start(); | ||
} catch (err) { | ||
this.log.error(`Error in ProvingAgent: ${String(err)}`); | ||
} | ||
}; | ||
} |
91 changes: 91 additions & 0 deletions
91
yarn-project/prover-client/src/proving_broker/proving_job_controller.test.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import { ProvingRequestType, type V2ProvingJobId, makePublicInputsAndRecursiveProof } from '@aztec/circuit-types'; | ||
import { RECURSIVE_PROOF_LENGTH, VerificationKeyData, makeRecursiveProof } from '@aztec/circuits.js'; | ||
import { makeBaseParityInputs, makeParityPublicInputs } from '@aztec/circuits.js/testing'; | ||
import { sleep } from '@aztec/foundation/sleep'; | ||
|
||
import { jest } from '@jest/globals'; | ||
|
||
import { MockProver } from '../test/mock_prover.js'; | ||
import { ProvingJobController, ProvingJobStatus } from './proving_job_controller.js'; | ||
|
||
describe('ProvingJobController', () => { | ||
let prover: MockProver; | ||
let onComplete: jest.Mock<any>; | ||
let controller: ProvingJobController; | ||
|
||
beforeEach(() => { | ||
prover = new MockProver(); | ||
onComplete = jest.fn(); | ||
controller = new ProvingJobController( | ||
{ | ||
type: ProvingRequestType.BASE_PARITY, | ||
blockNumber: 1, | ||
id: '1' as V2ProvingJobId, | ||
inputs: makeBaseParityInputs(), | ||
}, | ||
0, | ||
prover, | ||
onComplete, | ||
); | ||
}); | ||
|
||
it('reports IDLE status initially', () => { | ||
expect(controller.getStatus()).toBe(ProvingJobStatus.IDLE); | ||
}); | ||
|
||
it('reports PROVING status while busy', () => { | ||
controller.start(); | ||
expect(controller.getStatus()).toBe(ProvingJobStatus.PROVING); | ||
}); | ||
|
||
it('reports DONE status after job is done', async () => { | ||
controller.start(); | ||
await sleep(1); // give promises a chance to complete | ||
expect(controller.getStatus()).toBe(ProvingJobStatus.DONE); | ||
}); | ||
|
||
it('calls onComplete with the proof', async () => { | ||
const resp = makePublicInputsAndRecursiveProof( | ||
makeParityPublicInputs(), | ||
makeRecursiveProof(RECURSIVE_PROOF_LENGTH), | ||
VerificationKeyData.makeFakeHonk(), | ||
); | ||
jest.spyOn(prover, 'getBaseParityProof').mockResolvedValueOnce(resp); | ||
|
||
controller.start(); | ||
await sleep(1); // give promises a chance to complete | ||
expect(onComplete).toHaveBeenCalledWith(undefined, { | ||
type: ProvingRequestType.BASE_PARITY, | ||
value: resp, | ||
}); | ||
}); | ||
|
||
it('calls onComplete with the error', async () => { | ||
const err = new Error('test error'); | ||
jest.spyOn(prover, 'getBaseParityProof').mockRejectedValueOnce(err); | ||
|
||
controller.start(); | ||
await sleep(1); | ||
expect(onComplete).toHaveBeenCalledWith(err, undefined); | ||
}); | ||
|
||
it('does not crash if onComplete throws', async () => { | ||
const err = new Error('test error'); | ||
onComplete.mockImplementationOnce(() => { | ||
throw err; | ||
}); | ||
|
||
controller.start(); | ||
await sleep(1); | ||
expect(onComplete).toHaveBeenCalled(); | ||
}); | ||
|
||
it('does not crash if onComplete rejects', async () => { | ||
const err = new Error('test error'); | ||
onComplete.mockRejectedValueOnce(err); | ||
|
||
controller.start(); | ||
await sleep(1); | ||
expect(onComplete).toHaveBeenCalled(); | ||
}); | ||
}); |
Oops, something went wrong.