Skip to content

Commit

Permalink
Fix spartan (and include its tests in CI) (#184)
Browse files Browse the repository at this point in the history
Spartan bit-rotted because CI was misconfigured.
  • Loading branch information
alex-ozdemir authored Feb 1, 2024
1 parent d0b529b commit 2ebd0a1
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 67 deletions.
17 changes: 2 additions & 15 deletions driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ def check(features):
cmd = ["cargo", "check", "--tests", "--examples", "--benches", "--bins"]
if features:
cmd = cmd + ["--features"] + [",".join(features)]
if "ristretto255" in features:
cmd = cmd + ["--no-default-features"]
log_run_check(cmd)


Expand Down Expand Up @@ -82,8 +80,6 @@ def build(features):

if features:
cmd = cmd + ["--features"] + [",".join(features)]
if "ristretto255" in features:
cmd = cmd + ["--no-default-features"]

log_run_check(cmd)

Expand Down Expand Up @@ -114,9 +110,6 @@ def test(features, extra_args):
if features:
test_cmd += ["--features"] + [",".join(features)]
test_cmd_release += ["--features"] + [",".join(features)]
if "ristretto255" in features:
test_cmd += ["--no-default-features"]
test_cmd_release += ["--no-default-features"]
if len(extra_args) > 0:
test_cmd += [a for a in extra_args if a != "--"]
test_cmd_release += [a for a in extra_args if a != "--"]
Expand All @@ -135,7 +128,7 @@ def test(features, extra_args):
if "lp" in features:
log_run_check(["./scripts/test_zok_to_ilp.zsh"])
if "r1cs" in features:
if "ristretto255" in features: # spartan field
if "spartan" in features: # spartan field
log_run_check(["./scripts/spartan_zok_test.zsh"])
else: # bellman field
log_run_check(["./scripts/zokrates_test.zsh"])
Expand Down Expand Up @@ -168,8 +161,6 @@ def benchmark(features):

if features:
cmd = cmd + ["--features"] + [",".join(features)]
if "ristretto255" in features:
cmd = cmd + ["--no-default-features"]
log_run_check(cmd)


Expand All @@ -192,17 +183,13 @@ def lint():
cmd = ["cargo", "clippy", "--tests", "--examples", "--benches", "--bins"]
if features:
cmd = cmd + ["--features"] + [",".join(features)]
if "ristretto255" in features:
cmd = cmd + ["--no-default-features"]
log_run_check(cmd)


def flamegraph(features, extra):
cmd = ["cargo", "flamegraph"]
if features:
cmd = cmd + ["--features"] + [",".join(features)]
if "ristretto255" in features:
cmd = cmd + ["--no-default-features"]
cmd += extra
print("running:", " ".join(cmd))
log_run_check(cmd)
Expand Down Expand Up @@ -243,7 +230,7 @@ def set_features(features):
features = set()

def verify_feature(f):
if f in cargo_features | {"ristretto255"}:
if f in cargo_features:
return True
return False

Expand Down
19 changes: 11 additions & 8 deletions scripts/spartan_zok_test.zsh
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@ case "$OSTYPE" in
;;
esac

modulus=7237005577332262213973186563042994240857116359379907606001950938285454250989

function r1cs_test {
zpath=$1
measure_time $BIN $zpath r1cs --action count
measure_time $BIN --field-custom-modulus $modulus $zpath r1cs --action count
}

function r1cs_test_count {
zpath=$1
threshold=$2
o=$($BIN $zpath r1cs --action count)
o=$($BIN --field-custom-modulus $modulus $zpath r1cs --action count)
n_constraints=$(echo $o | grep 'Final R1cs size:' | grep -Eo '\b[0-9]+\b')
[[ $n_constraints -lt $threshold ]] || (echo "Got $n_constraints, expected < $threshold" && exit 1)
}
Expand All @@ -34,27 +36,29 @@ function r1cs_test_count {
# examples that don't need modulus change
function pf_test {
ex_name=$1
$BIN examples/ZoKrates/pf/$ex_name.zok r1cs --action spartansetup
$BIN --field-custom-modulus $modulus examples/ZoKrates/pf/$ex_name.zok r1cs --action spartan-setup
$ZK_BIN --pin examples/ZoKrates/pf/$ex_name.zok.pin --vin examples/ZoKrates/pf/$ex_name.zok.vin --action spartan
rm -rf P V pi
}

# Test prove workflow with --z-isolate-asserts, given an example name
# Test prove workflow with --zsharp-isolate-asserts, given an example name
function spartan_test_isolate {
ex_name=$1
$BIN --z-isolate-asserts examples/ZoKrates/spartan/$ex_name.zok r1cs --action spartansetup
$BIN --field-custom-modulus $modulus --zsharp-isolate-asserts true examples/ZoKrates/spartan/$ex_name.zok r1cs --action spartan-setup
$ZK_BIN --pin examples/ZoKrates/spartan/$ex_name.zok.pin --vin examples/ZoKrates/spartan/$ex_name.zok.vin --action spartan
rm -rf P V pi
}

# Test prove workflow, given an example name
function spartan_test {
ex_name=$1
$BIN examples/ZoKrates/spartan/$ex_name.zok r1cs --action spartansetup
$ZK_BIN --pin examples/ZoKrates/spartan/$ex_name.zok.pin --vin examples/ZoKrates/spartan/$ex_name.zok.vin --action spartan
$BIN --field-custom-modulus $modulus examples/ZoKrates/spartan/$ex_name.zok r1cs --action spartan-setup
$ZK_BIN --field-custom-modulus $modulus --pin examples/ZoKrates/spartan/$ex_name.zok.pin --vin examples/ZoKrates/spartan/$ex_name.zok.vin --action spartan
rm -rf P V pi
}

spartan_test assert

r1cs_test_count ./examples/ZoKrates/pf/mm4_cond.zok 120
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsAdd.zok
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsOnCurve.zok
Expand All @@ -70,7 +74,6 @@ r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsScalarMult.zo
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/hashes/mimc7/mimc7R20.zok
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/hashes/pedersen/512bit.zok

spartan_test assert
spartan_test_isolate isolate_assert
pf_test 3_plus
pf_test xor
Expand Down
2 changes: 1 addition & 1 deletion scripts/zokrates_test.zsh
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ function pf_test_only_pf {
done
}

# Test prove workflow with --z-isolate-asserts, given an example name
# Test prove workflow with --zsharp-isolate-asserts, given an example name
function pf_test_isolate {
for proof_impl in groth16 mirage
do
Expand Down
10 changes: 7 additions & 3 deletions src/target/r1cs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,8 @@ impl R1csFinal {
}

impl ProverData {
/// Check all assertions. Puts in 1 for challenges.
pub fn check_all(&self, values: &HashMap<String, Value>) {
/// Compute an R1CS witness (setting any challenges to 1s)
pub fn extend_r1cs_witness(&self, values: &HashMap<String, Value>) -> HashMap<Var, FieldV> {
// we need to evaluate all R1CS variables
let mut var_values: HashMap<Var, FieldV> = Default::default();
let mut eval = wit_comp::StagedWitCompEvaluator::new(&self.precompute);
Expand Down Expand Up @@ -504,7 +504,11 @@ impl ProverData {
}
}
}
self.r1cs.check_all(&var_values);
var_values
}
/// Check all assertions. Puts in 1 for challenges.
pub fn check_all(&self, values: &HashMap<String, Value>) {
self.r1cs.check_all(&self.extend_r1cs_witness(values));
}

/// How many commitments?
Expand Down
65 changes: 25 additions & 40 deletions src/target/r1cs/spartan.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
//! Export circ R1cs to Spartan
use crate::target::r1cs::wit_comp::StagedWitCompEvaluator;
use crate::target::r1cs::*;
use bincode::{deserialize_from, serialize_into};
use curve25519_dalek::scalar::Scalar;
Expand Down Expand Up @@ -95,31 +94,42 @@ pub fn r1cs_to_spartan(
"\nR1CS has modulus \n{s_mod},\n but Spartan CS expects \n{f_mod}",
);

let values = eval_inputs(inputs_map, prover_data);
let values = prover_data.extend_r1cs_witness(inputs_map);
prover_data.r1cs.check_all(&values);

assert_eq!(values.len(), prover_data.r1cs.vars.len());

for (var, val) in prover_data.r1cs.vars.iter().zip(&values) {
let scalar = val_to_scalar(val);
for var in prover_data.r1cs.vars.iter() {
assert!(matches!(var.ty(), VarType::Inst | VarType::FinalWit));
if let VarType::FinalWit = var.ty() {
// witness
let id = wit.len();
itrans.insert(id, *var);
trans.insert(*var, id);
let val = values.get(var).expect("missing R1CS value");
wit.push(int_to_scalar(&val.i()).to_bytes());
}
}

// input
itrans.insert(inp.len(), *var);
trans.insert(*var, inp.len());
let const_id = wit.len();

for var in prover_data.r1cs.vars.iter() {
assert!(matches!(var.ty(), VarType::Inst | VarType::FinalWit));
if let VarType::Inst = var.ty() {
inp.push(scalar.to_bytes());
} else {
wit.push(scalar.to_bytes());
// input
let id = wit.len() + 1 + inp.len();
itrans.insert(id, *var);
trans.insert(*var, id);
let val = values.get(var).expect("missing R1CS value");
inp.push(int_to_scalar(&val.i()).to_bytes());
}
}

assert_eq!(wit.len() + inp.len(), prover_data.r1cs.vars.len());

let num_vars = wit.len();
let const_id = wit.len();
let num_inputs = inp.len();
assert_eq!(wit.len() + inp.len(), prover_data.r1cs.vars.len());

let assn_witness = VarsAssignment::new(&wit).unwrap();

let num_inputs = inp.len();
let assn_inputs = InputsAssignment::new(&inp).unwrap();

// circuit
Expand Down Expand Up @@ -166,31 +176,6 @@ pub fn r1cs_to_spartan(
)
}

fn eval_inputs(inputs_map: &HashMap<String, Value>, prover_data: &ProverData) -> Vec<Value> {
let mut evaluator = StagedWitCompEvaluator::new(&prover_data.precompute);
let mut ffs = Vec::new();
ffs.extend(
evaluator
.eval_stage(inputs_map.clone())
.into_iter()
.cloned(),
);
ffs.extend(
evaluator
.eval_stage(Default::default())
.into_iter()
.cloned(),
);
ffs
}

fn val_to_scalar(v: &Value) -> Scalar {
match v.sort() {
Sort::Field(_) => return int_to_scalar(&v.as_pf().i()),
_ => panic!("Value should be a field"),
};
}

fn int_to_scalar(i: &Integer) -> Scalar {
let mut accumulator = Scalar::zero();
let limb_bits = (std::mem::size_of::<limb_t>() as u64) << 3;
Expand Down

0 comments on commit 2ebd0a1

Please sign in to comment.