Skip to content

Commit

Permalink
feat: add limit to unique contract call
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasRidhuan committed Dec 11, 2024
1 parent 9b26651 commit 8ed7693
Show file tree
Hide file tree
Showing 13 changed files with 92 additions and 17 deletions.
8 changes: 8 additions & 0 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,13 @@ void AvmTraceBuilder::rollback_to_non_revertible_checkpoint()

std::vector<uint8_t> AvmTraceBuilder::get_bytecode(const FF contract_address, bool check_membership)
{
// The cache contains all the unique contract class ids we have seen so far
if (contract_class_id_cache.size() >= AVM_MAX_UNIQUE_CONTRACT_CALLS) {
// Right now we have no way of communicating this to the circuit since we don't currently lay down rows
// for these operations
// error = AvmError::SIDE_EFFECT_LIMIT_REACHED;
throw std::runtime_error("Max unique contract call limit reached");
}
auto clk = static_cast<uint32_t>(main_trace.size()) + 1;

// Find the bytecode based on contract address of the public call request
Expand Down Expand Up @@ -192,6 +199,7 @@ std::vector<uint8_t> AvmTraceBuilder::get_bytecode(const FF contract_address, bo
// Assert that the hint's exists flag matches. The flag isn't really necessary...
ASSERT(bytecode_hint.contract_instance.exists);
bytecode_membership_cache.insert(contract_address);
contract_class_id_cache.insert(bytecode_hint.contract_instance.contract_class_id);
} else {
// This was a non-membership proof!
// Enforce that the tree access membership checked a low-leaf that skips the contract address nullifier.
Expand Down
2 changes: 2 additions & 0 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ class AvmTraceBuilder {
void checkpoint_non_revertible_state();
void rollback_to_non_revertible_checkpoint();
std::vector<uint8_t> get_bytecode(const FF contract_address, bool check_membership = false);
// Used to track the unique class ids, could also be used to cache membership checks of class ids
std::unordered_set<FF> contract_class_id_cache;
std::unordered_set<FF> bytecode_membership_cache;
void insert_private_state(const std::vector<FF>& siloed_nullifiers, const std::vector<FF>& siloed_note_hashes);

Expand Down
1 change: 1 addition & 0 deletions barretenberg/cpp/src/barretenberg/vm/aztec_constants.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
#define AVM_ACCUMULATED_DATA_LENGTH 318
#define AVM_CIRCUIT_PUBLIC_INPUTS_LENGTH 1006
#define AVM_VERIFICATION_KEY_LENGTH_IN_FIELDS 86
#define AVM_MAX_UNIQUE_CONTRACT_CALLS 21
#define AVM_PROOF_LENGTH_IN_FIELDS 4155
#define AVM_PUBLIC_COLUMN_MAX_SIZE 1024
#define AVM_PUBLIC_INPUTS_FLATTENED_SIZE 2915
Expand Down
1 change: 1 addition & 0 deletions l1-contracts/src/core/libraries/ConstantsGen.sol
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ library Constants {
uint256 internal constant TUBE_PROOF_LENGTH = 463;
uint256 internal constant HONK_VERIFICATION_KEY_LENGTH_IN_FIELDS = 128;
uint256 internal constant CLIENT_IVC_VERIFICATION_KEY_LENGTH_IN_FIELDS = 143;
uint256 internal constant MAX_PUBLIC_BYTECODE_SIZE_IN_BYTES = 96000;
uint256 internal constant MEM_TAG_FF = 0;
uint256 internal constant MEM_TAG_U1 = 1;
uint256 internal constant MEM_TAG_U8 = 2;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,13 @@ pub global CLIENT_IVC_VERIFICATION_KEY_LENGTH_IN_FIELDS: u32 = 143; // size of a
// 21 above refers to the constant AvmFlavor::NUM_PRECOMPUTED_ENTITIES
pub global AVM_VERIFICATION_KEY_LENGTH_IN_FIELDS: u32 = 2 + 21 * 4;

// Setting limits for AVM_MAX_UNIQUE_CONTRACT_CALLS
// This value is determined by the length of the AVM trace and the MAX_PUBLIC_BYTECODE_SIZE_IN_BYTES
// (i.e. 2^21 / MAX_PUBLIC_BYTECODE_SIZE_IN_BYTES ==> 2^21 / 96,000 = 21
pub global MAX_PUBLIC_BYTECODE_SIZE_IN_BYTES: u32 = MAX_PACKED_PUBLIC_BYTECODE_SIZE_IN_FIELDS * 32;
pub global AVM_MAX_UNIQUE_CONTRACT_CALLS: u32 = 21;


// `AVM_PROOF_LENGTH_IN_FIELDS` must be updated when AVM circuit changes.
// To determine latest value, hover `COMPUTED_AVM_PROOF_LENGTH_IN_FIELDS`
// in barretenberg/cpp/src/barretenberg/vm/avm/generated/flavor.hpp
Expand Down
2 changes: 2 additions & 0 deletions yarn-project/circuits.js/src/constants.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ export const TUBE_PROOF_LENGTH = 463;
export const HONK_VERIFICATION_KEY_LENGTH_IN_FIELDS = 128;
export const CLIENT_IVC_VERIFICATION_KEY_LENGTH_IN_FIELDS = 143;
export const AVM_VERIFICATION_KEY_LENGTH_IN_FIELDS = 86;
export const MAX_PUBLIC_BYTECODE_SIZE_IN_BYTES = 96000;
export const AVM_MAX_UNIQUE_CONTRACT_CALLS = 21;
export const AVM_PROOF_LENGTH_IN_FIELDS = 4155;
export const AVM_PUBLIC_COLUMN_MAX_SIZE = 1024;
export const AVM_PUBLIC_INPUTS_FLATTENED_SIZE = 2915;
Expand Down
1 change: 1 addition & 0 deletions yarn-project/circuits.js/src/scripts/constants.in.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ const CPP_CONSTANTS = [
'MULTI_CALL_ENTRYPOINT_ADDRESS',
'FEE_JUICE_ADDRESS',
'ROUTER_ADDRESS',
'AVM_MAX_UNIQUE_CONTRACT_CALLS',
];

const CPP_GENERATORS: string[] = [
Expand Down
23 changes: 15 additions & 8 deletions yarn-project/circuits.js/src/structs/avm/avm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,6 @@ export class AvmExecutionHints {
public readonly enqueuedCalls: Vector<AvmEnqueuedCallHint>;

public readonly contractInstances: Vector<AvmContractInstanceHint>;
public readonly contractBytecodeHints: Vector<AvmContractBytecodeHints>;

public readonly publicDataReads: Vector<AvmPublicDataReadTreeHint>;
public readonly publicDataWrites: Vector<AvmPublicDataWriteTreeHint>;
Expand All @@ -696,7 +695,8 @@ export class AvmExecutionHints {
constructor(
enqueuedCalls: AvmEnqueuedCallHint[],
contractInstances: AvmContractInstanceHint[],
contractBytecodeHints: AvmContractBytecodeHints[],
// string here is the contract class id
public contractBytecodeHints: Map<string, AvmContractBytecodeHints>,
publicDataReads: AvmPublicDataReadTreeHint[],
publicDataWrites: AvmPublicDataWriteTreeHint[],
nullifierReads: AvmNullifierReadTreeHint[],
Expand All @@ -707,7 +707,6 @@ export class AvmExecutionHints {
) {
this.enqueuedCalls = new Vector(enqueuedCalls);
this.contractInstances = new Vector(contractInstances);
this.contractBytecodeHints = new Vector(contractBytecodeHints);
this.publicDataReads = new Vector(publicDataReads);
this.publicDataWrites = new Vector(publicDataWrites);
this.nullifierReads = new Vector(nullifierReads);
Expand All @@ -722,7 +721,7 @@ export class AvmExecutionHints {
* @returns an empty instance.
*/
static empty() {
return new AvmExecutionHints([], [], [], [], [], [], [], [], [], []);
return new AvmExecutionHints([], [], new Map(), [], [], [], [], [], [], []);
}

/**
Expand All @@ -749,7 +748,7 @@ export class AvmExecutionHints {
return (
this.enqueuedCalls.items.length == 0 &&
this.contractInstances.items.length == 0 &&
this.contractBytecodeHints.items.length == 0 &&
this.contractBytecodeHints.size == 0 &&
this.publicDataReads.items.length == 0 &&
this.publicDataWrites.items.length == 0 &&
this.nullifierReads.items.length == 0 &&
Expand All @@ -769,7 +768,7 @@ export class AvmExecutionHints {
return new AvmExecutionHints(
fields.enqueuedCalls.items,
fields.contractInstances.items,
fields.contractBytecodeHints.items,
fields.contractBytecodeHints,
fields.publicDataReads.items,
fields.publicDataWrites.items,
fields.nullifierReads.items,
Expand All @@ -789,7 +788,7 @@ export class AvmExecutionHints {
return [
fields.enqueuedCalls,
fields.contractInstances,
fields.contractBytecodeHints,
new Vector(Array.from(fields.contractBytecodeHints.values())),
fields.publicDataReads,
fields.publicDataWrites,
fields.nullifierReads,
Expand All @@ -807,10 +806,18 @@ export class AvmExecutionHints {
*/
static fromBuffer(buff: Buffer | BufferReader): AvmExecutionHints {
const reader = BufferReader.asReader(buff);
const readMap = (reader: BufferReader) => {
const map = new Map();
const values = reader.readVector(AvmContractBytecodeHints);
for (const value of values) {
map.set(value.contractInstanceHint.address.toString(), value);
}
return map;
};
return new AvmExecutionHints(
reader.readVector(AvmEnqueuedCallHint),
reader.readVector(AvmContractInstanceHint),
reader.readVector(AvmContractBytecodeHints),
readMap(reader),
reader.readVector(AvmPublicDataReadTreeHint),
reader.readVector(AvmPublicDataWriteTreeHint),
reader.readVector(AvmNullifierReadTreeHint),
Expand Down
6 changes: 5 additions & 1 deletion yarn-project/circuits.js/src/tests/factories.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1233,6 +1233,10 @@ export function makeVector<T extends Bufferable>(length: number, fn: (i: number)
return new Vector(makeArray(length, fn, offset));
}

export function makeMap<T extends Bufferable>(size: number, fn: (i: number) => [string, T], offset = 0) {
return new Map(makeArray(size, i => fn(i + offset)));
}

export function makeContractInstanceFromClassId(classId: Fr, seed = 0): ContractInstanceWithAddress {
const salt = new Fr(seed);
const initializationHash = new Fr(seed + 1);
Expand Down Expand Up @@ -1368,7 +1372,7 @@ export function makeAvmExecutionHints(
return AvmExecutionHints.from({
enqueuedCalls: makeVector(baseLength, makeAvmEnqueuedCallHint, seed + 0x4100),
contractInstances: makeVector(baseLength + 5, makeAvmContractInstanceHint, seed + 0x4700),
contractBytecodeHints: makeVector(baseLength + 6, makeAvmBytecodeHints, seed + 0x4800),
contractBytecodeHints: makeMap(baseLength + 7, i => [i.toString(), makeAvmBytecodeHints(i)], seed + 0x4900),
publicDataReads: makeVector(baseLength + 7, makeAvmStorageReadTreeHints, seed + 0x4900),
publicDataWrites: makeVector(baseLength + 8, makeAvmStorageUpdateTreeHints, seed + 0x4a00),
nullifierReads: makeVector(baseLength + 9, makeAvmNullifierReadTreeHints, seed + 0x4b00),
Expand Down
4 changes: 1 addition & 3 deletions yarn-project/package.common.json
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@
"testTimeout": 30000,
"testRegex": "./src/.*\\.test\\.(js|mjs|ts)$",
"rootDir": "./src",
"setupFiles": [
"../../foundation/src/jest/setup.mjs"
]
"setupFiles": ["../../foundation/src/jest/setup.mjs"]
}
}
13 changes: 13 additions & 0 deletions yarn-project/simulator/src/public/bytecode_errors.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
export class UniqueContractCallsLimitReachedError extends Error {
constructor(limit: number) {
super(`Reached the limit on number of unique contract calss per tx: ${limit}`);
this.name = 'UniqueContractCallsLimitReachedError';
}
}

export class ContractClassBytecodeError extends Error {
constructor(contractAddress: string) {
super(`Failed to get bytecode for contract at address ${contractAddress}`);
this.name = 'ContractClassBytecodeError';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ describe('Enqueued-call Side Effect Trace', () => {
);

const membershipHint = new AvmNullifierReadTreeHint(lowLeafPreimage, lowLeafIndex, lowLeafSiblingPath);
expect(trace.getAvmCircuitHints().contractBytecodeHints.items).toEqual([
expect(trace.getAvmCircuitHints().contractBytecodeHints.size).toEqual([
{
bytecode,
contractInstanceHint: { address, exists, ...instanceWithoutVersion, membershipHint: { ...membershipHint } },
Expand Down Expand Up @@ -352,7 +352,7 @@ describe('Enqueued-call Side Effect Trace', () => {
const childHints = nestedTrace.getAvmCircuitHints();
expect(parentHints.enqueuedCalls.items).toEqual(childHints.enqueuedCalls.items);
expect(parentHints.contractInstances.items).toEqual(childHints.contractInstances.items);
expect(parentHints.contractBytecodeHints.items).toEqual(childHints.contractBytecodeHints.items);
expect(parentHints.contractBytecodeHints).toEqual(childHints.contractBytecodeHints);
expect(parentHints.publicDataReads.items).toEqual(childHints.publicDataReads.items);
expect(parentHints.publicDataWrites.items).toEqual(childHints.publicDataWrites.items);
expect(parentHints.nullifierReads.items).toEqual(childHints.nullifierReads.items);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { UnencryptedFunctionL2Logs, UnencryptedL2Log } from '@aztec/circuit-types';
import {
AVM_MAX_UNIQUE_CONTRACT_CALLS,
AvmAccumulatedData,
AvmAppendTreeHint,
AvmCircuitPublicInputs,
Expand Down Expand Up @@ -54,6 +55,7 @@ import { strict as assert } from 'assert';

import { type AvmFinalizedCallResult } from '../avm/avm_contract_call_result.js';
import { type AvmExecutionEnvironment } from '../avm/avm_execution_environment.js';
import { UniqueContractCallsLimitReachedError } from './bytecode_errors.js';
import { type EnqueuedPublicCallExecutionResultWithSideEffects, type PublicFunctionCallResult } from './execution.js';
import { SideEffectLimitReachedError } from './side_effect_errors.js';
import { type PublicSideEffectTraceInterface } from './side_effect_trace_interface.js';
Expand Down Expand Up @@ -124,6 +126,8 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI
* otherwise the public kernel can fail to prove because TX limits are breached.
*/
private readonly previousSideEffectArrayLengths: SideEffectArrayLengths = SideEffectArrayLengths.empty(),
/** We need to thread through the previous bytecode hints maps */
private readonly previousBytecodeHints: Map<string, AvmContractBytecodeHints> = new Map(),
) {
this.log.debug(`Creating trace instance with startSideEffectCounter: ${startSideEffectCounter}`);
this.sideEffectCounter = startSideEffectCounter;
Expand All @@ -140,6 +144,7 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI
this.previousSideEffectArrayLengths.l2ToL1Msgs + this.l2ToL1Messages.length,
this.previousSideEffectArrayLengths.unencryptedLogs + this.unencryptedLogs.length,
),
this.avmCircuitHints.contractBytecodeHints,
);
}

Expand Down Expand Up @@ -169,7 +174,13 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI
this.avmCircuitHints.enqueuedCalls.items.push(...forkedTrace.avmCircuitHints.enqueuedCalls.items);

this.avmCircuitHints.contractInstances.items.push(...forkedTrace.avmCircuitHints.contractInstances.items);
this.avmCircuitHints.contractBytecodeHints.items.push(...forkedTrace.avmCircuitHints.contractBytecodeHints.items);
// We want to merge the bytecode hints from forked, but we dont want to overwrite the existing ones (as they
// contain already existing membership checks against earlier state roots)
for (let [contractClassId, bytecodeHint] of forkedTrace.avmCircuitHints.contractBytecodeHints) {
if (!this.avmCircuitHints.contractBytecodeHints.has(contractClassId)) {
this.avmCircuitHints.contractBytecodeHints.set(contractClassId, bytecodeHint);
}
}

this.avmCircuitHints.publicDataReads.items.push(...forkedTrace.avmCircuitHints.publicDataReads.items);
this.avmCircuitHints.publicDataWrites.items.push(...forkedTrace.avmCircuitHints.publicDataWrites.items);
Expand Down Expand Up @@ -388,6 +399,25 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI
lowLeafIndex: Fr = Fr.zero(),
lowLeafPath: Fr[] = emptyNullifierPath(),
) {
// We have hinted this bytecode, we do nothing.
if (
this.previousBytecodeHints.has(contractInstance.contractClassId.toString()) ||
this.avmCircuitHints.contractBytecodeHints.has(contractInstance.contractClassId.toString())
) {
this.log.debug(
`Contract class id ${contractInstance.contractClassId.toString()} already exists in previous hints`,
);
return;
}

// Before adding this hint, check that we won't be exceeding the MAX_UNIQUE_CONTRACTS_PER_TX
if (
this.previousBytecodeHints.size + this.avmCircuitHints.contractBytecodeHints.size >=
AVM_MAX_UNIQUE_CONTRACT_CALLS
) {
throw new UniqueContractCallsLimitReachedError(AVM_MAX_UNIQUE_CONTRACT_CALLS);
}

const membershipHint = new AvmNullifierReadTreeHint(lowLeafPreimage, lowLeafIndex, lowLeafPath);
const instance = new AvmContractInstanceHint(
contractAddress,
Expand All @@ -399,8 +429,9 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI
contractInstance.publicKeys,
membershipHint,
);
// We need to deduplicate the contract instances based on addresses
this.avmCircuitHints.contractBytecodeHints.items.push(
// We need to deduplicate the contract instances based on contract class id
this.avmCircuitHints.contractBytecodeHints.set(
contractInstance.contractClassId.toString(),
new AvmContractBytecodeHints(bytecode, instance, contractClass),
);
this.log.debug(
Expand Down

0 comments on commit 8ed7693

Please sign in to comment.