Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up after 'WIP: integration test (#368)' #603

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,4 @@ jobs:
env:
RAYON_NUM_THREADS: 8
RUST_LOG: debug
run: cargo run --release --package ceno_zkvm --example fibonacci_elf --target ${{ matrix.target }} --
run: cargo run --release --package ceno_zkvm --example fibonacci_elf --target ${{ matrix.target }}
10 changes: 2 additions & 8 deletions ceno_emul/src/vm_state.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::HashMap;
use std::{collections::HashMap, ops::Not};

use super::rv32im::EmuContext;
use crate::{
Expand Down Expand Up @@ -79,13 +79,7 @@ impl VMState {

pub fn iter_until_halt(&mut self) -> impl Iterator<Item = Result<StepRecord>> + '_ {
let emu = Emulator::new();
from_fn(move || {
if self.halted() {
None
} else {
Some(self.step(&emu))
}
})
from_fn(move || self.halted().not().then(|| self.step(&emu)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not something based on take_while?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's exactly what I thought as well! But both halted and step want to borrow self, so the borrow checker didn't like that.

I was very tempted to take that as a suggestion that we should refactor. But that would have been way beyond the narrow scope of what I was trying to do here.

If you make a PR to that effect, I'm happy to review.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. I suspect I had tried that too back then 🤣

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we should leave a comment to save the next curious person some head scratching.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There would be a lot of comments if we’d start documenting hypothetical things that don’t work.

}

fn step(&mut self, emu: &Emulator) -> Result<StepRecord> {
Expand Down
67 changes: 46 additions & 21 deletions ceno_zkvm/examples/fibonacci_elf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use itertools::{Itertools, MinMaxResult, chain, enumerate};
use mpcs::{Basefold, BasefoldRSParams, PolynomialCommitmentScheme};
use std::{
collections::{HashMap, HashSet},
panic,
panic::{self, PanicHookInfo},
time::Instant,
};
use tracing_flame::FlameLayer;
Expand All @@ -35,6 +35,27 @@ struct Args {
max_steps: Option<usize>,
}

/// Temporarily override the panic hook
///
/// We restore the original hook after we are done.
fn with_panic_hook<F, R>(hook: Box<dyn Fn(&PanicHookInfo<'_>) + Sync + Send + 'static>, f: F) -> R
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this is a bit cleaner for the reader to extract, then inlining it into the code further down.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might as well go to the root cause and change the assert to a returned error.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting suggestion. Can you make a PR for that, please?

I didn't want to change the behaviour here, just conservatively make the code that we have easier to understand.

where
F: FnOnce() -> R,
{
// Save the current panic hook
let original_hook = panic::take_hook();

// Set the new panic hook
panic::set_hook(hook);

let result = f();

// Restore the original panic hook
panic::set_hook(original_hook);

result
}

fn main() {
let args = Args::parse();

Expand Down Expand Up @@ -125,7 +146,7 @@ fn main() {

let pk = zkvm_cs
.clone()
.key_gen::<Pcs>(pp.clone(), vp.clone(), zkvm_fixed_traces.clone())
.key_gen::<Pcs>(pp, vp, zkvm_fixed_traces.clone())
.expect("keygen failed");
let vk = pk.get_vk();

Expand Down Expand Up @@ -153,14 +174,14 @@ fn main() {
record.insn().codes().kind == EANY
&& record.rs1().unwrap().value == CENO_PLATFORM.ecall_halt()
})
.and_then(|halt_record| halt_record.rs2())
.and_then(StepRecord::rs2)
.map(|rs2| rs2.value);

let final_access = vm.tracer().final_accesses();
let end_cycle: u32 = vm.tracer().cycle().try_into().unwrap();

let pi = PublicValues::new(
exit_code.unwrap_or(0),
exit_code.unwrap_or_default(),
vm.program().entry,
Tracer::SUBCYCLES_PER_INSN as u32,
vm.get_pc().into(),
Expand Down Expand Up @@ -188,7 +209,7 @@ fn main() {
MemFinalRecord {
addr: rec.addr,
value: vm.peek_register(index),
cycle: *final_access.get(&vma).unwrap_or(&0),
cycle: final_access.get(&vma).copied().unwrap_or_default(),
}
} else {
// The table is padded beyond the number of registers.
Expand All @@ -209,7 +230,7 @@ fn main() {
MemFinalRecord {
addr: rec.addr,
value: vm.peek_memory(vma),
cycle: *final_access.get(&vma).unwrap_or(&0),
cycle: final_access.get(&vma).copied().unwrap_or_default(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0 is the intended value. No reason to hide it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine either way here.

}
})
.collect_vec();
Expand All @@ -218,7 +239,12 @@ fn main() {
// Find the final public IO cycles.
let io_final = io_init
.iter()
.map(|rec| *final_access.get(&rec.addr.into()).unwrap_or(&0))
.map(|rec| {
final_access
.get(&rec.addr.into())
.copied()
.unwrap_or_default()
})
.collect_vec();

// assign table circuits
Expand Down Expand Up @@ -269,18 +295,16 @@ fn main() {
}

let transcript = Transcript::new(b"riscv");
// change public input maliciously should cause verifier to reject proof
// Maliciously changing the public input should cause the verifier to reject the proof.
zkvm_proof.raw_pi[0] = vec![<GoldilocksExt2 as ff_ext::ExtensionField>::BaseField::ONE];
zkvm_proof.raw_pi[1] = vec![<GoldilocksExt2 as ff_ext::ExtensionField>::BaseField::ONE];

// capture panic message, if have
let default_hook = panic::take_hook();
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lispc Here's an example where I had some (minor) trouble understanding what the original code tried to do. So I cleaned it up into its own helper function, in order to make the job of the next reader down the line a bit easier.

panic::set_hook(Box::new(|_info| {
// by default it will print msg to stdout/stderr
// we override it to avoid print msg since we will capture the msg by our own
}));
let result = panic::catch_unwind(|| verifier.verify_proof(zkvm_proof, transcript));
panic::set_hook(default_hook);
// capture panic message, if any
// by default it will print msg to stdout/stderr
// we override it to avoid print msg since we will capture the msg by ourselves
let result = with_panic_hook(Box::new(|_info| ()), || {
panic::catch_unwind(|| verifier.verify_proof(zkvm_proof, transcript))
});
match result {
Ok(res) => {
res.expect_err("verify proof should return with error");
Expand Down Expand Up @@ -322,23 +346,24 @@ fn debug_memory_ranges(vm: &VMState, mem_final: &[MemFinalRecord]) {

tracing::debug!(
"Memory range (accessed): {:?}",
format_segments(vm.platform(), accessed_addrs.iter().copied())
format_segments(vm.platform(), &accessed_addrs)
);
tracing::debug!(
"Memory range (handled): {:?}",
format_segments(vm.platform(), handled_addrs.iter().copied())
format_segments(vm.platform(), &handled_addrs)
);

for addr in &accessed_addrs {
assert!(handled_addrs.contains(addr), "unhandled addr: {:?}", addr);
}
}

fn format_segments(
fn format_segments<'a>(
platform: &Platform,
addrs: impl Iterator<Item = ByteAddr>,
) -> HashMap<String, MinMaxResult<ByteAddr>> {
addrs: impl IntoIterator<Item = &'a ByteAddr>,
) -> HashMap<String, MinMaxResult<&'a ByteAddr>> {
addrs
.into_iter()
.into_grouping_map_by(|addr| format_segment(platform, addr.0))
.minmax()
}
Expand Down
17 changes: 4 additions & 13 deletions ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,7 @@ impl<E: ExtensionField> MmuConfig<E> {
io_addrs: &[Addr],
) {
assert!(
chain(
static_mem_init.iter().map(|record| record.addr),
io_addrs.iter().copied(),
)
.all_unique(),
chain(static_mem_init.iter().map(|record| &record.addr), io_addrs,).all_unique(),
"memory addresses must be unique"
);

Expand Down Expand Up @@ -142,14 +138,9 @@ impl MemPadder {
new_len: usize,
records: Vec<MemInitRecord>,
) -> Vec<MemInitRecord> {
if records.is_empty() {
self.padded(new_len, records)
} else {
self.padded(new_len, records)
.into_iter()
.sorted_by_key(|record| record.addr)
.collect()
}
let mut padded = self.padded(new_len, records);
padded.sort_by_key(|record| record.addr);
padded
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does extra work without being clearer when we just want the default content. If you really want to remove that if then it's more meaningful to split it into two functions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I didn't actually care that much about the if. It was more about converting from vector to iterator (then internally in sorted_by_key convert back to vector and sort then convert back to iterator) and then back to vector again. I'm not at all against letting the computer do some extra work, but here it didn't even make the code easier to read.

But while I was at it, I also simplified the structured by removing the if. The fewer branching code paths you have, the easier it is to achieve good test coverage of all the code paths. Rust's sorting is very efficient for already sorted data: it finishes in O(n). And we need O(n) anyway just to produce the data in the first place.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ho ok if that was about speed, I think that is zero-cost, it does the same thing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}

/// Pad `records` to `new_len` using unused addresses.
Expand Down
7 changes: 3 additions & 4 deletions ceno_zkvm/src/scheme.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use ff_ext::ExtensionField;
use itertools::Itertools;
use mpcs::PolynomialCommitmentScheme;
use serde::{Deserialize, Serialize};
use std::{collections::BTreeMap, fmt::Debug};
Expand Down Expand Up @@ -133,15 +132,15 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProof<E, PCS> {
.iter()
.map(|pv| {
if pv.len() == 1 {
// this is constant poly, and always evaluate to same constant value
// this is constant poly, and always evaluates to same constant value
E::from(pv[0])
} else {
// set 0 as placeholder. will be evaluate lazily
// set 0 as placeholder. will be evaluated lazily
// Or the vector is empty, i.e. the constant 0 polynomial.
E::ZERO
}
})
.collect_vec();
.collect();
Self {
raw_pi,
pi_evals,
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/scheme/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1241,7 +1241,7 @@ impl TowerProver {
virtual_polys.add_mle_list(vec![&eq, &q1, &q2], *alpha_denominator);
}
}
tracing::debug!("generated tower proof at round {}/{}", round, max_round_index);
tracing::debug!("generated tower proof at round {round}/{max_round_index}");

let wrap_batch_span = entered_span!("wrap_batch");
// NOTE: at the time of adding this span, visualizing it with the flamegraph layer
Expand Down
8 changes: 4 additions & 4 deletions ceno_zkvm/src/scheme/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
transcript: Transcript<E>,
does_halt: bool,
) -> Result<bool, ZKVMError> {
// require ecall/halt proof to exist, depending whether we expect a halt.
// require ecall/halt proof to exist, depending on whether we expect a halt.
let num_instances = vm_proof
.opcode_proofs
.get(&HaltInstruction::<E>::name())
.map(|(_, p)| p.num_instances)
.unwrap_or(0);
.unwrap_or_default();
if num_instances != (does_halt as usize) {
return Err(ZKVMError::VerifyError(format!(
"ecall/halt num_instances={}, expected={}",
Expand Down Expand Up @@ -117,12 +117,12 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
}
}

for (name, (_, proof)) in vm_proof.opcode_proofs.iter() {
for (name, (_, proof)) in &vm_proof.opcode_proofs {
tracing::debug!("read {}'s commit", name);
PCS::write_commitment(&proof.wits_commit, &mut transcript)
.map_err(ZKVMError::PCSError)?;
}
for (name, (_, proof)) in vm_proof.table_proofs.iter() {
for (name, (_, proof)) in &vm_proof.table_proofs {
tracing::debug!("read {}'s commit", name);
PCS::write_commitment(&proof.wits_commit, &mut transcript)
.map_err(ZKVMError::PCSError)?;
Expand Down