Skip to content

Commit

Permalink
perf: optimize ExecutionReport and create_alu_lookup_id iff mode …
Browse files Browse the repository at this point in the history
…is trace (#1480)
  • Loading branch information
tqn authored Sep 6, 2024
1 parent 770bf0d commit b98a30f
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 120 deletions.
22 changes: 22 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/core/executor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ hex = "0.4.3"
bytemuck = "1.16.3"
tiny-keccak = { version = "2.0.2", features = ["keccak"] }
vec_map = { version = "0.8.2", features = ["serde"] }
enum-map = { version = "2.7.3", features = ["serde"] }

[dev-dependencies]
sp1-zkvm = { workspace = true }
Expand Down
20 changes: 8 additions & 12 deletions crates/core/executor/src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -696,23 +696,19 @@ impl<'a> Executor<'a> {
if self.executor_mode == ExecutorMode::Trace {
self.memory_accesses = MemoryAccessRecord::default();
}
let lookup_id = if self.executor_mode == ExecutorMode::Simple {
LookupId::default()
} else {
let lookup_id = if self.executor_mode == ExecutorMode::Trace {
create_alu_lookup_id()
};
let syscall_lookup_id = if self.executor_mode == ExecutorMode::Simple {
LookupId::default()
} else {
LookupId::default()
};
let syscall_lookup_id = if self.executor_mode == ExecutorMode::Trace {
create_alu_lookup_id()
} else {
LookupId::default()
};

if self.print_report && !self.unconstrained {
self.report
.opcode_counts
.entry(instruction.opcode)
.and_modify(|c| *c += 1)
.or_insert(1);
self.report.opcode_counts[instruction.opcode] += 1;
}

match instruction.opcode {
Expand Down Expand Up @@ -930,7 +926,7 @@ impl<'a> Executor<'a> {
let syscall = SyscallCode::from_u32(syscall_id);

if self.print_report && !self.unconstrained {
self.report.syscall_counts.entry(syscall).and_modify(|c| *c += 1).or_insert(1);
self.report.syscall_counts[syscall] += 1;
}

// `hint_slice` is allowed in unconstrained mode since it is used to write the hint.
Expand Down
5 changes: 4 additions & 1 deletion crates/core/executor/src/opcode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use std::fmt::Display;

use enum_map::Enum;
use p3_field::Field;
use serde::{Deserialize, Serialize};

Expand All @@ -20,7 +21,9 @@ use serde::{Deserialize, Serialize};
/// Refer to the "RV32I Reference Card" [here](https://github.com/johnwinans/rvalp/releases) for
/// more details.
#[allow(non_camel_case_types)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord, Enum,
)]
pub enum Opcode {
/// rd ← rs1 + rs2, pc ← pc + 4
ADD = 0,
Expand Down
27 changes: 12 additions & 15 deletions crates/core/executor/src/report.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
use std::{
collections::{hash_map::Entry, HashMap},
fmt::{Display, Formatter, Result as FmtResult},
hash::Hash,
ops::{Add, AddAssign},
};

use enum_map::{EnumArray, EnumMap};
use hashbrown::HashMap;

use crate::{events::sorted_table_lines, syscalls::SyscallCode, Opcode};

/// An execution report.
#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub struct ExecutionReport {
/// The opcode counts.
pub opcode_counts: HashMap<Opcode, u64>,
pub opcode_counts: Box<EnumMap<Opcode, u64>>,
/// The syscall counts.
pub syscall_counts: HashMap<SyscallCode, u64>,
pub syscall_counts: Box<EnumMap<SyscallCode, u64>>,
/// The cycle tracker counts.
pub cycle_tracker: HashMap<String, u64>,
/// The unique memory address counts.
Expand All @@ -35,24 +36,20 @@ impl ExecutionReport {
}

/// Combines two `HashMap`s together. If a key is in both maps, the values are added together.
fn hashmap_add_assign<K, V>(lhs: &mut HashMap<K, V>, rhs: HashMap<K, V>)
fn counts_add_assign<K, V>(lhs: &mut EnumMap<K, V>, rhs: EnumMap<K, V>)
where
K: Eq + Hash,
K: EnumArray<V>,
V: AddAssign,
{
for (k, v) in rhs {
// Can't use `.and_modify(...).or_insert(...)` because we want to use `v` in both places.
match lhs.entry(k) {
Entry::Occupied(e) => *e.into_mut() += v,
Entry::Vacant(e) => drop(e.insert(v)),
}
lhs[k] += v;
}
}

impl AddAssign for ExecutionReport {
fn add_assign(&mut self, rhs: Self) {
hashmap_add_assign(&mut self.opcode_counts, rhs.opcode_counts);
hashmap_add_assign(&mut self.syscall_counts, rhs.syscall_counts);
counts_add_assign(&mut self.opcode_counts, *rhs.opcode_counts);
counts_add_assign(&mut self.syscall_counts, *rhs.syscall_counts);
self.touched_memory_addresses += rhs.touched_memory_addresses;
}
}
Expand All @@ -69,12 +66,12 @@ impl Add for ExecutionReport {
impl Display for ExecutionReport {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
writeln!(f, "opcode counts ({} total instructions):", self.total_instruction_count())?;
for line in sorted_table_lines(&self.opcode_counts) {
for line in sorted_table_lines(self.opcode_counts.as_ref()) {
writeln!(f, " {line}")?;
}

writeln!(f, "syscall counts ({} total syscall instructions):", self.total_syscall_count())?;
for line in sorted_table_lines(&self.syscall_counts) {
for line in sorted_table_lines(self.syscall_counts.as_ref()) {
writeln!(f, " {line}")?;
}

Expand Down
3 changes: 2 additions & 1 deletion crates/core/executor/src/syscalls/code.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use enum_map::Enum;
use serde::{Deserialize, Serialize};
use strum_macros::EnumIter;

Expand All @@ -18,7 +19,7 @@ use strum_macros::EnumIter;
/// memory accesses is bounded.
/// - Byte 3: Currently unused.
#[derive(
Debug, Copy, Clone, PartialEq, Eq, Hash, EnumIter, Ord, PartialOrd, Serialize, Deserialize,
Debug, Copy, Clone, PartialEq, Eq, Hash, EnumIter, Ord, PartialOrd, Serialize, Deserialize, Enum,
)]
#[allow(non_camel_case_types)]
#[allow(clippy::upper_case_acronyms)]
Expand Down
100 changes: 41 additions & 59 deletions crates/core/machine/src/riscv/cost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,151 +33,133 @@ impl CostEstimator for ExecutionReport {
total_area += (cpu_events as u64) * costs[&RiscvAirDiscriminants::Cpu];
total_chips += 1;

let sha_extend_events = *self.syscall_counts.get(&SyscallCode::SHA_EXTEND).unwrap_or(&0);
let sha_extend_events = self.syscall_counts[SyscallCode::SHA_EXTEND];
total_area += (sha_extend_events as u64) * costs[&RiscvAirDiscriminants::Sha256Extend];
total_chips += 1;

let sha_compress_events =
*self.syscall_counts.get(&SyscallCode::SHA_COMPRESS).unwrap_or(&0);
let sha_compress_events = self.syscall_counts[SyscallCode::SHA_COMPRESS];
total_area += (sha_compress_events as u64) * costs[&RiscvAirDiscriminants::Sha256Compress];
total_chips += 1;

let ed_add_events = *self.syscall_counts.get(&SyscallCode::ED_ADD).unwrap_or(&0);
let ed_add_events = self.syscall_counts[SyscallCode::ED_ADD];
total_area += (ed_add_events as u64) * costs[&RiscvAirDiscriminants::Ed25519Add];
total_chips += 1;

let ed_decompress_events =
*self.syscall_counts.get(&SyscallCode::ED_DECOMPRESS).unwrap_or(&0);
let ed_decompress_events = self.syscall_counts[SyscallCode::ED_DECOMPRESS];
total_area +=
(ed_decompress_events as u64) * costs[&RiscvAirDiscriminants::Ed25519Decompress];
total_chips += 1;

let k256_decompress_events =
*self.syscall_counts.get(&SyscallCode::SECP256K1_DECOMPRESS).unwrap_or(&0);
let k256_decompress_events = self.syscall_counts[SyscallCode::SECP256K1_DECOMPRESS];
total_area +=
(k256_decompress_events as u64) * costs[&RiscvAirDiscriminants::K256Decompress];
total_chips += 1;

let secp256k1_add_events =
*self.syscall_counts.get(&SyscallCode::SECP256K1_ADD).unwrap_or(&0);
let secp256k1_add_events = self.syscall_counts[SyscallCode::SECP256K1_ADD];
total_area += (secp256k1_add_events as u64) * costs[&RiscvAirDiscriminants::Secp256k1Add];
total_chips += 1;

let secp256k1_double_events =
*self.syscall_counts.get(&SyscallCode::SECP256K1_DOUBLE).unwrap_or(&0);
let secp256k1_double_events = self.syscall_counts[SyscallCode::SECP256K1_DOUBLE];
total_area +=
(secp256k1_double_events as u64) * costs[&RiscvAirDiscriminants::Secp256k1Double];
total_chips += 1;

let keccak256_permute_events =
*self.syscall_counts.get(&SyscallCode::KECCAK_PERMUTE).unwrap_or(&0);
let keccak256_permute_events = self.syscall_counts[SyscallCode::KECCAK_PERMUTE];
total_area += (keccak256_permute_events as u64) * costs[&RiscvAirDiscriminants::KeccakP];
total_chips += 1;

let bn254_add_events = *self.syscall_counts.get(&SyscallCode::BN254_ADD).unwrap_or(&0);
let bn254_add_events = self.syscall_counts[SyscallCode::BN254_ADD];
total_area += (bn254_add_events as u64) * costs[&RiscvAirDiscriminants::Bn254Add];
total_chips += 1;

let bn254_double_events =
*self.syscall_counts.get(&SyscallCode::BN254_DOUBLE).unwrap_or(&0);
let bn254_double_events = self.syscall_counts[SyscallCode::BN254_DOUBLE];
total_area += (bn254_double_events as u64) * costs[&RiscvAirDiscriminants::Bn254Double];
total_chips += 1;

let bls12381_add_events =
*self.syscall_counts.get(&SyscallCode::BLS12381_ADD).unwrap_or(&0);
let bls12381_add_events = self.syscall_counts[SyscallCode::BLS12381_ADD];
total_area += (bls12381_add_events as u64) * costs[&RiscvAirDiscriminants::Bls12381Add];
total_chips += 1;

let bls12381_double_events =
*self.syscall_counts.get(&SyscallCode::BLS12381_DOUBLE).unwrap_or(&0);
let bls12381_double_events = self.syscall_counts[SyscallCode::BLS12381_DOUBLE];
total_area +=
(bls12381_double_events as u64) * costs[&RiscvAirDiscriminants::Bls12381Double];
total_chips += 1;

let uint256_mul_events = *self.syscall_counts.get(&SyscallCode::UINT256_MUL).unwrap_or(&0);
let uint256_mul_events = self.syscall_counts[SyscallCode::UINT256_MUL];
total_area += (uint256_mul_events as u64) * costs[&RiscvAirDiscriminants::Uint256Mul];
total_chips += 1;

let bls12381_fp_events =
*self.syscall_counts.get(&SyscallCode::BLS12381_FP_ADD).unwrap_or(&0)
+ *self.syscall_counts.get(&SyscallCode::BLS12381_FP_SUB).unwrap_or(&0)
+ *self.syscall_counts.get(&SyscallCode::BLS12381_FP_MUL).unwrap_or(&0);
let bls12381_fp_events = self.syscall_counts[SyscallCode::BLS12381_FP_ADD]
+ self.syscall_counts[SyscallCode::BLS12381_FP_SUB]
+ self.syscall_counts[SyscallCode::BLS12381_FP_MUL];
total_area += (bls12381_fp_events as u64) * costs[&RiscvAirDiscriminants::Bls12381Fp];
total_chips += 1;

let bls12381_fp2_addsub_events =
*self.syscall_counts.get(&SyscallCode::BLS12381_FP2_ADD).unwrap_or(&0)
+ *self.syscall_counts.get(&SyscallCode::BLS12381_FP2_SUB).unwrap_or(&0);
let bls12381_fp2_addsub_events = self.syscall_counts[SyscallCode::BLS12381_FP2_ADD]
+ self.syscall_counts[SyscallCode::BLS12381_FP2_SUB];
total_area +=
(bls12381_fp2_addsub_events as u64) * costs[&RiscvAirDiscriminants::Bls12381Fp2AddSub];
total_chips += 1;

let bls12381_fp2_mul_events =
*self.syscall_counts.get(&SyscallCode::BLS12381_FP2_MUL).unwrap_or(&0);
let bls12381_fp2_mul_events = self.syscall_counts[SyscallCode::BLS12381_FP2_MUL];
total_area +=
(bls12381_fp2_mul_events as u64) * costs[&RiscvAirDiscriminants::Bls12381Fp2Mul];
total_chips += 1;

let bn254_fp_events = *self.syscall_counts.get(&SyscallCode::BN254_FP_ADD).unwrap_or(&0)
+ *self.syscall_counts.get(&SyscallCode::BN254_FP_SUB).unwrap_or(&0)
+ *self.syscall_counts.get(&SyscallCode::BN254_FP_MUL).unwrap_or(&0);
let bn254_fp_events = self.syscall_counts[SyscallCode::BN254_FP_ADD]
+ self.syscall_counts[SyscallCode::BN254_FP_SUB]
+ self.syscall_counts[SyscallCode::BN254_FP_MUL];
total_area += (bn254_fp_events as u64) * costs[&RiscvAirDiscriminants::Bn254Fp];
total_chips += 1;

let bn254_fp2_addsub_events =
*self.syscall_counts.get(&SyscallCode::BN254_FP2_ADD).unwrap_or(&0)
+ *self.syscall_counts.get(&SyscallCode::BN254_FP2_SUB).unwrap_or(&0);
let bn254_fp2_addsub_events = self.syscall_counts[SyscallCode::BN254_FP2_ADD]
+ self.syscall_counts[SyscallCode::BN254_FP2_SUB];
total_area +=
(bn254_fp2_addsub_events as u64) * costs[&RiscvAirDiscriminants::Bn254Fp2AddSub];
total_chips += 1;

let bn254_fp2_mul_events =
*self.syscall_counts.get(&SyscallCode::BN254_FP2_MUL).unwrap_or(&0);
let bn254_fp2_mul_events = self.syscall_counts[SyscallCode::BN254_FP2_MUL];
total_area += (bn254_fp2_mul_events as u64) * costs[&RiscvAirDiscriminants::Bn254Fp2Mul];
total_chips += 1;

let bls12381_decompress_events =
*self.syscall_counts.get(&SyscallCode::BLS12381_DECOMPRESS).unwrap_or(&0);
let bls12381_decompress_events = self.syscall_counts[SyscallCode::BLS12381_DECOMPRESS];
total_area +=
(bls12381_decompress_events as u64) * costs[&RiscvAirDiscriminants::Bls12381Decompress];
total_chips += 1;

let divrem_events = *self.opcode_counts.get(&Opcode::DIV).unwrap_or(&0)
+ *self.opcode_counts.get(&Opcode::REM).unwrap_or(&0)
+ *self.opcode_counts.get(&Opcode::DIVU).unwrap_or(&0)
+ *self.opcode_counts.get(&Opcode::REMU).unwrap_or(&0);
let divrem_events = self.opcode_counts[Opcode::DIV]
+ self.opcode_counts[Opcode::REM]
+ self.opcode_counts[Opcode::DIVU]
+ self.opcode_counts[Opcode::REMU];
total_area += (divrem_events as u64) * costs[&RiscvAirDiscriminants::DivRem];
total_chips += 1;

let addsub_events = *self.opcode_counts.get(&Opcode::ADD).unwrap_or(&0)
+ *self.opcode_counts.get(&Opcode::SUB).unwrap_or(&0);
let addsub_events = self.opcode_counts[Opcode::ADD] + self.opcode_counts[Opcode::SUB];
total_area += (addsub_events as u64) * costs[&RiscvAirDiscriminants::Add];
total_chips += 1;

let bitwise_events = *self.opcode_counts.get(&Opcode::AND).unwrap_or(&0)
+ *self.opcode_counts.get(&Opcode::OR).unwrap_or(&0)
+ *self.opcode_counts.get(&Opcode::XOR).unwrap_or(&0);
let bitwise_events = self.opcode_counts[Opcode::AND]
+ self.opcode_counts[Opcode::OR]
+ self.opcode_counts[Opcode::XOR];
total_area += (bitwise_events as u64) * costs[&RiscvAirDiscriminants::Bitwise];
total_chips += 1;

let mul_events = *self.opcode_counts.get(&Opcode::MUL).unwrap_or(&0)
+ *self.opcode_counts.get(&Opcode::MULH).unwrap_or(&0)
+ *self.opcode_counts.get(&Opcode::MULHU).unwrap_or(&0)
+ *self.opcode_counts.get(&Opcode::MULHSU).unwrap_or(&0);
let mul_events = self.opcode_counts[Opcode::MUL]
+ self.opcode_counts[Opcode::MULH]
+ self.opcode_counts[Opcode::MULHU]
+ self.opcode_counts[Opcode::MULHSU];
total_area += (mul_events as u64) * costs[&RiscvAirDiscriminants::Mul];
total_chips += 1;

let shift_right_events = *self.opcode_counts.get(&Opcode::SRL).unwrap_or(&0)
+ *self.opcode_counts.get(&Opcode::SRA).unwrap_or(&0);
let shift_right_events = self.opcode_counts[Opcode::SRL] + self.opcode_counts[Opcode::SRA];
total_area += (shift_right_events as u64) * costs[&RiscvAirDiscriminants::ShiftRight];
total_chips += 1;

let shift_left_events = *self.opcode_counts.get(&Opcode::SLL).unwrap_or(&0);
let shift_left_events = self.opcode_counts[Opcode::SLL];
total_area += (shift_left_events as u64) * costs[&RiscvAirDiscriminants::ShiftLeft];
total_chips += 1;

let lt_events = *self.opcode_counts.get(&Opcode::SLT).unwrap_or(&0)
+ *self.opcode_counts.get(&Opcode::SLTU).unwrap_or(&0);
let lt_events = self.opcode_counts[Opcode::SLT] + self.opcode_counts[Opcode::SLTU];
total_area += (lt_events as u64) * costs[&RiscvAirDiscriminants::Lt];
total_chips += 1;

Expand Down
4 changes: 2 additions & 2 deletions crates/core/machine/src/utils/prove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -545,11 +545,11 @@ where
// Print the opcode and syscall count tables like `du`: sorted by count (descending) and
// with the count in the first column.
tracing::info!("execution report (opcode counts):");
for line in sorted_table_lines(&report_aggregate.opcode_counts) {
for line in sorted_table_lines(report_aggregate.opcode_counts.as_ref()) {
tracing::info!(" {line}");
}
tracing::info!("execution report (syscall counts):");
for line in sorted_table_lines(&report_aggregate.syscall_counts) {
for line in sorted_table_lines(report_aggregate.syscall_counts.as_ref()) {
tracing::info!(" {line}");
}

Expand Down
Loading

0 comments on commit b98a30f

Please sign in to comment.