Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify risc0 tracer util to print function stack to enable better debugging #711

Merged
merged 6 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions examples/demo-prover/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Directories and paths
TRACER_DIR = ../../utils/zk-cycle-utils/tracer
ELF_PATH_TRACER = ../../../examples/demo-prover/target/riscv-guest/riscv32im-risc0-zkvm-elf/release/rollup
TRACE_PATH_TRACER = ../../../examples/demo-prover/host/rollup.trace

# This allows you to pass additional flags when you call `make run-tracer`.
# For example: `make run-tracer ADDITIONAL_FLAGS="--some-flag"`
ADDITIONAL_FLAGS ?=

.PHONY: generate-files run-tracer

all: generate-files run-tracer

generate-files:
ROLLUP_TRACE=rollup.trace cargo bench --bench prover_bench --features bench

run-tracer:
@cd $(TRACER_DIR) && \
cargo run --release -- --no-raw-counts --rollup-elf $(ELF_PATH_TRACER) --rollup-trace $(TRACE_PATH_TRACER) $(ADDITIONAL_FLAGS)
13 changes: 9 additions & 4 deletions examples/demo-prover/host/benches/prover_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,17 @@ impl RegexAppender {
impl log::Log for RegexAppender {
fn log(&self, record: &log::Record) {
if let Some(captures) = self.regex.captures(record.args().to_string().as_str()) {
let mut file_guard = self.file.lock().unwrap();
if let Some(matched_pc) = captures.get(1) {
let pc_value_num = u64::from_str_radix(&matched_pc.as_str()[2..], 16).unwrap();
let pc_value = format!("{}\n", pc_value_num);
let mut file_guard = self.file.lock().unwrap();
let pc_value = format!("{}\t", pc_value_num);
file_guard.write_all(pc_value.as_bytes()).unwrap();
}
if let Some(matched_iname) = captures.get(2) {
let iname = matched_iname.as_str().to_uppercase();
let iname_value = format!("{}\n", iname);
file_guard.write_all(iname_value.as_bytes()).unwrap();
}
}
}

Expand All @@ -69,8 +74,8 @@ impl log::Log for RegexAppender {
}

fn get_config(rollup_trace: &str) -> Config {
let regex_pattern = r".*?pc: (0x[0-9a-fA-F]+), insn.*";
// let log_file = "/Users/dubbelosix/sovereign/examples/demo-prover/matched_pattern.log";
// [942786] pc: 0x0008e564, insn: 0xffc67613 => andi x12, x12, -4
let regex_pattern = r".*?pc: (0x[0-9a-fA-F]+), insn: .*?=> ([a-z]*?) ";

let custom_appender = RegexAppender::new(regex_pattern, rollup_trace);

Expand Down
2 changes: 1 addition & 1 deletion examples/demo-prover/methods/guest/src/bin/rollup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use celestia::{BlobWithSender, CelestiaHeader};
use const_rollup_config::{ROLLUP_NAMESPACE_RAW, SEQUENCER_DA_ADDRESS};
use demo_stf::app::create_zk_app_template;
use demo_stf::ArrayWitness;

use risc0_adapter::guest::Risc0Guest;
use risc0_zkvm::guest::env;
use sov_rollup_interface::crypto::NoOpHasher;
Expand Down Expand Up @@ -103,7 +104,6 @@ pub fn main() {
let metrics_syscall_name = unsafe {
risc0_zkvm_platform::syscall::SyscallName::from_bytes_with_nul(cycle_string.as_ptr())
};

risc0_zkvm::guest::env::send_recv_slice::<u8, u8>(metrics_syscall_name, &serialized);
}
}
2 changes: 1 addition & 1 deletion examples/demo-rollup/rollup_config.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[da]
# The JWT used to authenticate with the celestia light client. Instructions for generating this token can be found in the README
celestia_rpc_auth_token = "MY.SECRET.TOKEN"
celestia_rpc_auth_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJBbGxvdyI6WyJwdWJsaWMiLCJyZWFkIiwid3JpdGUiLCJhZG1pbiJdfQ.MhcEa3-fyoZvrf5bJ-sqUnrJi1cHhfKBq0W4lGT-oso"
dubbelosix marked this conversation as resolved.
Show resolved Hide resolved
# The address of the *trusted* Celestia light client to interact with
celestia_rpc_address = "http://127.0.0.1:26658"
# The largest response the rollup will accept from the Celestia node. Defaults to 100 MB
Expand Down
140 changes: 118 additions & 22 deletions utils/zk-cycle-utils/tracer/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,24 +38,67 @@ struct Args {
#[arg(short, long)]
/// Strip the hashes from the function name while printing
strip_hashes: bool,

#[arg(short, long)]
/// Function name to target for getting stack counts
function_name: Option<String>,

#[arg(short, long)]
/// Exclude functions matching these patterns from display
/// usage: -e func1 -e func2 -e func3
exclude_view: Vec<String>,
}

fn strip_hash(name_with_hash: &str) -> String {
let re = Regex::new(r"::h[0-9a-fA-F]+$").unwrap();
re.replace(name_with_hash, "").to_string()
}

fn print_intruction_counts(count_vec: Vec<(&String, &usize)>, top_n: usize, strip_hashes: bool) {
fn get_cycle_count(insn: &str) -> Result<usize, &'static str> {
match insn {
"LB" | "LH" | "LW" | "LBU" | "LHU" | "ADDI" | "SLLI" | "SLTI" | "SLTIU" |
"AUIPC" | "SB" | "SH" | "SW" | "ADD" | "SUB" | "SLL" | "SLT" | "SLTU" |
"XOR" | "SRL" | "SRA" | "OR" | "AND" | "MUL" | "MULH" | "MULSU" | "MULU" |
"LUI" | "BEQ" | "BNE" | "BLT" | "BGE" | "BLTU" | "BGEU" | "JALR" | "JAL" |
"ECALL" | "EBREAK" => Ok(1),

// Don't see this in the risc0 code base, but MUL, MULH, MULSU, and MULU all take 1 cycle,
// so going with that for MULHU as well.
"MULHU" => Ok(1),

"XORI" | "ORI" | "ANDI" | "SRLI" | "SRAI" | "DIV" | "DIVU" | "REM" | "REMU" => Ok(2),

_ => Err("Decode error"),
}
}
dubbelosix marked this conversation as resolved.
Show resolved Hide resolved

fn print_intruction_counts(first_header: &str,
count_vec: Vec<(String, usize)>,
top_n: usize,
strip_hashes: bool,
exclude_list: Option<&[String]>) {
let mut table = Table::new();
table.set_format(*format::consts::FORMAT_DEFAULT);
table.set_titles(Row::new(vec![
Cell::new("Function Name"),
Cell::new(first_header),
Cell::new("Instruction Count"),
]));

let wrap_width = 90;
let mut row_count = 0;
for (key, value) in count_vec {
let mut cont = false;
if let Some(ev) = exclude_list {
for e in ev {
if key.contains(e) {
cont = true;
break
}
}
if cont {
continue
}
}
let mut stripped_key = key.clone();
if strip_hashes {
stripped_key = strip_hash(&key);
Expand All @@ -75,7 +118,18 @@ fn print_intruction_counts(count_vec: Vec<(&String, &usize)>, top_n: usize, stri
table.printstd();
}

fn _build_lookups_radare_2(
fn focused_stack_counts(function_stack: &[String],
filtered_stack_counts: &mut HashMap<Vec<String>, usize>,
function_name: &str,
instruction: &str) {
if let Some(index) = function_stack.iter().position(|s| s == function_name) {
let truncated_stack = &function_stack[0..=index];
let count = filtered_stack_counts.entry(truncated_stack.to_vec()).or_insert(0);
*count += get_cycle_count(instruction).unwrap();
}
}

fn _build_radare2_lookups(
start_lookup: &mut HashMap<u64, String>,
end_lookup: &mut HashMap<u64, String>,
func_range_lookup: &mut HashMap<String, (u64, u64)>,
Expand Down Expand Up @@ -130,10 +184,18 @@ fn build_goblin_lookups(
Ok(())
}

fn increment_stack_counts(instruction_counts: &mut HashMap<String, usize>, function_stack: &[String]) {
for function_name in function_stack {
*instruction_counts.entry(function_name.clone()).or_insert(0) += 1;
fn increment_stack_counts(instruction_counts: &mut HashMap<String, usize>,
function_stack: &[String],
filtered_stack_counts: &mut HashMap<Vec<String>, usize>,
function_name: &Option<String>,
instruction: &str) {
for f in function_stack {
*instruction_counts.entry(f.clone()).or_insert(0) += get_cycle_count(instruction).unwrap();
}
if let Some(f) = function_name {
focused_stack_counts(function_stack, filtered_stack_counts, &f, instruction)
}

}

fn main() -> std::io::Result<()> {
Expand All @@ -145,6 +207,8 @@ fn main() -> std::io::Result<()> {
let no_stack_counts = args.no_stack_counts;
let no_raw_counts = args.no_raw_counts;
let strip_hashes = args.strip_hashes;
let function_name = args.function_name;
let exclude_view = args.exclude_view;

let mut start_lookup = HashMap::new();
let mut end_lookup = HashMap::new();
Expand All @@ -153,7 +217,7 @@ fn main() -> std::io::Result<()> {

let mut function_ranges: Vec<(u64, u64, String)> = func_range_lookup
.iter()
.map(|(function_name, &(start, end))| (start, end, function_name.clone()))
.map(|(f, &(start, end))| (start, end, f.clone()))
.collect();

function_ranges.sort_by_key(|&(start, _, _)| start);
Expand All @@ -162,6 +226,7 @@ fn main() -> std::io::Result<()> {
let mut function_stack: Vec<String> = Vec::new();
let mut instruction_counts: HashMap<String, usize> = HashMap::new();
let mut counts_without_callgraph: HashMap<String, usize> = HashMap::new();
let mut filtered_stack_counts: HashMap<Vec<String>, usize> = HashMap::new();
let total_lines = file_content.lines().count() as u64;
let mut current_function_range : (u64,u64) = (0,0);

Expand All @@ -176,7 +241,9 @@ fn main() -> std::io::Result<()> {
if c % &update_interval == 0 {
pb.inc(update_interval as u64);
}
let pc = line.parse().unwrap();
let mut parts = line.split("\t");
let pc = parts.next().unwrap_or_default().parse().unwrap();
let instruction = parts.next().unwrap_or_default();

// Raw counts without considering the callgraph at all
// we're just checking if the PC belongs to a function
Expand All @@ -193,9 +260,9 @@ fn main() -> std::io::Result<()> {
} })
{
let (_, _, fname) = &function_ranges[index];
*counts_without_callgraph.entry(fname.clone()).or_insert(0) += 1;
*counts_without_callgraph.entry(fname.clone()).or_insert(0) += get_cycle_count(instruction).unwrap();
} else {
*counts_without_callgraph.entry("anonymous".to_string()).or_insert(0) += 1;
*counts_without_callgraph.entry("anonymous".to_string()).or_insert(0) += get_cycle_count(instruction).unwrap();
}

// The next section considers the callstack
Expand All @@ -204,17 +271,17 @@ fn main() -> std::io::Result<()> {

// we are still in the current function
if pc > current_function_range.0 && pc <= current_function_range.1 {
increment_stack_counts(&mut instruction_counts, &function_stack);
increment_stack_counts(&mut instruction_counts, &function_stack, &mut filtered_stack_counts, &function_name,instruction);
continue;
}

// jump to a new function (or the same one)
if let Some(function_name) = start_lookup.get(&pc) {
increment_stack_counts(&mut instruction_counts, &function_stack);
if let Some(f) = start_lookup.get(&pc) {
increment_stack_counts(&mut instruction_counts, &function_stack, &mut filtered_stack_counts, &function_name,instruction);
// jump to a new function (not recursive)
if !function_stack.contains(&function_name) {
function_stack.push(function_name.clone());
current_function_range = *func_range_lookup.get(function_name).unwrap();
if !function_stack.contains(&f) {
function_stack.push(f.clone());
current_function_range = *func_range_lookup.get(f).unwrap();
}
} else {
// this means pc now points to an instruction that is
Expand All @@ -237,33 +304,62 @@ fn main() -> std::io::Result<()> {
if unwind_found {

function_stack.truncate(unwind_point + 1);
increment_stack_counts(&mut instruction_counts, &function_stack);
increment_stack_counts(&mut instruction_counts, &function_stack, &mut filtered_stack_counts, &function_name,instruction);
continue;
}

// if no unwind point has been found, that means we jumped to some random location
// so we'll just increment the counts for everything in the stack
increment_stack_counts(&mut instruction_counts, &function_stack);
increment_stack_counts(&mut instruction_counts, &function_stack, &mut filtered_stack_counts, &function_name,instruction);
}

}

pb.finish_with_message("done");

let mut raw_counts: Vec<(&String, &usize)> = instruction_counts.iter().collect();
let mut raw_counts: Vec<(String, usize)> = instruction_counts
.iter()
.map(|(key, value)| (key.clone(), value.clone()))
.collect();
raw_counts.sort_by(|a, b| b.1.cmp(&a.1));

println!("\n\nTotal instructions in trace: {}", total_lines);
if !no_stack_counts {
println!("\n\n Instruction counts considering call graph");
print_intruction_counts(raw_counts, top_n, strip_hashes);
print_intruction_counts("Function Name", raw_counts, top_n, strip_hashes,Some(&exclude_view));
}

let mut raw_counts: Vec<(&String, &usize)> = counts_without_callgraph.iter().collect();
let mut raw_counts: Vec<(String, usize)> = counts_without_callgraph
.iter()
.map(|(key, value)| (key.clone(), value.clone()))
.collect();
raw_counts.sort_by(|a, b| b.1.cmp(&a.1));
if !no_raw_counts {
println!("\n\n Instruction counts ignoring call graph");
print_intruction_counts(raw_counts, top_n, strip_hashes);
print_intruction_counts("Function Name",raw_counts, top_n, strip_hashes,Some(&exclude_view));
}

let mut raw_counts: Vec<(String, usize)> = filtered_stack_counts
.iter()
.map(|(stack, count)| {
let numbered_stack = stack
.iter()
.rev()
.enumerate()
.map(|(index, line)| {
let modified_line = if strip_hashes { strip_hash(line) } else { line.clone() };
format!("({}) {}", index + 1, modified_line)
})
.collect::<Vec<_>>()
.join("\n");
(numbered_stack, *count)
})
.collect();

raw_counts.sort_by(|a, b| b.1.cmp(&a.1));
if let Some(f) = function_name {
println!("\n\n Stack patterns for function '{f}' ");
print_intruction_counts("Function Stack",raw_counts, top_n, strip_hashes,None);
}
Ok(())
}