Skip to content

Commit

Permalink
feat: validate counters (AztecProtocol#6365)
Browse files Browse the repository at this point in the history
Please read [contributing guidelines](CONTRIBUTING.md) and remove this
line.
  • Loading branch information
LeilaWang authored May 15, 2024
1 parent 3d78751 commit 1f28b3a
Show file tree
Hide file tree
Showing 30 changed files with 1,756 additions and 924 deletions.
9 changes: 2 additions & 7 deletions noir-projects/aztec-nr/aztec/src/context/private_context.nr
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,10 @@ impl ContextInterface for PrivateContext {

impl PrivateContext {
pub fn new(inputs: PrivateContextInputs, args_hash: Field) -> PrivateContext {
let side_effect_counter = inputs.start_side_effect_counter;
let mut min_revertible_side_effect_counter = 0;
if is_empty(inputs.call_context.msg_sender) {
min_revertible_side_effect_counter = side_effect_counter;
}
PrivateContext {
inputs,
side_effect_counter,
min_revertible_side_effect_counter,
side_effect_counter: inputs.start_side_effect_counter + 1,
min_revertible_side_effect_counter: 0,
is_fee_payer: false,
args_hash,
return_hash: 0,
Expand Down
2 changes: 1 addition & 1 deletion noir-projects/aztec-nr/aztec/src/context/public_context.nr
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ impl PublicContext {
pub fn new(inputs: PublicContextInputs, args_hash: Field) -> PublicContext {
PublicContext {
inputs,
side_effect_counter: inputs.start_side_effect_counter,
side_effect_counter: inputs.start_side_effect_counter + 1,
args_hash,
return_hash: 0,
nullifier_read_requests: BoundedVec::new(),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use dep::types::{
abis::{
call_context::CallContext, call_request::CallRequest, private_call_stack_item::PrivateCallStackItem,
private_kernel::private_call_data::PrivateCallData
private_kernel::private_call_data::PrivateCallData, side_effect::Ordered
},
address::{AztecAddress, PartialAddress}, contract_class_id::ContractClassId,
hash::{private_functions_root_from_siblings, stdlib_recursion_verification_key_compress_native_vk},
Expand Down Expand Up @@ -50,9 +50,92 @@ fn validate_call_request(request: CallRequest, hash: Field, caller: PrivateCallS
}
}

fn validate_call_requests<N>(call_requests: [CallRequest; N], hashes: [Field; N], caller: PrivateCallStackItem) {
fn validate_incrementing_counters_within_range<T, N>(
counter_start: u32,
counter_end: u32,
items: [T; N],
num_items: u64
) where T: Ordered {
let mut prev_counter = counter_start;
let mut should_check = true;
for i in 0..N {
validate_call_request(call_requests[i], hashes[i], caller);
should_check &= i != num_items;
if should_check {
let item = items[i];
assert(
item.counter() > prev_counter, "counter must be larger than the counter of the previous item"
);
prev_counter = item.counter();
}
}
assert(prev_counter < counter_end, "counter must be smaller than the end counter of the call");
}

fn validate_incrementing_counter_ranges_within_range<N>(
counter_start: u32,
counter_end: u32,
items: [CallRequest; N],
num_items: u64
) {
let mut prev_counter = counter_start;
let mut should_check = true;
for i in 0..N {
should_check &= i != num_items;
if should_check {
let item = items[i];
assert(
item.start_side_effect_counter > prev_counter, "start counter must be larger than the end counter of the previous call"
);
assert(
item.end_side_effect_counter > item.start_side_effect_counter, "nested call has incorrect counter range"
);
prev_counter = item.end_side_effect_counter;
}
}
assert(
prev_counter < counter_end, "end counter must be smaller than the end counter of the parent call"
);
}

fn validate_split_private_call_requests<N>(
min_revertible_side_effect_counter: u32,
first_revertible_call_request_index: u64,
call_requests: [CallRequest; N],
num_call_requests: u64
) {
if first_revertible_call_request_index != 0 {
let last_non_revertible_call_request_index = first_revertible_call_request_index - 1;
let call_request = call_requests[last_non_revertible_call_request_index];
assert(
min_revertible_side_effect_counter > call_request.end_side_effect_counter, "min_revertible_side_effect_counter must be greater than the end counter of the last non revertible call"
);
}
if first_revertible_call_request_index != num_call_requests {
let call_request = call_requests[first_revertible_call_request_index];
assert(
min_revertible_side_effect_counter <= call_request.start_side_effect_counter, "min_revertible_side_effect_counter must be less than or equal to the start counter of the first revertible call"
);
}
}

fn validate_split_public_call_requests<N>(
min_revertible_side_effect_counter: u32,
first_revertible_call_request_index: u64,
call_requests: [CallRequest; N],
num_call_requests: u64
) {
if first_revertible_call_request_index != 0 {
let last_non_revertible_call_request_index = first_revertible_call_request_index - 1;
let call_request = call_requests[last_non_revertible_call_request_index];
assert(
min_revertible_side_effect_counter > call_request.counter(), "min_revertible_side_effect_counter must be greater than the counter of the last non revertible call"
);
}
if first_revertible_call_request_index != num_call_requests {
let call_request = call_requests[first_revertible_call_request_index];
assert(
min_revertible_side_effect_counter <= call_request.counter(), "min_revertible_side_effect_counter must be less than or equal to the counter of the first revertible call"
);
}
}

Expand Down Expand Up @@ -86,6 +169,34 @@ impl PrivateCallDataValidator {
self.validate_private_call_requests();
self.validate_public_call_requests();
self.validate_teardown_call_request();
self.validate_counters();
}

pub fn validate_as_first_call(
self,
first_revertible_private_call_request_index: u64,
first_revertible_public_call_request_index: u64
) {
let public_inputs = self.data.call_stack_item.public_inputs;
let call_context = public_inputs.call_context;
assert(call_context.is_delegate_call == false, "Users cannot make a delegatecall");
assert(call_context.is_static_call == false, "Users cannot make a static call");

let min_revertible_side_effect_counter = public_inputs.min_revertible_side_effect_counter;
// No need to check that the min_revertible_side_effect_counter falls in the counter range of the private call.
// It is valid as long as it does not fall in the middle of any nested call.
validate_split_private_call_requests(
min_revertible_side_effect_counter,
first_revertible_private_call_request_index,
self.data.private_call_stack,
self.array_lengths.private_call_stack_hashes
);
validate_split_public_call_requests(
min_revertible_side_effect_counter,
first_revertible_public_call_request_index,
self.data.public_call_stack,
self.array_lengths.public_call_stack_hashes
);
}

// Confirm that the TxRequest (user's intent) matches the private call being executed.
Expand All @@ -103,11 +214,6 @@ impl PrivateCallDataValidator {
assert_eq(
tx_request.tx_context, call_stack_item.public_inputs.tx_context, "tx_context in tx_request must match tx_context in call_stack_item"
);

// If checking against TxRequest, it must be the first call, which has the following restrictions.
let call_context = call_stack_item.public_inputs.call_context;
assert(call_context.is_delegate_call == false, "Users cannot make a delegatecall");
assert(call_context.is_static_call == false, "Users cannot make a static call");
}

pub fn validate_against_call_request(self, request: CallRequest) {
Expand Down Expand Up @@ -205,19 +311,19 @@ impl PrivateCallDataValidator {
}

fn validate_private_call_requests(self) {
validate_call_requests(
self.data.private_call_stack,
self.data.call_stack_item.public_inputs.private_call_stack_hashes,
self.data.call_stack_item
);
let call_requests = self.data.private_call_stack;
let hashes = self.data.call_stack_item.public_inputs.private_call_stack_hashes;
for i in 0..call_requests.len() {
validate_call_request(call_requests[i], hashes[i], self.data.call_stack_item);
}
}

fn validate_public_call_requests(self) {
validate_call_requests(
self.data.public_call_stack,
self.data.call_stack_item.public_inputs.public_call_stack_hashes,
self.data.call_stack_item
);
let call_requests = self.data.public_call_stack;
let hashes = self.data.call_stack_item.public_inputs.public_call_stack_hashes;
for i in 0..call_requests.len() {
validate_call_request(call_requests[i], hashes[i], self.data.call_stack_item);
}
}

fn validate_teardown_call_request(self) {
Expand All @@ -227,4 +333,81 @@ impl PrivateCallDataValidator {
self.data.call_stack_item
);
}

fn validate_counters(self) {
let public_inputs = self.data.call_stack_item.public_inputs;
let counter_start = public_inputs.start_side_effect_counter;
let counter_end = public_inputs.end_side_effect_counter;

assert(counter_start < counter_end, "private call has incorrect counter range");

validate_incrementing_counters_within_range(
counter_start,
counter_end,
public_inputs.note_hash_read_requests,
self.array_lengths.note_hash_read_requests
);
validate_incrementing_counters_within_range(
counter_start,
counter_end,
public_inputs.nullifier_read_requests,
self.array_lengths.nullifier_read_requests
);
validate_incrementing_counters_within_range(
counter_start,
counter_end,
public_inputs.new_note_hashes,
self.array_lengths.new_note_hashes
);
validate_incrementing_counters_within_range(
counter_start,
counter_end,
public_inputs.new_nullifiers,
self.array_lengths.new_nullifiers
);
validate_incrementing_counters_within_range(
counter_start,
counter_end,
public_inputs.new_l2_to_l1_msgs,
self.array_lengths.new_l2_to_l1_msgs
);
validate_incrementing_counters_within_range(
counter_start,
counter_end,
public_inputs.encrypted_logs_hashes,
self.array_lengths.encrypted_logs_hashes
);
validate_incrementing_counters_within_range(
counter_start,
counter_end,
public_inputs.unencrypted_logs_hashes,
self.array_lengths.unencrypted_logs_hashes
);
validate_incrementing_counter_ranges_within_range(
counter_start,
counter_end,
self.data.private_call_stack,
self.array_lengths.private_call_stack_hashes
);

// Validate the public call requests by checking their start counters only, as their end counters are unknown.
validate_incrementing_counters_within_range(
counter_start,
counter_end,
self.data.public_call_stack,
self.array_lengths.public_call_stack_hashes
);

let teardown_call_request_count = if self.data.public_teardown_call_request.hash == 0 {
0
} else {
1
};
validate_incrementing_counters_within_range(
counter_start,
counter_end,
[self.data.public_teardown_call_request],
teardown_call_request_count
);
}
}
Loading

0 comments on commit 1f28b3a

Please sign in to comment.