Skip to content

Commit

Permalink
refactor: private reset kernel (#7984)
Browse files Browse the repository at this point in the history
Please read [contributing guidelines](CONTRIBUTING.md) and remove this
line.
  • Loading branch information
LeilaWang authored Aug 15, 2024
1 parent cee9d9a commit 0d82c79
Show file tree
Hide file tree
Showing 64 changed files with 1,603 additions and 1,361 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
};
use dep::protocol_types::{
hash::sha256_to_field, address::AztecAddress, point::Point, abis::note_hash::NoteHash,
constants::MAX_NOTE_HASHES_PER_CALL, utils::arrays::find_index
constants::MAX_NOTE_HASHES_PER_CALL
};

unconstrained fn compute_unconstrained<Note, let N: u32, let NB: u32, let M: u32>(
Expand Down Expand Up @@ -61,13 +61,8 @@ fn emit_with_keys<Note, let N: u32, let NB: u32, let M: u32>(
let note_hash_counter = note_header.note_hash_counter;
let storage_slot = note_header.storage_slot;

let note_exists_index = find_index(
context.note_hashes.storage,
|n: NoteHash| n.counter == note_hash_counter
);
assert(
note_exists_index as u32 != MAX_NOTE_HASHES_PER_CALL, "Can only emit a note log for an existing note."
);
let note_exists = context.note_hashes.storage.any(|n: NoteHash| n.counter == note_hash_counter);
assert(note_exists, "Can only emit a note log for an existing note.");

let contract_address: AztecAddress = context.this_address();
let ovsk_app: Field = context.request_ovsk_app(ovpk.hash());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use dep::types::{
abis::{kernel_circuit_public_inputs::PrivateKernelCircuitPublicInputs, note_hash::ScopedNoteHash},
constants::{MAX_NOTE_HASHES_PER_TX, MAX_NULLIFIERS_PER_TX}, utils::arrays::find_index
constants::{MAX_NOTE_HASHES_PER_TX, MAX_NULLIFIERS_PER_TX}, utils::arrays::find_index_hint
};

struct PreviousKernelValidatorHints {
Expand All @@ -13,7 +13,7 @@ unconstrained pub fn generate_previous_kernel_validator_hints(previous_kernel: P
let nullifiers = previous_kernel.end.nullifiers;
for i in 0..nullifiers.len() {
let nullified_note_hash = nullifiers[i].nullifier.note_hash;
let note_hash_index = find_index(
let note_hash_index = find_index_hint(
note_hashes,
|n: ScopedNoteHash| n.value() == nullified_note_hash
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,9 @@ use dep::types::{
address::{AztecAddress, PartialAddress}, contract_class_id::ContractClassId,
constants::MAX_FIELD_VALUE,
hash::{private_functions_root_from_siblings, stdlib_recursion_verification_key_compress_native_vk},
traits::is_empty, transaction::tx_request::TxRequest, utils::arrays::find_index
traits::is_empty, transaction::tx_request::TxRequest, utils::arrays::find_index_hint
};

unconstrained fn match_log_to_note<N>(
note_log: NoteLogHash,
accumulated_note_hashes: [ScopedNoteHash; N]
) -> u32 {
find_index(
accumulated_note_hashes,
|n: ScopedNoteHash| n.counter() == note_log.note_hash_counter
)
}

unconstrained fn find_first_revertible_private_call_request_index(public_inputs: PrivateCircuitPublicInputs) -> u32 {
find_first_revertible_item_index(
public_inputs.min_revertible_side_effect_counter,
Expand Down Expand Up @@ -60,7 +50,12 @@ fn validate_call_context(
}
}

fn validate_incrementing_counters_within_range<T, N>(counter_start: u32, counter_end: u32, items: [T; N], num_items: u32) where T: Ordered {
fn validate_incrementing_counters_within_range<T, let N: u32>(
counter_start: u32,
counter_end: u32,
items: [T; N],
num_items: u32
) where T: Ordered {
let mut prev_counter = counter_start;
let mut should_check = true;
for i in 0..N {
Expand All @@ -76,12 +71,7 @@ fn validate_incrementing_counters_within_range<T, N>(counter_start: u32, counter
assert(prev_counter < counter_end, "counter must be smaller than the end counter of the call");
}

fn validate_incrementing_counter_ranges_within_range<T, N>(
counter_start: u32,
counter_end: u32,
items: [T; N],
num_items: u32
) where T: RangeOrdered {
fn validate_incrementing_counter_ranges_within_range<T, let N: u32>(counter_start: u32, counter_end: u32, items: [T; N], num_items: u32) where T: RangeOrdered {
let mut prev_counter = counter_start;
let mut should_check = true;
for i in 0..N {
Expand Down Expand Up @@ -111,7 +101,7 @@ impl PrivateCallDataValidator {
PrivateCallDataValidator { data, array_lengths }
}

pub fn validate<N>(self, accumulated_note_hashes: [ScopedNoteHash; N]) {
pub fn validate<let N: u32>(self, accumulated_note_hashes: [ScopedNoteHash; N]) {
self.validate_contract_address();
self.validate_call();
self.validate_private_call_requests();
Expand All @@ -138,7 +128,7 @@ impl PrivateCallDataValidator {
tx_request.origin, call_stack_item.contract_address, "origin address does not match call stack items contract address"
);
assert_eq(
tx_request.function_data.hash(), call_stack_item.function_data.hash(), "tx_request function_data must match call_stack_item function_data"
tx_request.function_data, call_stack_item.function_data, "tx_request function_data must match call_stack_item function_data"
);
assert_eq(
tx_request.args_hash, call_stack_item.public_inputs.args_hash, "noir function args passed to tx_request must match args in the call_stack_item"
Expand Down Expand Up @@ -394,18 +384,22 @@ impl PrivateCallDataValidator {
);
}

fn validate_note_logs<N>(self, accumulated_note_hashes: [ScopedNoteHash; N]) {
fn validate_note_logs<let N: u32>(self, accumulated_note_hashes: [ScopedNoteHash; N]) {
let note_logs = self.data.call_stack_item.public_inputs.note_encrypted_logs_hashes;
let num_logs = self.array_lengths.note_encrypted_logs_hashes;
let storage_contract_address = self.data.call_stack_item.public_inputs.call_context.storage_contract_address;
let mut should_check = true;
for i in 0..note_logs.len() {
should_check &= i != num_logs;
if should_check {
let note_index = match_log_to_note(note_logs[i], accumulated_note_hashes);
let note_log = note_logs[i];
let note_index = find_index_hint(
accumulated_note_hashes,
|n: ScopedNoteHash| n.counter() == note_log.note_hash_counter
);
assert(note_index != N, "could not find note hash linked to note log");
assert_eq(
note_logs[i].note_hash_counter, accumulated_note_hashes[note_index].counter(), "could not find note hash linked to note log"
note_log.note_hash_counter, accumulated_note_hashes[note_index].counter(), "could not find note hash linked to note log"
);
// If the note_index points to an empty note hash, the following check will fail.
assert_eq(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
mod reset_output_hints;

use crate::components::reset_output_composer::{reset_output_hints::{generate_reset_output_hints, ResetOutputHints}};
use dep::reset_kernel_lib::{
KeyValidationHint, NoteHashReadRequestHints, NullifierReadRequestHints, TransientDataIndexHint,
PrivateValidationRequestProcessor
};
use dep::types::{
abis::{
kernel_circuit_public_inputs::PrivateKernelCircuitPublicInputs,
log_hash::{NoteLogHash, ScopedEncryptedLogHash}, note_hash::ScopedNoteHash,
nullifier::ScopedNullifier
nullifier::ScopedNullifier, validation_requests::PrivateValidationRequests
},
address::AztecAddress,
constants::{
Expand All @@ -15,64 +19,83 @@ use dep::types::{
hash::{mask_encrypted_log_hash, silo_note_hash, silo_nullifier}, utils::arrays::sort_by_counters_asc
};

struct PrivateKernelResetOutputs {
note_hashes: [ScopedNoteHash; MAX_NOTE_HASHES_PER_TX],
nullifiers: [ScopedNullifier; MAX_NULLIFIERS_PER_TX],
note_encrypted_log_hashes: [NoteLogHash; MAX_NOTE_ENCRYPTED_LOGS_PER_TX],
encrypted_log_hashes: [ScopedEncryptedLogHash; MAX_ENCRYPTED_LOGS_PER_TX],
}

struct ResetOutputComposer {
struct ResetOutputComposer<
let NH_RR_PENDING: u32,
let NH_RR_SETTLED: u32,
let NLL_RR_PENDING: u32,
let NLL_RR_SETTLED: u32,
let KEY_VALIDATION_REQUESTS: u32,
> {
previous_kernel: PrivateKernelCircuitPublicInputs,
validation_request_processor: PrivateValidationRequestProcessor<NH_RR_PENDING, NH_RR_SETTLED, NLL_RR_PENDING, NLL_RR_SETTLED, KEY_VALIDATION_REQUESTS>,
note_hash_siloing_amount: u32,
nullifier_siloing_amount: u32,
encrypted_log_siloing_amount: u32,
hints: ResetOutputHints,
}

impl ResetOutputComposer {
pub fn new(
impl<
let NH_RR_PENDING: u32,
let NH_RR_SETTLED: u32,
let NLL_RR_PENDING: u32,
let NLL_RR_SETTLED: u32,
let KEY_VALIDATION_REQUESTS: u32,
> ResetOutputComposer<
NH_RR_PENDING,
NH_RR_SETTLED,
NLL_RR_PENDING,
NLL_RR_SETTLED,
KEY_VALIDATION_REQUESTS,
> {
pub fn new<let TRANSIENT_DATA_AMOUNT: u32>(
previous_kernel: PrivateKernelCircuitPublicInputs,
transient_nullifier_indexes_for_note_hashes: [u32; MAX_NOTE_HASHES_PER_TX],
transient_note_hash_indexes_for_nullifiers: [u32; MAX_NULLIFIERS_PER_TX],
validation_request_processor: PrivateValidationRequestProcessor<NH_RR_PENDING, NH_RR_SETTLED, NLL_RR_PENDING, NLL_RR_SETTLED, KEY_VALIDATION_REQUESTS>,
transient_data_index_hints: [TransientDataIndexHint; TRANSIENT_DATA_AMOUNT],
note_hash_siloing_amount: u32,
nullifier_siloing_amount: u32,
encrypted_log_siloing_amount: u32
) -> Self {
let hints = generate_reset_output_hints(
let hints = generate_reset_output_hints(previous_kernel, transient_data_index_hints);
ResetOutputComposer {
previous_kernel,
transient_nullifier_indexes_for_note_hashes,
transient_note_hash_indexes_for_nullifiers
);
ResetOutputComposer { previous_kernel, note_hash_siloing_amount, nullifier_siloing_amount, encrypted_log_siloing_amount, hints }
validation_request_processor,
note_hash_siloing_amount,
nullifier_siloing_amount,
encrypted_log_siloing_amount,
hints
}
}

pub fn finish(self) -> PrivateKernelResetOutputs {
let note_hashes = if self.note_hash_siloing_amount == 0 {
pub fn finish(self) -> PrivateKernelCircuitPublicInputs {
let mut output = self.previous_kernel;

output.validation_requests = self.validation_request_processor.compose();

output.end.note_hashes = if self.note_hash_siloing_amount == 0 {
self.hints.kept_note_hashes
} else {
self.get_sorted_siloed_note_hashes()
};

let nullifiers = if self.nullifier_siloing_amount == 0 {
output.end.nullifiers = if self.nullifier_siloing_amount == 0 {
self.hints.kept_nullifiers
} else {
self.get_sorted_siloed_nullifiers()
};

let note_encrypted_log_hashes = if self.note_hash_siloing_amount == 0 {
output.end.note_encrypted_logs_hashes = if self.note_hash_siloing_amount == 0 {
self.hints.kept_note_encrypted_log_hashes
} else {
self.get_sorted_note_encrypted_log_hashes()
};

let encrypted_log_hashes = if self.encrypted_log_siloing_amount == 0 {
output.end.encrypted_logs_hashes = if self.encrypted_log_siloing_amount == 0 {
self.previous_kernel.end.encrypted_logs_hashes
} else {
self.get_sorted_masked_encrypted_log_hashes()
};

PrivateKernelResetOutputs { note_hashes, nullifiers, note_encrypted_log_hashes, encrypted_log_hashes }
output
}

fn get_sorted_siloed_note_hashes(self) -> [ScopedNoteHash; MAX_NOTE_HASHES_PER_TX] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::components::reset_output_composer::reset_output_hints::{
get_transient_or_propagated_note_hash_indexes_for_logs::get_transient_or_propagated_note_hash_indexes_for_logs,
squash_transient_data::squash_transient_data
};
use dep::reset_kernel_lib::TransientDataIndexHint;
use dep::types::{
abis::{
kernel_circuit_public_inputs::PrivateKernelCircuitPublicInputs, log_hash::NoteLogHash,
Expand Down Expand Up @@ -32,17 +33,15 @@ struct ResetOutputHints {
sorted_encrypted_log_hash_indexes: [u32; MAX_ENCRYPTED_LOGS_PER_TX],
}

pub fn generate_reset_output_hints(
unconstrained pub fn generate_reset_output_hints<let NUM_TRANSIENT_DATA_INDEX_HINTS: u32>(
previous_kernel: PrivateKernelCircuitPublicInputs,
transient_nullifier_indexes_for_note_hashes: [u32; MAX_NOTE_HASHES_PER_TX],
transient_note_hash_indexes_for_nullifiers: [u32; MAX_NULLIFIERS_PER_TX]
transient_data_index_hints: [TransientDataIndexHint; NUM_TRANSIENT_DATA_INDEX_HINTS]
) -> ResetOutputHints {
let (kept_note_hashes, kept_nullifiers, kept_note_encrypted_log_hashes) = squash_transient_data(
previous_kernel.end.note_hashes,
previous_kernel.end.nullifiers,
previous_kernel.end.note_encrypted_logs_hashes,
transient_nullifier_indexes_for_note_hashes,
transient_note_hash_indexes_for_nullifiers
transient_data_index_hints
);

// note_hashes
Expand All @@ -56,7 +55,8 @@ pub fn generate_reset_output_hints(
let transient_or_propagated_note_hash_indexes_for_logs = get_transient_or_propagated_note_hash_indexes_for_logs(
previous_kernel.end.note_encrypted_logs_hashes,
previous_kernel.end.note_hashes,
kept_note_hashes
kept_note_hashes,
transient_data_index_hints
);

// encrypted_log_hashes
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use dep::types::{abis::{log_hash::NoteLogHash, note_hash::ScopedNoteHash}};
use dep::reset_kernel_lib::TransientDataIndexHint;
use dep::types::{abis::{log_hash::NoteLogHash, note_hash::ScopedNoteHash}, utils::arrays::find_index_hint};

pub fn get_transient_or_propagated_note_hash_indexes_for_logs<let NUM_LOGS: u32, let NUM_NOTE_HASHES: u32>(
unconstrained pub fn get_transient_or_propagated_note_hash_indexes_for_logs<let NUM_LOGS: u32, let NUM_NOTE_HASHES: u32, let NUM_INDEX_HINTS: u32>(
note_logs: [NoteLogHash; NUM_LOGS],
note_hashes: [ScopedNoteHash; NUM_NOTE_HASHES],
expected_note_hashes: [ScopedNoteHash; NUM_NOTE_HASHES]
expected_note_hashes: [ScopedNoteHash; NUM_NOTE_HASHES],
transient_data_index_hints: [TransientDataIndexHint; NUM_INDEX_HINTS]
) -> [u32; NUM_LOGS] {
let mut indexes = [0; NUM_LOGS];
for i in 0..note_logs.len() {
Expand All @@ -18,7 +20,7 @@ pub fn get_transient_or_propagated_note_hash_indexes_for_logs<let NUM_LOGS: u32,
if !propagated {
for j in 0..note_hashes.len() {
if note_hashes[j].counter() == log_note_hash_counter {
indexes[i] = j;
indexes[i] = find_index_hint(transient_data_index_hints, |hint: TransientDataIndexHint| hint.note_hash_index == j);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
use dep::reset_kernel_lib::TransientDataIndexHint;
use dep::types::abis::{note_hash::ScopedNoteHash, nullifier::ScopedNullifier, log_hash::NoteLogHash};

unconstrained pub fn squash_transient_data<let M: u32, let N: u32, let P: u32>(
unconstrained pub fn squash_transient_data<let M: u32, let N: u32, let P: u32, let NUM_TRANSIENT_DATA_INDEX_HINTS: u32>(
note_hashes: [ScopedNoteHash; M],
nullifiers: [ScopedNullifier; N],
logs: [NoteLogHash; P],
transient_nullifier_indexes_for_note_hashes: [u32; M],
transient_note_hash_indexes_for_nullifiers: [u32; N]
transient_data_index_hints: [TransientDataIndexHint; NUM_TRANSIENT_DATA_INDEX_HINTS]
) -> ([ScopedNoteHash; M], [ScopedNullifier; N], [NoteLogHash; P]) {
let mut transient_nullifier_indexes_for_note_hashes = [N; M];
let mut transient_note_hash_indexes_for_nullifiers = [M; N];
for i in 0..transient_data_index_hints.len() {
let hint = transient_data_index_hints[i];
if hint.note_hash_index != M {
transient_nullifier_indexes_for_note_hashes[hint.note_hash_index] = hint.nullifier_index;
transient_note_hash_indexes_for_nullifiers[hint.nullifier_index] = hint.note_hash_index;
}
}

let mut propagated_note_hashes = BoundedVec::new();
for i in 0..note_hashes.len() {
if transient_nullifier_indexes_for_note_hashes[i] == N {
Expand Down
Loading

0 comments on commit 0d82c79

Please sign in to comment.