Skip to content

Commit

Permalink
chore: stop calling public kernel tail
Browse files Browse the repository at this point in the history
  • Loading branch information
dbanks12 committed Nov 9, 2024
1 parent 2e13938 commit faabd42
Show file tree
Hide file tree
Showing 4 changed files with 263 additions and 110 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -486,17 +486,16 @@ describe('Enqueued-call Side Effect Trace', () => {
// parent absorbs child's side effects
const parentSideEffects = trace.getSideEffects();
const childSideEffects = nestedTrace.getSideEffects();
// TODO(dbanks12): confirm that all hints were merged from child
if (callResults.reverted) {
expect(parentSideEffects.publicDataReads).toEqual(childSideEffects.publicDataReads);
expect(parentSideEffects.publicDataWrites).toEqual(childSideEffects.publicDataWrites);
expect(parentSideEffects.noteHashReadRequests).toEqual(childSideEffects.noteHashReadRequests);
expect(parentSideEffects.publicDataReads).toEqual([]);
expect(parentSideEffects.publicDataWrites).toEqual([]);
expect(parentSideEffects.noteHashReadRequests).toEqual([]);
expect(parentSideEffects.noteHashes).toEqual([]);
expect(parentSideEffects.nullifierReadRequests).toEqual(childSideEffects.nullifierReadRequests);
expect(parentSideEffects.nullifierNonExistentReadRequests).toEqual(
childSideEffects.nullifierNonExistentReadRequests,
);
expect(parentSideEffects.nullifiers).toEqual(childSideEffects.nullifiers);
expect(parentSideEffects.l1ToL2MsgReadRequests).toEqual(childSideEffects.l1ToL2MsgReadRequests);
expect(parentSideEffects.nullifierReadRequests).toEqual([]);
expect(parentSideEffects.nullifierNonExistentReadRequests).toEqual([]);
expect(parentSideEffects.nullifiers).toEqual([]);
expect(parentSideEffects.l1ToL2MsgReadRequests).toEqual([]);
expect(parentSideEffects.l2ToL1Msgs).toEqual([]);
expect(parentSideEffects.unencryptedLogs).toEqual([]);
expect(parentSideEffects.unencryptedLogsHashes).toEqual([]);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import { UnencryptedFunctionL2Logs, UnencryptedL2Log } from '@aztec/circuit-types';
import {
AvmAccumulatedData,
AvmCircuitPublicInputs,
AvmContractBytecodeHints,
AvmContractInstanceHint,
AvmEnqueuedCallHint,
Expand All @@ -11,6 +13,8 @@ import {
type ContractClassIdPreimage,
EthAddress,
Gas,
type GasSettings,
type GlobalVariables,
L2ToL1Message,
LogHash,
MAX_ENCRYPTED_LOGS_PER_TX,
Expand All @@ -28,11 +32,14 @@ import {
MAX_UNENCRYPTED_LOGS_PER_TX,
NoteHash,
Nullifier,
PrivateToAvmAccumulatedData,
PrivateToAvmAccumulatedDataArrayLengths,
PublicAccumulatedData,
PublicAccumulatedDataArrayLengths,
PublicCallRequest,
PublicDataRead,
PublicDataUpdateRequest,
PublicDataWrite,
PublicInnerCallRequest,
PublicValidationRequestArrayLengths,
PublicValidationRequests,
Expand All @@ -44,13 +51,15 @@ import {
ScopedReadRequest,
SerializableContractInstance,
TreeLeafReadRequest,
type TreeSnapshots,
VMCircuitPublicInputs,
} from '@aztec/circuits.js';
import { computePublicDataTreeLeafSlot, siloNullifier } from '@aztec/circuits.js/hash';
import { makeTuple } from '@aztec/foundation/array';
import { padArrayEnd } from '@aztec/foundation/collection';
import { Fr } from '@aztec/foundation/fields';
import { createDebugLogger } from '@aztec/foundation/log';
import { type Tuple } from '@aztec/foundation/serialize';

import { type AvmContractCallResult } from '../avm/avm_contract_call_result.js';
import { type AvmExecutionEnvironment } from '../avm/avm_execution_environment.js';
Expand Down Expand Up @@ -93,6 +102,9 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI
/** The side effect counter increments with every call to the trace. */
private sideEffectCounter: number;

//private publicSetupCallRequests: PublicCallRequest[] = [];
//private publicAppLogicCallRequests: PublicCallRequest[] = [];
//private publicTeardownCallRequest: PublicCallRequest[] = [];
private enqueuedCalls: PublicCallRequest[] = [];

private publicDataReads: PublicDataRead[] = [];
Expand Down Expand Up @@ -502,14 +514,14 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI
// TODO(dbanks12): accept & merge nested trace's hints!
// TODO(dbanks12): What should happen to side effect counter on revert?
this.sideEffectCounter = nestedTrace.sideEffectCounter;
this.publicDataReads.push(...nestedTrace.publicDataReads);
this.publicDataWrites.push(...nestedTrace.publicDataWrites);
this.noteHashReadRequests.push(...nestedTrace.noteHashReadRequests);
//this.publicDataReads.push(...nestedTrace.publicDataReads);
//this.publicDataWrites.push(...nestedTrace.publicDataWrites);
//this.noteHashReadRequests.push(...nestedTrace.noteHashReadRequests);
// new noteHashes are tossed on revert
this.nullifierReadRequests.push(...nestedTrace.nullifierReadRequests);
this.nullifierNonExistentReadRequests.push(...nestedTrace.nullifierNonExistentReadRequests);
this.nullifiers.push(...nestedTrace.nullifiers);
this.l1ToL2MsgReadRequests.push(...nestedTrace.l1ToL2MsgReadRequests);
//this.nullifierReadRequests.push(...nestedTrace.nullifierReadRequests);
//this.nullifierNonExistentReadRequests.push(...nestedTrace.nullifierNonExistentReadRequests);
//this.nullifiers.push(...nestedTrace.nullifiers);
//this.l1ToL2MsgReadRequests.push(...nestedTrace.l1ToL2MsgReadRequests);
// new l2-to-l1 messages are tossed on revert
// new unencrypted logs are tossed on revert
}
Expand Down Expand Up @@ -592,6 +604,44 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI
);
}

public toAvmCircuitPublicInputs(
/** Globals. */
globalVariables: GlobalVariables,
/** Start tree snapshots. */
startTreeSnapshots: TreeSnapshots,
/** How much gas was available for this public execution. */
gasLimits: GasSettings,
/** Call requests for setup phase. */
publicSetupCallRequests: Tuple<PublicCallRequest, typeof MAX_ENQUEUED_CALLS_PER_TX>,
/** Call requests for app logic phase. */
publicAppLogicCallRequests: Tuple<PublicCallRequest, typeof MAX_ENQUEUED_CALLS_PER_TX>,
/** Call request for teardown phase. */
publicTeardownCallRequest: PublicCallRequest,
/** End tree snapshots. */
endTreeSnapshots: TreeSnapshots,
/** Transaction fee. */
transactionFee: Fr,
/** The call's results */
reverted: boolean,
): AvmCircuitPublicInputs {
return new AvmCircuitPublicInputs(
globalVariables,
startTreeSnapshots,
gasLimits,
publicSetupCallRequests,
publicAppLogicCallRequests,
publicTeardownCallRequest,
/*previousNonRevertibleAccumulatedDataArrayLengths=*/ PrivateToAvmAccumulatedDataArrayLengths.empty(),
/*previousRevertibleAccumulatedDataArrayLengths=*/ PrivateToAvmAccumulatedDataArrayLengths.empty(),
/*previousNonRevertibleAccumulatedDataArray=*/ PrivateToAvmAccumulatedData.empty(),
/*previousRevertibleAccumulatedDataArray=*/ PrivateToAvmAccumulatedData.empty(),
endTreeSnapshots,
/*accumulatedData=*/ this.getAvmAccumulatedData(),
transactionFee,
reverted,
);
}

public toPublicFunctionCallResult(
/** The execution environment of the nested call. */
_avmEnvironment: AvmExecutionEnvironment,
Expand Down Expand Up @@ -632,6 +682,28 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI
);
}

private getAvmAccumulatedData() {
return new AvmAccumulatedData(
padArrayEnd(
this.noteHashes.map(n => n.value),
Fr.zero(),
MAX_NOTE_HASHES_PER_TX,
),
padArrayEnd(
this.nullifiers.map(n => n.value),
Fr.zero(),
MAX_NULLIFIERS_PER_TX,
),
padArrayEnd(this.l2ToL1Messages, ScopedL2ToL1Message.empty(), MAX_L2_TO_L1_MSGS_PER_TX),
padArrayEnd(this.unencryptedLogsHashes, ScopedLogHash.empty(), MAX_UNENCRYPTED_LOGS_PER_TX),
padArrayEnd(
this.publicDataWrites.map(w => new PublicDataWrite(w.leafSlot, w.newValue)),
PublicDataWrite.empty(),
MAX_PUBLIC_DATA_UPDATE_REQUESTS_PER_TX,
),
);
}

private getAccumulatedData(gasUsed: Gas) {
return new PublicAccumulatedData(
padArrayEnd(this.noteHashes, ScopedNoteHash.empty(), MAX_NOTE_HASHES_PER_TX),
Expand Down
42 changes: 10 additions & 32 deletions yarn-project/simulator/src/public/enqueued_calls_processor.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import { type FieldsOf } from '@aztec/foundation/types';
import { openTmpStore } from '@aztec/kv-store/utils';
import { type AppendOnlyTree, Poseidon, StandardTree, newTree } from '@aztec/merkle-tree';

import { jest } from '@jest/globals';
import { type MockProxy, mock } from 'jest-mock-extended';

import { type AvmPersistableStateManager } from '../avm/journal/journal.js';
Expand Down Expand Up @@ -111,16 +110,13 @@ describe('enqueued_calls_processor', () => {
return Promise.resolve(result);
});

const tailSpy = jest.spyOn(publicKernel, 'publicKernelCircuitTail');

const txResult = await processor.process(tx);

expect(txResult.processedPhases).toHaveLength(1);
expect(txResult.processedPhases[0]).toEqual(expect.objectContaining({ revertReason: undefined }));
expect(txResult.revertCode).toEqual(RevertCode.OK);
expect(txResult.revertReason).toBe(undefined);

expect(tailSpy).toHaveBeenCalledTimes(1);
expect(publicExecutor.simulate).toHaveBeenCalledTimes(2);

const outputs = txResult.avmProvingRequest!.inputs.output.accumulatedData;
Expand Down Expand Up @@ -187,8 +183,6 @@ describe('enqueued_calls_processor', () => {
);
}

const tailSpy = jest.spyOn(publicKernel, 'publicKernelCircuitTail');

const txResult = await processor.process(tx);

expect(txResult.processedPhases).toHaveLength(3);
Expand All @@ -197,11 +191,10 @@ describe('enqueued_calls_processor', () => {
expect(txResult.processedPhases[2]).toEqual(expect.objectContaining({ revertReason: teardownFailure }));
expect(txResult.revertReason).toBe(teardownFailure);

expect(tailSpy).toHaveBeenCalledTimes(1);
expect(publicExecutor.simulate).toHaveBeenCalledTimes(3);

const outputs = txResult.avmProvingRequest!.inputs.output.accumulatedData;
const numPublicDataWrites = 3;
const numPublicDataWrites = 3; // 7 total, but teardown reverted, so 2 from app logic and 2 from teardown are reverted, so 2 from app logic and 2 from teardown are reverted
expect(arrayNonEmptyLength(outputs.publicDataWrites, PublicDataWrite.isEmpty)).toBe(numPublicDataWrites);
expect(outputs.publicDataWrites.slice(0, numPublicDataWrites)).toEqual([
new PublicDataWrite(
Expand Down Expand Up @@ -269,11 +262,8 @@ describe('enqueued_calls_processor', () => {
);
}

const tailSpy = jest.spyOn(publicKernel, 'publicKernelCircuitTail');

await expect(processor.process(tx)).rejects.toThrow(setupFailureMsg);

expect(tailSpy).toHaveBeenCalledTimes(0);
expect(publicExecutor.simulate).toHaveBeenCalledTimes(1);
});

Expand Down Expand Up @@ -340,8 +330,6 @@ describe('enqueued_calls_processor', () => {
);
}

const tailSpy = jest.spyOn(publicKernel, 'publicKernelCircuitTail');

const txResult = await processor.process(tx);

expect(publicExecutor.simulate).toHaveBeenCalledTimes(3);
Expand All @@ -353,7 +341,6 @@ describe('enqueued_calls_processor', () => {
// tx reports app logic failure
expect(txResult.revertReason).toBe(appLogicFailure);

expect(tailSpy).toHaveBeenCalledTimes(1);
expect(publicExecutor.simulate).toHaveBeenCalledTimes(3);

const outputs = txResult.avmProvingRequest!.inputs.output.accumulatedData;
Expand Down Expand Up @@ -435,23 +422,20 @@ describe('enqueued_calls_processor', () => {
);
}

const tailSpy = jest.spyOn(publicKernel, 'publicKernelCircuitTail');

const txResult = await processor.process(tx);

expect(txResult.processedPhases).toHaveLength(3);
expect(txResult.processedPhases[0]).toEqual(expect.objectContaining({ revertReason: undefined }));
expect(txResult.processedPhases[1]).toEqual(expect.objectContaining({ revertReason: appLogicFailure }));
expect(txResult.processedPhases[2]).toEqual(expect.objectContaining({ revertReason: teardownFailure }));
expect(txResult.revertCode).toEqual(RevertCode.BOTH_REVERTED);
//expect(txResult.revertCode).toEqual(RevertCode.BOTH_REVERTED);
// tx reports app logic failure
expect(txResult.revertReason).toBe(appLogicFailure);

expect(tailSpy).toHaveBeenCalledTimes(1);
expect(publicExecutor.simulate).toHaveBeenCalledTimes(3);

const outputs = txResult.avmProvingRequest!.inputs.output.accumulatedData;
const numPublicDataWrites = 3;
const numPublicDataWrites = 3; // 4 are reverted
expect(arrayNonEmptyLength(outputs.publicDataWrites, PublicDataWrite.isEmpty)).toBe(numPublicDataWrites);
expect(outputs.publicDataWrites.slice(0, numPublicDataWrites)).toEqual([
new PublicDataWrite(
Expand Down Expand Up @@ -581,8 +565,6 @@ describe('enqueued_calls_processor', () => {
);
}

const tailSpy = jest.spyOn(publicKernel, 'publicKernelCircuitTail');

const txResult = await processor.process(tx);

expect(txResult.processedPhases).toHaveLength(3);
Expand All @@ -596,7 +578,6 @@ describe('enqueued_calls_processor', () => {
});
expect(txResult.revertReason).toBe(undefined);

expect(tailSpy).toHaveBeenCalledTimes(1);
expect(publicExecutor.simulate).toHaveBeenCalledTimes(3);

const expectedSimulateCall = (availableGas: Partial<FieldsOf<Gas>>, txFee: number) => [
Expand All @@ -618,18 +599,19 @@ describe('enqueued_calls_processor', () => {
const output = txResult.avmProvingRequest!.inputs.output;
expect(output.transactionFee.toNumber()).toEqual(expectedTxFee);

const numPublicDataWrites = 3;
const numPublicDataWrites = 6; // 3 if we enable deduplication of writes to same slot
expect(arrayNonEmptyLength(output.accumulatedData.publicDataWrites, PublicDataWrite.isEmpty)).toBe(
numPublicDataWrites,
);
expect(output.accumulatedData.publicDataWrites.slice(0, numPublicDataWrites)).toEqual([
// squashed
// new PublicDataWrite(computePublicDataTreeLeafSlot(contractAddress, contractSlotA), fr(0x101)),
// will be overwritten
new PublicDataWrite(computePublicDataTreeLeafSlot(contractAddress, contractSlotA), fr(0x101)),
new PublicDataWrite(computePublicDataTreeLeafSlot(contractAddress, contractSlotB), fr(0x151)),

new PublicDataWrite(computePublicDataTreeLeafSlot(contractAddress, contractSlotA), fr(0x103)),
// squashed
// new PublicDataWrite(computePublicDataTreeLeafSlot(contractAddress, contractSlotC), fr(0x201)),
// new PublicDataWrite(computePublicDataTreeLeafSlot(contractAddress, contractSlotC), fr(0x102)),
// will be overwritten
new PublicDataWrite(computePublicDataTreeLeafSlot(contractAddress, contractSlotC), fr(0x201)),
new PublicDataWrite(computePublicDataTreeLeafSlot(contractAddress, contractSlotC), fr(0x102)),
new PublicDataWrite(computePublicDataTreeLeafSlot(contractAddress, contractSlotC), fr(0x152)),
]);

Expand Down Expand Up @@ -664,8 +646,6 @@ describe('enqueued_calls_processor', () => {

publicExecutor.simulate.mockImplementationOnce(() => Promise.resolve(simulatorResults[0]));

const tailSpy = jest.spyOn(publicKernel, 'publicKernelCircuitTail');

const txResult = await processor.process(tx);

expect(txResult.processedPhases).toHaveLength(1);
Expand All @@ -674,7 +654,5 @@ describe('enqueued_calls_processor', () => {
[PublicKernelPhase.TEARDOWN]: teardownGasUsed,
});
expect(txResult.revertReason).toBe(undefined);

expect(tailSpy).toHaveBeenCalledTimes(1);
});
});
Loading

0 comments on commit faabd42

Please sign in to comment.