Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] GPU spmvm #935

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,11 @@ clap = "4.3.17"
ff = "0.13"
metrics = "0.21.1"
neptune = { git = "https://github.com/lurk-lab/neptune", branch = "dev", features = ["abomonation"] }
nova = { git = "https://github.com/lurk-lab/arecibo", branch = "dev", package = "nova-snark" }
nova = { git = "https://github.com/lurk-lab/arecibo", branch = "gpu-spmvm", package = "nova-snark" }
once_cell = "1.18.0"
pairing = { version = "0.23" }
pasta_curves = { git = "https://github.com/lurk-lab/pasta_curves", branch = "dev" }
pasta-msm = { git = "https://github.com/lurk-lab/pasta-msm", branch = "dev" }
pasta-msm = { git = "https://github.com/lurk-lab/pasta-msm", branch = "pushing-the-limit" }
proptest = "1.2.0"
proptest-derive = "0.3.0"
rand = "0.8"
Expand All @@ -141,7 +141,7 @@ tempfile = "3.6.0"
camino = "1.1.6"
thiserror = "1.0.44"
tracing = "0.1.37"
tracing-texray = "0.2.0"
tracing-texray = { git = "https://github.com/winston-h-zhang/tracing-texray", branch = "shim" }
tracing-subscriber = "0.3.17"

[[bin]]
Expand Down
1,691 changes: 1,691 additions & 0 deletions benches/dev/600.txt

Large diffs are not rendered by default.

3,132 changes: 3,132 additions & 0 deletions benches/dev/900.txt

Large diffs are not rendered by default.

33 changes: 17 additions & 16 deletions benches/fibonacci.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@ use lurk::{
field::LurkField,
lem::{eval::evaluate, multiframe::MultiFrame, pointers::Ptr, store::Store},
proof::nova::NovaProver,
proof::Prover,
public_parameters::{
instance::{Instance, Kind},
public_params,
},
proof::{nova::public_params, Prover},
state::State,
};

use tracing_subscriber::{fmt, prelude::*, EnvFilter, Registry};
use tracing_texray::TeXRayLayer;

mod common;
use common::set_bench_config;

Expand Down Expand Up @@ -111,14 +110,8 @@ fn fibonacci_prove<M: measurement::Measurement>(
let lang_pallas = Lang::<pallas::Scalar, Coproc<pallas::Scalar>>::new();
let lang_rc = Arc::new(lang_pallas.clone());

// use cached public params
let instance = Instance::new(
prove_params.reduction_count,
lang_rc.clone(),
true,
Kind::NovaPublicParams,
);
let pp = public_params::<_, _, MultiFrame<'_, _, _>>(&instance).unwrap();
let pp =
public_params::<_, _, MultiFrame<'_, _, _>>(prove_params.reduction_count, lang_rc.clone());

// Track the number of `Lurk frames / sec`
let rc = prove_params.reduction_count as u64;
Expand Down Expand Up @@ -148,7 +141,8 @@ fn fibonacci_prove<M: measurement::Measurement>(
b.iter_batched(
|| frames,
|frames| {
let result = prover.prove(&pp, frames, &store);
let result = tracing_texray::examine(tracing::info_span!("bang!"))
.in_scope(|| prover.prove(&pp, frames, &store));
let _ = black_box(result);
},
BatchSize::LargeInput,
Expand All @@ -159,12 +153,17 @@ fn fibonacci_prove<M: measurement::Measurement>(

fn fibonacci_benchmark(c: &mut Criterion) {
// Uncomment to record the logs. May negatively impact performance
//tracing_subscriber::fmt::init();
let subscriber = Registry::default()
.with(fmt::layer().pretty())
.with(EnvFilter::from_default_env())
.with(TeXRayLayer::new().width(120));
tracing::subscriber::set_global_default(subscriber).unwrap();

set_bench_config();
tracing::debug!("{:?}", lurk::config::LURK_CONFIG);

let reduction_counts = rc_env().unwrap_or_else(|_| vec![100]);
let batch_sizes = [100, 200];
let batch_sizes = [249, 374, 499];

let state = State::init_lurk_state().rccell();

Expand All @@ -187,6 +186,8 @@ fn fibonacci_benchmark(c: &mut Criterion) {
}
}

// RUST_LOG=info LURK_RC=600 LURK_PERF=max-parallel-simple cargo criterion --bench fibonacci --features "cuda" 2> ./benches/gpu-spmvm/600.txt
// RUST_LOG=info LURK_RC=900 LURK_PERF=max-parallel-simple cargo criterion --bench fibonacci --features "cuda" 2> ./benches/dev/900.txt
cfg_if::cfg_if! {
if #[cfg(feature = "flamegraph")] {
criterion_group! {
Expand Down
2,540 changes: 2,540 additions & 0 deletions benches/gpu-spmvm/1200.txt

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions benches/gpu-spmvm/600.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Compiling sppark v0.1.5
Compiling pasta-msm v0.1.4 (https://github.com/lurk-lab/pasta-msm?branch=dev#182b971d)
423 changes: 423 additions & 0 deletions benches/gpu-spmvm/900.txt

Large diffs are not rendered by default.

126 changes: 91 additions & 35 deletions examples/fibonacci.rs
Original file line number Diff line number Diff line change
@@ -1,64 +1,120 @@
use std::{cell::RefCell, rc::Rc, sync::Arc, time::Duration};

use anyhow::anyhow;

use pasta_curves::pallas;

use lurk::{
eval::lang::{Coproc, Lang},
field::LurkField,
lem::{eval::evaluate_simple, pointers::Ptr, store::Store},
{eval::lang::Coproc, state::State},
lem::{eval::evaluate, multiframe::MultiFrame, pointers::Ptr, store::Store},
proof::Prover,
proof::{nova::NovaProver, RecursiveSNARKTrait},
public_parameters::{
instance::{Instance, Kind},
public_params,
},
state::State,
};
use pasta_curves::Fq;

fn fib_expr<F: LurkField>(store: &Store<F>) -> Ptr {
use tracing_subscriber::{fmt, prelude::*, EnvFilter, Registry};
use tracing_texray::TeXRayLayer;

fn fib<F: LurkField>(store: &Store<F>, state: Rc<RefCell<State>>, _a: u64) -> Ptr {
let program = r#"
(letrec ((next (lambda (a b) (next b (+ a b))))
(fib (next 0 1)))
(fib))
"#;

store.read_with_default_state(program).unwrap()
store.read(state, program).unwrap()
}

// The env output in the `fib_frame`th frame of the above, infinite Fibonacci computation contains a binding of the
// The env output in the `fib_frame`th frame of the above, infinite Fibonacci computation will contain a binding of the
// nth Fibonacci number to `a`.
// means of computing it.]
fn fib_frame(n: usize) -> usize {
11 + 16 * n
}

// Set the limit so the last step will be filled exactly, since Lurk currently only pads terminal/error continuations.
#[allow(dead_code)]
fn fib_limit(n: usize, rc: usize) -> usize {
let frame = fib_frame(n);
rc * (frame / rc + usize::from(frame % rc != 0))
}

fn lurk_fib(store: &Store<Fq>, n: usize, _rc: usize) -> Ptr {
let frame_idx = fib_frame(n);
// let limit = fib_limit(n, rc);
let limit = frame_idx;
let fib_expr = fib_expr(store);

let (output, ..) = evaluate_simple::<Fq, Coproc<Fq>>(None, fib_expr, store, limit).unwrap();

let target_env = &output[1];

// The result is the value of the second binding (of `A`), in the target env.
// See relevant excerpt of execution trace below:
//
// INFO lurk::eval > Frame: 11
// Expr: (NEXT B (+ A B))
// Env: ((B . 1) (A . 0) ((NEXT . <FUNCTION (A) (LAMBDA (B) (NEXT B (+ A B)))>)))
// Cont: Tail{ saved_env: (((NEXT . <FUNCTION (A) (LAMBDA (B) (NEXT B (+ A B)))>))), continuation: LetRec{var: FIB,
// saved_env: (((NEXT . <FUNCTION (A) (LAMBDA (B) (NEXT B (+ A B)))>))), body: (FIB), continuation: Tail{ saved_env:
// NIL, continuation: Outermost } } }

let (_, rest_bindings) = store.car_cdr(target_env).unwrap();
let (second_binding, _) = store.car_cdr(&rest_bindings).unwrap();
store.car_cdr(&second_binding).unwrap().1
#[derive(Clone, Debug, Copy)]
struct ProveParams {
fib_n: usize,
rc: usize,
}

fn rc_env() -> anyhow::Result<Vec<usize>> {
std::env::var("LURK_RC")
.map_err(|e| anyhow!("Reduction count env var isn't set: {e}"))
.and_then(|rc| {
let vec: anyhow::Result<Vec<usize>> = rc
.split(',')
.map(|rc| {
rc.parse::<usize>()
.map_err(|e| anyhow!("Failed to parse RC: {e}"))
})
.collect();
vec
})
}

fn fibonacci_prove(prove_params: ProveParams, state: &Rc<RefCell<State>>) {
let limit = fib_limit(prove_params.fib_n, prove_params.rc);
let lang_pallas = Lang::<pallas::Scalar, Coproc<pallas::Scalar>>::new();
let lang_rc = Arc::new(lang_pallas.clone());

// use cached public params
let instance = Instance::new(
prove_params.rc,
lang_rc.clone(),
true,
Kind::NovaPublicParams,
);
let pp = public_params::<_, _, MultiFrame<'_, _, _>>(&instance).unwrap();

let store = Store::default();

let ptr = fib::<pasta_curves::Fq>(&store, state.clone(), prove_params.fib_n as u64);
let prover = NovaProver::new(prove_params.rc, lang_rc.clone());

let frames = &evaluate::<pasta_curves::Fq, Coproc<pasta_curves::Fq>>(None, ptr, &store, limit)
.unwrap()
.0;
let (proof, z0, zi, num_steps) = tracing_texray::examine(tracing::info_span!("bang!"))
.in_scope(|| prover.prove(&pp, frames, &store).unwrap());

let res = proof.verify(&pp, &z0, &zi, num_steps).unwrap();
assert!(res);
}

/// RUST_LOG=info LURK_RC=900 LURK_PERF=max-parallel-simple cargo run --release --example fibonacci --features "cuda"
fn main() {
let store = &Store::<Fq>::default();
let n: usize = std::env::args().collect::<Vec<_>>()[1].parse().unwrap();
let state = State::init_lurk_state();
let subscriber = Registry::default()
.with(fmt::layer().pretty())
.with(EnvFilter::from_default_env())
.with(TeXRayLayer::new().width(120));
tracing::subscriber::set_global_default(subscriber).unwrap();

let rcs = rc_env().unwrap_or_else(|_| vec![100]);
let batch_sizes = [249];

let state = State::init_lurk_state().rccell();

let fib = lurk_fib(store, n, 100);
for rc in rcs.iter() {
for fib_n in batch_sizes.iter() {
let prove_params = ProveParams {
fib_n: *fib_n,
rc: *rc,
};
fibonacci_prove(prove_params, &state);
}
}

println!("Fib({n}) = {}", fib.fmt_to_string(store, &state));
println!("success");
}
4 changes: 2 additions & 2 deletions src/lem/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,15 @@ fn build_frames<
let mut pc = 0;
let mut frames = vec![];
let mut iterations = 0;
tracing::info!("{}", &log_fmt(0, &input, &[], store));
tracing::debug!("{}", &log_fmt(0, &input, &[], store));
for _ in 0..limit {
let mut emitted = vec![];
let (frame, must_break) =
compute_frame(lurk_step, cprocs_run, &input, store, lang, &mut emitted, pc)?;

iterations += 1;
input = frame.output.clone();
tracing::info!("{}", &log_fmt(iterations, &input, &emitted, store));
tracing::debug!("{}", &log_fmt(iterations, &input, &emitted, store));
let expr = frame.output[0];
frames.push(frame);

Expand Down
1 change: 1 addition & 0 deletions src/lem/multiframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,7 @@ impl<'a, F: LurkField, C: Coprocessor<F>> nova::traits::circuit::StepCircuit<F>
2 * self.lurk_step.input_params.len()
}

#[tracing::instrument(skip_all, name = "synthesize")]
fn synthesize<CS>(
&self,
cs: &mut CS,
Expand Down
3 changes: 1 addition & 2 deletions src/lem/var_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::collections::hash_map::Entry;

use anyhow::{bail, Result};
use fxhash::FxHashMap;
use tracing::info;

use super::Var;

Expand All @@ -29,7 +28,7 @@ impl<V> VarMap<V> {
}
Entry::Occupied(mut o) => {
let v = o.insert(v);
info!("Variable {} has been overwritten", o.key());
tracing::debug!("Variable {} has been overwritten", o.key());
Some(v)
}
}
Expand Down
6 changes: 4 additions & 2 deletions src/proof/nova.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,14 +293,16 @@ where
assert_eq!(reduction_count, circuit_primary.frames().unwrap().len());

let mut r_snark = recursive_snark.unwrap_or_else(|| {
RecursiveSNARK::new(
let recursive_snark = RecursiveSNARK::new(
&pp.pp,
&circuit_primary,
&circuit_secondary,
z0_primary,
&z0_secondary,
)
.expect("Failed to construct initial recursive snark")
.expect("Failed to construct initial recursive snark");
recursive_snark.write_abomonated(&pp.pp).unwrap();
recursive_snark
});
r_snark
.prove_step(&pp.pp, &circuit_primary, &circuit_secondary)
Expand Down
Loading