From 50ee9c9100f3997fae0df6600121fb6f7d92fd3a Mon Sep 17 00:00:00 2001 From: Dan Dore Date: Mon, 29 Apr 2024 12:00:02 -0700 Subject: [PATCH 01/10] fix handling of less-than instructions with second operand immediate --- alu_u32/src/lt/mod.rs | 16 ++++++++++++---- assembler/src/lib.rs | 2 +- basic/src/bin/valida.rs | 4 ++-- cpu/src/columns.rs | 11 ++++++----- cpu/src/lib.rs | 26 ++++++++++++++++++++++++-- cpu/src/stark.rs | 21 ++++++++++++++++----- 6 files changed, 61 insertions(+), 19 deletions(-) diff --git a/alu_u32/src/lt/mod.rs b/alu_u32/src/lt/mod.rs index 3aee13b6..191ef0f2 100644 --- a/alu_u32/src/lt/mod.rs +++ b/alu_u32/src/lt/mod.rs @@ -148,13 +148,13 @@ where let clk = state.cpu().clock; let pc = state.cpu().pc; let mut imm: Option> = None; - let read_addr_1 = (state.cpu().fp as i32 + ops.b()) as u32; let write_addr = (state.cpu().fp as i32 + ops.a()) as u32; let src1: Word = if ops.d() == 1 { let b = (ops.b() as u32).into(); imm = Some(b); b } else { + let read_addr_1 = (state.cpu().fp as i32 + ops.b()) as u32; state .mem_mut() .read(clk, read_addr_1, true, pc, opcode, 0, "") @@ -181,7 +181,11 @@ where .lt_u32_mut() .operations .push(Operation::Lt32(dst, src1, src2)); - state.cpu_mut().push_bus_op(imm, opcode, ops); + if ops.d() == 1 { + state.cpu_mut().push_left_imm_bus_op(imm, opcode, ops); + } else { + state.cpu_mut().push_bus_op(imm, opcode, ops); + } } } @@ -197,13 +201,13 @@ where let clk = state.cpu().clock; let pc = state.cpu().pc; let mut imm: Option> = None; - let read_addr_1 = (state.cpu().fp as i32 + ops.b()) as u32; let write_addr = (state.cpu().fp as i32 + ops.a()) as u32; let src1: Word = if ops.d() == 1 { let b = (ops.b() as u32).into(); imm = Some(b); b } else { + let read_addr_1 = (state.cpu().fp as i32 + ops.b()) as u32; state .mem_mut() .read(clk, read_addr_1, true, pc, opcode, 0, "") @@ -230,6 +234,10 @@ where .lt_u32_mut() .operations .push(Operation::Lte32(dst, src1, src2)); - state.cpu_mut().push_bus_op(imm, opcode, ops); + if ops.d() == 1 { + state.cpu_mut().push_left_imm_bus_op(imm, opcode, ops); + } else { + state.cpu_mut().push_bus_op(imm, opcode, ops); + } } } diff --git a/assembler/src/lib.rs b/assembler/src/lib.rs index b3ed64be..64a4a06d 100644 --- a/assembler/src/lib.rs +++ b/assembler/src/lib.rs @@ -135,7 +135,7 @@ pub fn assemble(input: &str) -> Result, String> { } "ilt" | "ilte" => { // (a, b, c, 1, 0) - operands.extend(vec![0, 1]); + operands.extend(vec![1, 0]); } "advread" => { // (a, 0, 0, 0, 0) diff --git a/basic/src/bin/valida.rs b/basic/src/bin/valida.rs index a4c2992f..13948f40 100644 --- a/basic/src/bin/valida.rs +++ b/basic/src/bin/valida.rs @@ -435,9 +435,9 @@ fn main() { Ok(_) => { stdout().write("Proof verified\n".as_bytes()).unwrap(); } - Err(_) => { + Err(e) => { stdout() - .write("Proof verification failed\n".as_bytes()) + .write("Proof verification failed: \n".as_bytes()) .unwrap(); } } diff --git a/cpu/src/columns.rs b/cpu/src/columns.rs index 589cb999..d48ea5aa 100644 --- a/cpu/src/columns.rs +++ b/cpu/src/columns.rs @@ -4,7 +4,7 @@ use valida_derive::AlignedBorrow; use valida_machine::{Operands, Word, CPU_MEMORY_CHANNELS}; use valida_util::indices_arr; -#[derive(AlignedBorrow, Default)] +#[derive(AlignedBorrow, Default, Debug)] pub struct CpuCols { /// Clock cycle pub clk: T, @@ -36,17 +36,18 @@ pub struct CpuCols { pub chip_channel: ChipChannelCols, } -#[derive(Default)] +#[derive(Default, Debug)] pub struct InstructionCols { pub opcode: T, pub operands: Operands, } -#[derive(Default)] +#[derive(Default, Debug)] pub struct OpcodeFlagCols { pub is_bus_op: T, pub is_bus_op_with_mem: T, pub is_imm_op: T, + pub is_left_imm_op: T, pub is_load: T, pub is_load_u8: T, pub is_load_s8: T, @@ -62,7 +63,7 @@ pub struct OpcodeFlagCols { pub is_loadfp: T, } -#[derive(Default)] +#[derive(Default, Debug)] pub struct MemoryChannelCols { pub used: T, pub is_read: T, @@ -70,7 +71,7 @@ pub struct MemoryChannelCols { pub value: Word, } -#[derive(Default)] +#[derive(Default, Debug)] pub struct ChipChannelCols { pub clk_or_zero: T, } diff --git a/cpu/src/lib.rs b/cpu/src/lib.rs index 424affc0..79e32f1f 100644 --- a/cpu/src/lib.rs +++ b/cpu/src/lib.rs @@ -1,5 +1,3 @@ -#![no_std] - extern crate alloc; use crate::columns::{CpuCols, CPU_COL_MAP, NUM_CPU_COLS}; @@ -44,6 +42,7 @@ pub enum Operation { Bne(Option> /*imm*/), Imm32, Bus(Option> /*imm*/), + BusLeftImm(Option> /*imm*/), BusWithMemory(Option> /*imm*/), ReadAdvice, Stop, @@ -211,6 +210,10 @@ impl CpuChip { cols.opcode_flags.is_bus_op = SC::Val::one(); self.set_imm_value(cols, *imm); } + Operation::BusLeftImm(imm) => { + cols.opcode_flags.is_bus_op = SC::Val::one(); + self.set_left_imm_value(cols, *imm); + } Operation::BusWithMemory(imm) => { cols.opcode_flags.is_bus_op = SC::Val::one(); cols.opcode_flags.is_bus_op_with_mem = SC::Val::one(); @@ -355,6 +358,15 @@ impl CpuChip { cols.instruction.operands.0[2] = imm.reduce(); } } + + fn set_left_imm_value(&self, cols: &mut CpuCols, imm: Option>) { + if let Some(imm) = imm { + cols.opcode_flags.is_left_imm_op = F::one(); + let imm = imm.transform(F::from_canonical_u8); + cols.mem_channels[1].value = imm; + cols.instruction.operands.0[1] = imm.reduce(); + } + } } pub trait MachineWithCpuChip: MachineWithMemoryChip { @@ -881,6 +893,16 @@ impl CpuChip { self.push_op(Operation::Bus(imm), opcode, operands); } + pub fn push_left_imm_bus_op( + &mut self, + imm: Option>, + opcode: u32, + operands: Operands, + ) { + self.pc += 1; + self.push_op(Operation::BusLeftImm(imm), opcode, operands); + } + pub fn push_op(&mut self, op: Operation, opcode: u32, operands: Operands) { self.operations.push(op); self.instructions.push(InstructionWord { opcode, operands }); diff --git a/cpu/src/stark.rs b/cpu/src/stark.rs index 425267bd..0d9ae1e5 100644 --- a/cpu/src/stark.rs +++ b/cpu/src/stark.rs @@ -1,6 +1,7 @@ use crate::columns::{CpuCols, NUM_CPU_COLS}; use crate::CpuChip; use core::borrow::Borrow; +use p3_field::extension::BinomiallyExtendable; use valida_machine::Word; use p3_air::{Air, AirBuilder, BaseAir}; @@ -45,10 +46,16 @@ where // Immediate value constraints (TODO: we'd need to range check read_value_2 in // this case) + // this asserts that at most one of `is_imm_op` and `is_left_imm_op` is true. + builder.assert_bool(local.opcode_flags.is_imm_op + local.opcode_flags.is_left_imm_op); builder.when(local.opcode_flags.is_imm_op).assert_eq( local.instruction.operands.c(), reduce::(&base, local.read_value_2()), ); + builder.when(local.opcode_flags.is_left_imm_op).assert_eq( + local.instruction.operands.b(), + reduce::(&base, local.read_value_2()), + ); // "Stop" constraints (to check that program execution was not stopped prematurely) builder @@ -82,6 +89,7 @@ impl CpuChip { let is_loadfp = local.opcode_flags.is_loadfp; let _is_advice = local.opcode_flags.is_advice; // TODO: unused let is_imm_op = local.opcode_flags.is_imm_op; + let is_left_imm_op = local.opcode_flags.is_left_imm_op; let is_bus_op = local.opcode_flags.is_bus_op; let _is_bus_op_with_mem = local.opcode_flags.is_bus_op_with_mem; // TODO: unused @@ -94,11 +102,12 @@ impl CpuChip { builder.assert_zero(local.mem_channels[2].is_read); // Read (1) + // note that here we are using the fact that at most one of 'is_imm_op' and 'is_left_imm_op' is ever true. builder - .when(is_jalv + is_beq + is_bne + is_bus_op) + .when(is_jalv + is_beq + is_bne + is_bus_op * (AB::Expr::one() - is_left_imm_op)) .assert_eq(local.read_addr_1(), addr_b.clone()); builder - .when(is_load + is_store) + .when(is_load + is_store + is_bus_op * is_left_imm_op) .assert_eq(local.read_addr_1(), addr_c.clone()); builder .when(is_load + is_store + is_jalv + is_beq + is_bne + is_bus_op) @@ -106,6 +115,7 @@ impl CpuChip { builder.when(is_jal).assert_zero(local.read_1_used()); // Read (2) + // note that here we are again using the fact that at most one of 'is_imm_op' and 'is_left_imm_op' is ever true. builder.when(is_load).assert_eq( local.read_addr_2(), reduce::(base, local.read_value_1()), @@ -114,18 +124,19 @@ impl CpuChip { .when(is_store) .assert_eq(local.read_addr_2(), addr_b); builder - .when(is_jalv + (AB::Expr::one() - is_imm_op) * (is_beq + is_bne + is_bus_op)) + .when(is_jalv + (AB::Expr::one() - (is_imm_op + is_left_imm_op)) * is_bus_op) .assert_eq(local.read_addr_2(), addr_c); builder .when( is_load + is_store + is_jalv - + (AB::Expr::one() - is_imm_op) * (is_beq + is_bne + is_bus_op), + + (AB::Expr::one() - is_imm_op) * (is_beq + is_bne) + + (AB::Expr::one() - (is_imm_op + is_left_imm_op)) * is_bus_op, ) .assert_one(local.read_2_used()); builder - .when(is_jal + is_imm_op * (is_beq + is_bne + is_bus_op)) + .when(is_jal + is_imm_op * (is_beq + is_bne) + (is_imm_op + is_left_imm_op) * is_bus_op) .assert_zero(local.read_2_used()); // Write From 55fa4521c70c35346da944945044863466364a60 Mon Sep 17 00:00:00 2001 From: Dan Dore Date: Tue, 30 Apr 2024 09:55:39 -0700 Subject: [PATCH 02/10] fix ordering of memory channel values when left operand is immediate --- cpu/src/lib.rs | 11 ++++++----- cpu/src/stark.rs | 26 +++++++++++++++++--------- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/cpu/src/lib.rs b/cpu/src/lib.rs index 79e32f1f..64a2ce34 100644 --- a/cpu/src/lib.rs +++ b/cpu/src/lib.rs @@ -170,9 +170,6 @@ impl CpuChip { cols.fp = SC::Val::from_canonical_u32(self.registers[clk].fp); cols.clk = SC::Val::from_canonical_usize(clk); - self.set_instruction_values(clk, cols); - self.set_memory_channel_values::(clk, cols, machine); - match op { Operation::Store32 => { cols.opcode_flags.is_store = SC::Val::one(); @@ -230,6 +227,9 @@ impl CpuChip { } } + self.set_instruction_values(clk, cols); + self.set_memory_channel_values::(clk, cols, machine); + row } @@ -249,13 +249,14 @@ impl CpuChip { cols.mem_channels[1].is_read = SC::Val::one(); cols.mem_channels[2].is_read = SC::Val::zero(); + let is_left_imm_op = cols.opcode_flags.is_left_imm_op == SC::Val::one(); let memory = machine.mem(); for ops in memory.operations.get(&(clk as u32)).iter() { let mut is_first_read = true; for op in ops.iter() { match op { MemoryOperation::Read(addr, value) => { - if is_first_read { + if is_first_read & !is_left_imm_op { cols.mem_channels[0].used = SC::Val::one(); cols.mem_channels[0].addr = SC::Val::from_canonical_u32(*addr); cols.mem_channels[0].value = @@ -363,7 +364,7 @@ impl CpuChip { if let Some(imm) = imm { cols.opcode_flags.is_left_imm_op = F::one(); let imm = imm.transform(F::from_canonical_u8); - cols.mem_channels[1].value = imm; + cols.mem_channels[0].value = imm; cols.instruction.operands.0[1] = imm.reduce(); } } diff --git a/cpu/src/stark.rs b/cpu/src/stark.rs index 0d9ae1e5..4ba2a627 100644 --- a/cpu/src/stark.rs +++ b/cpu/src/stark.rs @@ -1,7 +1,7 @@ use crate::columns::{CpuCols, NUM_CPU_COLS}; use crate::CpuChip; use core::borrow::Borrow; -use p3_field::extension::BinomiallyExtendable; + use valida_machine::Word; use p3_air::{Air, AirBuilder, BaseAir}; @@ -54,7 +54,7 @@ where ); builder.when(local.opcode_flags.is_left_imm_op).assert_eq( local.instruction.operands.b(), - reduce::(&base, local.read_value_2()), + reduce::(&base, local.read_value_1()), ); // "Stop" constraints (to check that program execution was not stopped prematurely) @@ -107,12 +107,21 @@ impl CpuChip { .when(is_jalv + is_beq + is_bne + is_bus_op * (AB::Expr::one() - is_left_imm_op)) .assert_eq(local.read_addr_1(), addr_b.clone()); builder - .when(is_load + is_store + is_bus_op * is_left_imm_op) + .when(is_load + is_store) .assert_eq(local.read_addr_1(), addr_c.clone()); builder - .when(is_load + is_store + is_jalv + is_beq + is_bne + is_bus_op) + .when( + is_load + + is_store + + is_jalv + + is_beq + + is_bne + + (AB::Expr::one() - is_left_imm_op) * is_bus_op, + ) .assert_one(local.read_1_used()); - builder.when(is_jal).assert_zero(local.read_1_used()); + builder + .when(is_jal + is_left_imm_op) + .assert_zero(local.read_1_used()); // Read (2) // note that here we are again using the fact that at most one of 'is_imm_op' and 'is_left_imm_op' is ever true. @@ -124,19 +133,18 @@ impl CpuChip { .when(is_store) .assert_eq(local.read_addr_2(), addr_b); builder - .when(is_jalv + (AB::Expr::one() - (is_imm_op + is_left_imm_op)) * is_bus_op) + .when(is_jalv + (AB::Expr::one() - is_imm_op) * is_bus_op) .assert_eq(local.read_addr_2(), addr_c); builder .when( is_load + is_store + is_jalv - + (AB::Expr::one() - is_imm_op) * (is_beq + is_bne) - + (AB::Expr::one() - (is_imm_op + is_left_imm_op)) * is_bus_op, + + (AB::Expr::one() - is_imm_op) * (is_beq + is_bne + is_bus_op), ) .assert_one(local.read_2_used()); builder - .when(is_jal + is_imm_op * (is_beq + is_bne) + (is_imm_op + is_left_imm_op) * is_bus_op) + .when(is_jal + is_imm_op * (is_beq + is_bne) + is_imm_op * is_bus_op) .assert_zero(local.read_2_used()); // Write From 605c959f34dfe9b8702dd2c9bda34981a28a5f1c Mon Sep 17 00:00:00 2001 From: Dan Dore Date: Tue, 30 Apr 2024 19:27:57 -0700 Subject: [PATCH 03/10] fixes stark constraints for LT32Chip --- alu_u32/src/lt/mod.rs | 2 +- alu_u32/src/lt/stark.rs | 51 ++++++++------ basic/tests/test_prover.rs | 133 +++++++++++++++++++++++++++++++++++-- 3 files changed, 161 insertions(+), 25 deletions(-) diff --git a/alu_u32/src/lt/mod.rs b/alu_u32/src/lt/mod.rs index 191ef0f2..9c6b73a0 100644 --- a/alu_u32/src/lt/mod.rs +++ b/alu_u32/src/lt/mod.rs @@ -115,7 +115,7 @@ impl Lt32Chip { .into_iter() .zip(c.into_iter()) .enumerate() - .find_map(|(n, (x, y))| if x == y { Some(n) } else { None }) + .find_map(|(n, (x, y))| if x == y { None } else { Some(n) }) { let z = 256u16 + b[n] as u16 - c[n] as u16; for i in 0..10 { diff --git a/alu_u32/src/lt/stark.rs b/alu_u32/src/lt/stark.rs index 2180ce58..42fc1bc9 100644 --- a/alu_u32/src/lt/stark.rs +++ b/alu_u32/src/lt/stark.rs @@ -33,46 +33,59 @@ where // Check bit decomposition of z = 256 + input_1[n] - input_2[n], where // n is the most significant byte that differs between inputs - for i in 0..3 { - builder - .when_ne(local.byte_flag[i], AB::Expr::one()) - .assert_eq(local.input_1[i], local.input_2[i]); - + for i in 0..4 { builder.when(local.byte_flag[i]).assert_eq( AB::Expr::from_canonical_u32(256) + local.input_1[i] - local.input_2[i], bit_comp.clone(), ); - builder.assert_bool(local.byte_flag[i]); } - // Check final byte (if no other byte flags were set) - let flag_sum = local.byte_flag[0] + local.byte_flag[1] + local.byte_flag[2]; + // ensure at most one byte flag is set + let flag_sum = + local.byte_flag[0] + local.byte_flag[1] + local.byte_flag[2] + local.byte_flag[3]; builder.assert_bool(flag_sum.clone()); + + // case: top bytes match + builder + .when_ne(local.byte_flag[0], AB::Expr::one()) + .assert_eq(local.input_1[0], local.input_2[0]); + // case: top two bytes match + builder + .when_ne(local.byte_flag[0] + local.byte_flag[1], AB::Expr::one()) + .assert_eq(local.input_1[1], local.input_2[1]); + // case: top three bytes match + builder + .when_ne( + local.byte_flag[0] + local.byte_flag[1] + local.byte_flag[2], + AB::Expr::one(), + ) + .assert_eq(local.input_1[2], local.input_2[2]); + // case: top four bytes match; must set z = 0 builder - .when_ne(local.multiplicity, AB::Expr::zero()) .when_ne(flag_sum.clone(), AB::Expr::one()) - .assert_eq( - AB::Expr::from_canonical_u32(256) + local.input_1[3] - local.input_2[3], - bit_comp.clone(), - ); + .assert_eq(local.input_1[3], local.input_2[3]); + builder + .when_ne(flag_sum.clone(), AB::Expr::one()) + .assert_eq(bit_comp, AB::Expr::zero()); builder.assert_bool(local.is_lt); builder.assert_bool(local.is_lte); builder.assert_bool(local.is_lt + local.is_lte); // Output constraints + // local.bits[8] is 1 iff input_1 > input_2: output should be 0 builder.when(local.bits[8]).assert_zero(local.output); - builder - .when_ne(local.multiplicity, AB::Expr::zero()) - .when_ne(local.bits[8], AB::Expr::one()) - .assert_one(local.output); // output should be 1 if is_lte & input_1 == input_2 - let all_flag_sum = flag_sum + local.byte_flag[3]; builder .when(local.is_lte) - .when_ne(all_flag_sum, AB::Expr::one()) + .when_ne(flag_sum.clone(), AB::Expr::one()) .assert_one(local.output); + // output should be 0 if is_lt & input_1 == input_2 + builder + .when(local.is_lt) + .when_ne(flag_sum, AB::Expr::one()) + .assert_zero(local.output); // Check bit decomposition for bit in local.bits.into_iter() { diff --git a/basic/tests/test_prover.rs b/basic/tests/test_prover.rs index d04a33ed..9ad51d9c 100644 --- a/basic/tests/test_prover.rs +++ b/basic/tests/test_prover.rs @@ -3,6 +3,7 @@ extern crate core; use p3_baby_bear::BabyBear; use p3_fri::{TwoAdicFriPcs, TwoAdicFriPcsConfig}; use valida_alu_u32::add::{Add32Instruction, MachineWithAdd32Chip}; +use valida_alu_u32::lt::{Lt32Instruction, Lte32Instruction}; use valida_basic::BasicMachine; use valida_cpu::{ BeqInstruction, BneInstruction, Imm32Instruction, JalInstruction, JalvInstruction, @@ -20,7 +21,7 @@ use valida_program::MachineWithProgramChip; use p3_challenger::DuplexChallenger; use p3_dft::Radix2Bowers; use p3_field::extension::BinomialExtensionField; -use p3_field::Field; +use p3_field::{Field, PrimeField32, TwoAdicField}; use p3_fri::FriConfig; use p3_keccak::Keccak256Hash; use p3_mds::coset_mds::CosetMds; @@ -31,8 +32,7 @@ use rand::thread_rng; use valida_machine::StarkConfigImpl; use valida_machine::__internal::p3_commit::ExtensionMmcs; -#[test] -fn prove_fibonacci() { +fn fib_program() -> Vec> { let mut program = vec![]; // Label locations @@ -46,7 +46,7 @@ fn prove_fibonacci() { //main: ; @main //; %bb.0: // imm32 -4(fp), 0, 0, 0, 0 - // imm32 -8(fp), 0, 0, 0, 10 + // imm32 -8(fp), 0, 0, 0, 25 // addi -16(fp), -8(fp), 0 // imm32 -20(fp), 0, 0, 0, 28 // jal -28(fp), fib, -28 @@ -184,7 +184,74 @@ fn prove_fibonacci() { operands: Operands([-4, 0, 8, 0, 0]), }, ]); + program +} + +fn left_imm_ops_program() -> Vec> { + let mut program = vec![]; + + // imm32 -4(fp), 0, 0, 0, 3 + // lt32 -8(fp), 3, -4(fp), 1, 0 + // lte32 -12(fp), 3, -4(fp), 1, 0 + // stop + program.extend([ + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([-4, 0, 0, 0, 3]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([-8, 0, 0, 1, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([4, 3, -4, 1, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([8, 3, -4, 1, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([12, 4, -4, 1, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([16, 4, -4, 1, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([20, 2, -4, 1, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([24, 2, -4, 1, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([28, 256, -4, 1, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([32, 256, -4, 1, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([36, 3, -8, 1, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([40, 3, -8, 1, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands::default(), + }, + ]); + program +} +fn prove_program(program: Vec>) -> BasicMachine { let mut machine = BasicMachine::::default(); let rom = ProgramROM::new(program); machine.program_mut().set_program_rom(&rom); @@ -194,6 +261,7 @@ fn prove_fibonacci() { machine.run(&rom, &mut FixedAdviceProvider::empty()); type Val = BabyBear; + type Challenge = BinomialExtensionField; type PackedChallenge = BinomialExtensionField<::Packing, 5>; @@ -250,13 +318,68 @@ fn prove_fibonacci() { .verify(&config, &deserialized_proof) .expect("verification failed"); + machine +} +#[test] +fn prove_fibonacci() { + let program = fib_program::(); + + let machine = prove_program(program); + assert_eq!(machine.cpu().clock, 192); assert_eq!(machine.cpu().operations.len(), 192); assert_eq!(machine.mem().operations.values().flatten().count(), 401); assert_eq!(machine.add_u32().operations.len(), 105); - assert_eq!( *machine.mem().cells.get(&(0x1000 + 4)).unwrap(), // Return value Word([0, 1, 37, 17,]) // 25th fibonacci number (75025) ); } + +#[test] +fn prove_left_imm_ops() { + let program = left_imm_ops_program::(); + + let machine = prove_program(program); + + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 4)).unwrap(), + Word([0, 0, 0, 0]) // 3 < 3 (false) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 8)).unwrap(), + Word([0, 0, 0, 1]) // 3 <= 3 (true) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 12)).unwrap(), + Word([0, 0, 0, 0]) // 4 < 3 (false) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 16)).unwrap(), + Word([0, 0, 0, 0]) // 4 <= 3 (false) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 20)).unwrap(), + Word([0, 0, 0, 1]) // 2 < 3 (true) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 24)).unwrap(), + Word([0, 0, 0, 1]) // 2 <= 3 (true) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 28)).unwrap(), + Word([0, 0, 0, 0]) // 256 < 3 (false) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 32)).unwrap(), + Word([0, 0, 0, 0]) // 256 <= 3 (false) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 36)).unwrap(), + Word([0, 0, 0, 1]) // 3 < 256 (true) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 40)).unwrap(), + Word([0, 0, 0, 1]) // 3 <= 256 (false) + ); +} From 215176aa1370ac09288096721db91b5fb04ee12d Mon Sep 17 00:00:00 2001 From: Dan Dore Date: Thu, 2 May 2024 17:20:58 -0700 Subject: [PATCH 04/10] restore no_std --- basic/src/bin/valida.rs | 4 ++-- cpu/src/lib.rs | 2 ++ cpu/src/stark.rs | 1 - 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/basic/src/bin/valida.rs b/basic/src/bin/valida.rs index 13948f40..a4c2992f 100644 --- a/basic/src/bin/valida.rs +++ b/basic/src/bin/valida.rs @@ -435,9 +435,9 @@ fn main() { Ok(_) => { stdout().write("Proof verified\n".as_bytes()).unwrap(); } - Err(e) => { + Err(_) => { stdout() - .write("Proof verification failed: \n".as_bytes()) + .write("Proof verification failed\n".as_bytes()) .unwrap(); } } diff --git a/cpu/src/lib.rs b/cpu/src/lib.rs index 64a2ce34..ed5c9fac 100644 --- a/cpu/src/lib.rs +++ b/cpu/src/lib.rs @@ -1,3 +1,5 @@ +#![no_std] + extern crate alloc; use crate::columns::{CpuCols, CPU_COL_MAP, NUM_CPU_COLS}; diff --git a/cpu/src/stark.rs b/cpu/src/stark.rs index 4ba2a627..2b9fa22f 100644 --- a/cpu/src/stark.rs +++ b/cpu/src/stark.rs @@ -1,7 +1,6 @@ use crate::columns::{CpuCols, NUM_CPU_COLS}; use crate::CpuChip; use core::borrow::Borrow; - use valida_machine::Word; use p3_air::{Air, AirBuilder, BaseAir}; From f5795638ab061e31737e52326409cc0d0a250a9b Mon Sep 17 00:00:00 2001 From: Dan Dore Date: Fri, 3 May 2024 14:26:54 -0700 Subject: [PATCH 05/10] fixes setting instruction column values with immediate arguments in CpuChip --- alu_u32/src/lt/mod.rs | 189 +++++++++++-------------------------- alu_u32/src/lt/stark.rs | 24 ++--- basic/tests/test_prover.rs | 18 +++- cpu/src/lib.rs | 2 +- 4 files changed, 85 insertions(+), 148 deletions(-) diff --git a/alu_u32/src/lt/mod.rs b/alu_u32/src/lt/mod.rs index 3a08d01a..1177c225 100644 --- a/alu_u32/src/lt/mod.rs +++ b/alu_u32/src/lt/mod.rs @@ -139,29 +139,18 @@ impl Lt32Chip { } cols.multiplicity = F::one(); } -} - -pub trait MachineWithLt32Chip: MachineWithCpuChip { - fn lt_u32(&self) -> &Lt32Chip; - fn lt_u32_mut(&mut self) -> &mut Lt32Chip; -} - -instructions!( - Lt32Instruction, - Lte32Instruction, - Slt32Instruction, - Sle32Instruction -); - -impl Instruction for Lt32Instruction -where - M: MachineWithLt32Chip, - F: Field, -{ - const OPCODE: u32 = LT32; - fn execute(state: &mut M, ops: Operands) { - let opcode = >::OPCODE; + fn execute_with_closure( + state: &mut M, + ops: Operands, + opcode: u32, + comp: F, + ) -> (Word, Word, Word) + where + M: MachineWithLt32Chip, + E: Field, + F: Fn(Word, Word) -> bool, + { let clk = state.cpu().clock; let pc = state.cpu().pc; let mut imm: Option> = None; @@ -187,75 +176,68 @@ where .read(clk, read_addr_2, true, pc, opcode, 1, "") }; - let dst = if src1 < src2 { + let dst = if comp(src1, src2) { Word::from(1) } else { Word::from(0) }; state.mem_mut().write(clk, write_addr, dst, true); - state - .lt_u32_mut() - .operations - .push(Operation::Lt32(dst, src1, src2)); if ops.d() == 1 { - state.cpu_mut().push_left_imm_bus_op(imm, opcode, ops); + state.cpu_mut().push_left_imm_bus_op(imm, opcode, ops) } else { state.cpu_mut().push_bus_op(imm, opcode, ops); } + (dst, src1, src2) } } -impl Instruction for Lte32Instruction +pub trait MachineWithLt32Chip: MachineWithCpuChip { + fn lt_u32(&self) -> &Lt32Chip; + fn lt_u32_mut(&mut self) -> &mut Lt32Chip; +} + +instructions!( + Lt32Instruction, + Lte32Instruction, + Slt32Instruction, + Sle32Instruction +); + +impl Instruction for Lt32Instruction where M: MachineWithLt32Chip, F: Field, { - const OPCODE: u32 = LTE32; + const OPCODE: u32 = LT32; fn execute(state: &mut M, ops: Operands) { let opcode = >::OPCODE; - let clk = state.cpu().clock; - let pc = state.cpu().pc; - let mut imm: Option> = None; - let write_addr = (state.cpu().fp as i32 + ops.a()) as u32; - let src1: Word = if ops.d() == 1 { - let b = (ops.b() as u32).into(); - imm = Some(b); - b - } else { - let read_addr_1 = (state.cpu().fp as i32 + ops.b()) as u32; - state - .mem_mut() - .read(clk, read_addr_1, true, pc, opcode, 0, "") - }; - let src2: Word = if ops.is_imm() == 1 { - let c = (ops.c() as u32).into(); - imm = Some(c); - c - } else { - let read_addr_2 = (state.cpu().fp as i32 + ops.c()) as u32; - state - .mem_mut() - .read(clk, read_addr_2, true, pc, opcode, 1, "") - }; + let comp = |a, b| a < b; + let (dst, src1, src2) = Lt32Chip::execute_with_closure(state, ops, opcode, comp); - let dst = if src1 <= src2 { - Word::from(1) - } else { - Word::from(0) - }; - state.mem_mut().write(clk, write_addr, dst, true); + state + .lt_u32_mut() + .operations + .push(Operation::Lt32(dst, src1, src2)); + } +} +impl Instruction for Lte32Instruction +where + M: MachineWithLt32Chip, + F: Field, +{ + const OPCODE: u32 = LTE32; + + fn execute(state: &mut M, ops: Operands) { + let opcode = >::OPCODE; + let comp = |a, b| a <= b; + let (dst, src1, src2) = Lt32Chip::execute_with_closure(state, ops, opcode, comp); state .lt_u32_mut() .operations .push(Operation::Lte32(dst, src1, src2)); - if ops.d() == 1 { - state.cpu_mut().push_left_imm_bus_op(imm, opcode, ops); - } else { - state.cpu_mut().push_bus_op(imm, opcode, ops); - } } } @@ -268,45 +250,16 @@ where fn execute(state: &mut M, ops: Operands) { let opcode = >::OPCODE; - let clk = state.cpu().clock; - let pc = state.cpu().pc; - let mut imm: Option> = None; - let read_addr_1 = (state.cpu().fp as i32 + ops.b()) as u32; - let write_addr = (state.cpu().fp as i32 + ops.a()) as u32; - let src1: Word = if ops.d() == 1 { - let b = (ops.b() as u32).into(); - imm = Some(b); - b - } else { - state - .mem_mut() - .read(clk, read_addr_1, true, pc, opcode, 0, "") - }; - let src2: Word = if ops.is_imm() == 1 { - let c = (ops.c() as u32).into(); - imm = Some(c); - c - } else { - let read_addr_2 = (state.cpu().fp as i32 + ops.c()) as u32; - state - .mem_mut() - .read(clk, read_addr_2, true, pc, opcode, 1, "") - }; - - let src1_i: i32 = src1.into(); - let src2_i: i32 = src2.into(); - let dst = if src1_i < src2_i { - Word::from(1) - } else { - Word::from(0) + let comp = |a: Word, b: Word| { + let a_i: i32 = a.into(); + let b_i: i32 = b.into(); + a_i < b_i }; - state.mem_mut().write(clk, write_addr, dst, true); - + let (dst, src1, src2) = Lt32Chip::execute_with_closure(state, ops, opcode, comp); state .lt_u32_mut() .operations .push(Operation::Slt32(dst, src1, src2)); - state.cpu_mut().push_bus_op(imm, opcode, ops); } } @@ -319,44 +272,16 @@ where fn execute(state: &mut M, ops: Operands) { let opcode = >::OPCODE; - let clk = state.cpu().clock; - let pc = state.cpu().pc; - let mut imm: Option> = None; - let read_addr_1 = (state.cpu().fp as i32 + ops.b()) as u32; - let write_addr = (state.cpu().fp as i32 + ops.a()) as u32; - let src1: Word = if ops.d() == 1 { - let b = (ops.b() as u32).into(); - imm = Some(b); - b - } else { - state - .mem_mut() - .read(clk, read_addr_1, true, pc, opcode, 0, "") - }; - let src2: Word = if ops.is_imm() == 1 { - let c = (ops.c() as u32).into(); - imm = Some(c); - c - } else { - let read_addr_2 = (state.cpu().fp as i32 + ops.c()) as u32; - state - .mem_mut() - .read(clk, read_addr_2, true, pc, opcode, 1, "") + let comp = |a: Word, b: Word| { + let a_i: i32 = a.into(); + let b_i: i32 = b.into(); + a_i <= b_i }; - - let src1_i: i32 = src1.into(); - let src2_i: i32 = src2.into(); - let dst = if src1_i <= src2_i { - Word::from(1) - } else { - Word::from(0) - }; - state.mem_mut().write(clk, write_addr, dst, true); + let (dst, src1, src2) = Lt32Chip::execute_with_closure(state, ops, opcode, comp); state .lt_u32_mut() .operations .push(Operation::Sle32(dst, src1, src2)); - state.cpu_mut().push_bus_op(imm, opcode, ops); } } diff --git a/alu_u32/src/lt/stark.rs b/alu_u32/src/lt/stark.rs index 42fc1bc9..0402b890 100644 --- a/alu_u32/src/lt/stark.rs +++ b/alu_u32/src/lt/stark.rs @@ -31,21 +31,13 @@ where .map(|(bit, base)| bit * base) .sum(); - // Check bit decomposition of z = 256 + input_1[n] - input_2[n], where - // n is the most significant byte that differs between inputs - for i in 0..4 { - builder.when(local.byte_flag[i]).assert_eq( - AB::Expr::from_canonical_u32(256) + local.input_1[i] - local.input_2[i], - bit_comp.clone(), - ); - builder.assert_bool(local.byte_flag[i]); - } + // check that the n-th byte flag is set, where n is the first byte that differs between the two inputs // ensure at most one byte flag is set let flag_sum = local.byte_flag[0] + local.byte_flag[1] + local.byte_flag[2] + local.byte_flag[3]; builder.assert_bool(flag_sum.clone()); - + // check that bytes before the first set byte flag are all equal // case: top bytes match builder .when_ne(local.byte_flag[0], AB::Expr::one()) @@ -67,7 +59,17 @@ where .assert_eq(local.input_1[3], local.input_2[3]); builder .when_ne(flag_sum.clone(), AB::Expr::one()) - .assert_eq(bit_comp, AB::Expr::zero()); + .assert_eq(bit_comp.clone(), AB::Expr::zero()); + + // Check bit decomposition of z = 256 + input_1[n] - input_2[n] + // when `n` is the first byte that differs between the two inputs. + for i in 0..4 { + builder.when(local.byte_flag[i]).assert_eq( + AB::Expr::from_canonical_u32(256) + local.input_1[i] - local.input_2[i], + bit_comp.clone(), + ); + builder.assert_bool(local.byte_flag[i]); + } builder.assert_bool(local.is_lt); builder.assert_bool(local.is_lte); diff --git a/basic/tests/test_prover.rs b/basic/tests/test_prover.rs index 9ad51d9c..4c6c672d 100644 --- a/basic/tests/test_prover.rs +++ b/basic/tests/test_prover.rs @@ -190,59 +190,69 @@ fn fib_program() -> Vec> fn left_imm_ops_program() -> Vec> { let mut program = vec![]; - // imm32 -4(fp), 0, 0, 0, 3 - // lt32 -8(fp), 3, -4(fp), 1, 0 - // lte32 -12(fp), 3, -4(fp), 1, 0 - // stop program.extend([ + // imm32 -4(fp), 0, 0, 0, 3 + // ;(0, 0, 1, 0) == 256 InstructionWord { opcode: , Val>>::OPCODE, operands: Operands([-4, 0, 0, 0, 3]), }, + // imm32 -8(fp), 0, 0, 1, 0 InstructionWord { opcode: , Val>>::OPCODE, operands: Operands([-8, 0, 0, 1, 0]), }, + // lt32 4(fp), 3, -4(fp), 1, 0 InstructionWord { opcode: , Val>>::OPCODE, operands: Operands([4, 3, -4, 1, 0]), }, + // lte32 8(fp), 3, -4(fp), 1, 0 InstructionWord { opcode: , Val>>::OPCODE, operands: Operands([8, 3, -4, 1, 0]), }, + // lt32 12(fp), 4, -4(fp), 1, 0 InstructionWord { opcode: , Val>>::OPCODE, operands: Operands([12, 4, -4, 1, 0]), }, + // lte32 16(fp), 4, -4(fp), 1, 0 InstructionWord { opcode: , Val>>::OPCODE, operands: Operands([16, 4, -4, 1, 0]), }, + // lt32 20(fp), 2, -4(fp), 1, 0 InstructionWord { opcode: , Val>>::OPCODE, operands: Operands([20, 2, -4, 1, 0]), }, + // lte32 24(fp), 2, -4(fp), 1, 0 InstructionWord { opcode: , Val>>::OPCODE, operands: Operands([24, 2, -4, 1, 0]), }, + // lt32 28(fp), 256, -4(fp), 1, 0 InstructionWord { opcode: , Val>>::OPCODE, operands: Operands([28, 256, -4, 1, 0]), }, + // lte32 32(fp), 256, -4(fp), 1, 0 InstructionWord { opcode: , Val>>::OPCODE, operands: Operands([32, 256, -4, 1, 0]), }, + // lt32 36(fp), 3, -8(fp), 1, 0 InstructionWord { opcode: , Val>>::OPCODE, operands: Operands([36, 3, -8, 1, 0]), }, + // lte32 40(fp), 3, -8(fp), 1, 0 InstructionWord { opcode: , Val>>::OPCODE, operands: Operands([40, 3, -8, 1, 0]), }, + // stop 0, 0, 0, 0, 0 InstructionWord { opcode: , Val>>::OPCODE, operands: Operands::default(), diff --git a/cpu/src/lib.rs b/cpu/src/lib.rs index ed5c9fac..52885c95 100644 --- a/cpu/src/lib.rs +++ b/cpu/src/lib.rs @@ -171,6 +171,7 @@ impl CpuChip { cols.pc = SC::Val::from_canonical_u32(self.registers[clk].pc); cols.fp = SC::Val::from_canonical_u32(self.registers[clk].fp); cols.clk = SC::Val::from_canonical_usize(clk); + self.set_instruction_values(clk, cols); match op { Operation::Store32 => { @@ -229,7 +230,6 @@ impl CpuChip { } } - self.set_instruction_values(clk, cols); self.set_memory_channel_values::(clk, cols, machine); row From a7e8369b7500b281b814321906cfd93e11d6e673 Mon Sep 17 00:00:00 2001 From: Dan Dore Date: Fri, 3 May 2024 14:27:23 -0700 Subject: [PATCH 06/10] adds signed lt instructions to assembler --- assembler/grammar/assembly.pest | 2 +- assembler/src/lib.rs | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/assembler/grammar/assembly.pest b/assembler/grammar/assembly.pest index 82cea9e2..bceb2fa6 100644 --- a/assembler/grammar/assembly.pest +++ b/assembler/grammar/assembly.pest @@ -6,7 +6,7 @@ mnemonic = { "lw" | "sw" | "loadu8" | "loads8" | "storeu8" | "jalv" | "jal" | "beqi" | "beq" | "bnei" | "bne" | "imm32" | "stop" | "advread" | "advwrite" | "addi" | "add" | "subi" | "sub" | "muli" | "mul" | "mulhsi"| "mulhui"| "mulhs"| "mulhu" | "divi" | "div" | "sdiv" | "sdivi" | - "ilte" | "ltei" | "lte" | "ilt" | "lti" | "lt" | "shli" | "shl" | "shri" | "shr" | "srai" | "sra" | + "ilte" | "ltei" | "lte" | "ilt" | "lti" | "lt" | "sltei" | "slti" | "sle" | "slt" | "islt" | "isle" | "shli" | "shl" | "shri" | "shr" | "srai" | "sra" | "andi" | "and" | "ori" | "or" | "xori" | "xor" | "nei" | "ne" | "eqi" | "eq" | "feadd" | "fesub" | "femul" | "write" diff --git a/assembler/src/lib.rs b/assembler/src/lib.rs index 64a4a06d..fb0bce1e 100644 --- a/assembler/src/lib.rs +++ b/assembler/src/lib.rs @@ -10,7 +10,6 @@ pub struct AssemblyParser; pub fn assemble(input: &str) -> Result, String> { let parsed = AssemblyParser::parse(Rule::assembly, input).unwrap(); - // First pass: Record label locations let mut label_locations = HashMap::new(); let mut pc = 0; @@ -82,6 +81,8 @@ pub fn assemble(input: &str) -> Result, String> { "sdiv" | "sdivi" => SDIV32, "ilt" | "lt" | "lti" => LT32, "ilte" | "lte" | "ltei" => LTE32, + "islt" | "slt" | "slti" => SLT32, + "isle" | "sle" | "slei" => SLE32, "shl" | "shli" => SHL32, "shr" | "shri" => SHR32, "sra" | "srai" => SRA32, @@ -123,17 +124,17 @@ pub fn assemble(input: &str) -> Result, String> { } "add" | "sub" | "mul" | "mulhs" | "mulhu" | "div" | "lt" | "lte" | "shl" | "shr" | "sra" | "beq" | "bne" | "and" | "or" | "xor" | "ne" | "eq" - | "jal" | "jalv" => { + | "jal" | "jalv" | "slt" | "sle" => { // (a, b, c, 0, 0) operands.extend(vec![0; 2]); } "addi" | "subi" | "muli" | "mulhsi" | "mulhui" | "divi" | "sdivi" | "lti" | "ltei" | "shli" | "shri" | "srai" | "beqi" | "bnei" | "andi" | "ori" - | "xori" | "nei" | "eqi" => { + | "xori" | "nei" | "eqi" | "slti" | "slei" => { // (a, b, c, 0, 1) operands.extend(vec![0, 1]); } - "ilt" | "ilte" => { + "ilt" | "ilte" | "islt" | "isle" => { // (a, b, c, 1, 0) operands.extend(vec![1, 0]); } From 7a6d813f36bf28ba90e76342e0a76358a2e317d0 Mon Sep 17 00:00:00 2001 From: Dan Dore Date: Fri, 3 May 2024 16:18:26 -0700 Subject: [PATCH 07/10] fixes soundness gap in LT32 chip --- alu_u32/src/lt/columns.rs | 3 +++ alu_u32/src/lt/mod.rs | 9 ++++++--- alu_u32/src/lt/stark.rs | 5 +++++ 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/alu_u32/src/lt/columns.rs b/alu_u32/src/lt/columns.rs index 55fa10cc..a87a2499 100644 --- a/alu_u32/src/lt/columns.rs +++ b/alu_u32/src/lt/columns.rs @@ -21,6 +21,9 @@ pub struct Lt32Cols { pub is_lt: T, pub is_lte: T, + + // inverse of input_1[i] - input_2[i] where i is the first byte that differs + pub diff_inv: T, } pub const NUM_LT_COLS: usize = size_of::>(); diff --git a/alu_u32/src/lt/mod.rs b/alu_u32/src/lt/mod.rs index 1177c225..c1348d04 100644 --- a/alu_u32/src/lt/mod.rs +++ b/alu_u32/src/lt/mod.rs @@ -119,6 +119,9 @@ impl Lt32Chip { F: PrimeField, { // Set the input columns + debug_assert_eq!(a.0.len(), 4); + debug_assert_eq!(b.0.len(), 4); + debug_assert_eq!(c.0.len(), 4); cols.input_1 = b.transform(F::from_canonical_u8); cols.input_2 = c.transform(F::from_canonical_u8); cols.output = F::from_canonical_u8(a[3]); @@ -133,9 +136,9 @@ impl Lt32Chip { for i in 0..10 { cols.bits[i] = F::from_canonical_u16(z >> i & 1); } - if n < 4 { - cols.byte_flag[n] = F::one(); - } + cols.byte_flag[n] = F::one(); + // b[n] != c[n] always here, so the difference is never zero. + cols.diff_inv = (cols.input_1[n] - cols.input_2[n]).inverse(); } cols.multiplicity = F::one(); } diff --git a/alu_u32/src/lt/stark.rs b/alu_u32/src/lt/stark.rs index 0402b890..48684df7 100644 --- a/alu_u32/src/lt/stark.rs +++ b/alu_u32/src/lt/stark.rs @@ -68,6 +68,11 @@ where AB::Expr::from_canonical_u32(256) + local.input_1[i] - local.input_2[i], bit_comp.clone(), ); + // ensure that when the n-th byte flag is set, the n-th bytes are actually different + builder.when(local.byte_flag[i]).assert_eq( + (local.input_1[i] - local.input_2[i]) * local.diff_inv, + AB::Expr::one(), + ); builder.assert_bool(local.byte_flag[i]); } From ee93d2d12dfb286ba7cf615276244dd11bc9073b Mon Sep 17 00:00:00 2001 From: Dan Dore Date: Fri, 3 May 2024 19:29:43 -0700 Subject: [PATCH 08/10] implements constraints for signed inequality instructions --- alu_u32/src/lt/columns.rs | 11 ++- alu_u32/src/lt/mod.rs | 24 +++-- alu_u32/src/lt/stark.rs | 87 ++++++++++++++--- basic/tests/test_prover.rs | 194 ++++++++++++++++++++++++++++++++++++- 4 files changed, 294 insertions(+), 22 deletions(-) diff --git a/alu_u32/src/lt/columns.rs b/alu_u32/src/lt/columns.rs index a87a2499..48bd5ead 100644 --- a/alu_u32/src/lt/columns.rs +++ b/alu_u32/src/lt/columns.rs @@ -13,7 +13,7 @@ pub struct Lt32Cols { pub byte_flag: [T; 4], /// Bit decomposition of 256 + input_1 - input_2 - pub bits: [T; 10], + pub bits: [T; 9], pub output: T, @@ -21,9 +21,18 @@ pub struct Lt32Cols { pub is_lt: T, pub is_lte: T, + pub is_slt: T, + pub is_sle: T, // inverse of input_1[i] - input_2[i] where i is the first byte that differs pub diff_inv: T, + + // bit decomposition of top bytes for input_1 and input_2 + pub top_bits_1: [T; 8], + pub top_bits_2: [T; 8], + + // boolean flag for whether the sign of the two inputs is different + pub different_signs: T, } pub const NUM_LT_COLS: usize = size_of::>(); diff --git a/alu_u32/src/lt/mod.rs b/alu_u32/src/lt/mod.rs index c1348d04..90c7f831 100644 --- a/alu_u32/src/lt/mod.rs +++ b/alu_u32/src/lt/mod.rs @@ -60,6 +60,8 @@ where vec![ (LT_COL_MAP.is_lt, SC::Val::from_canonical_u32(LT32)), (LT_COL_MAP.is_lte, SC::Val::from_canonical_u32(LTE32)), + (LT_COL_MAP.is_slt, SC::Val::from_canonical_u32(SLT32)), + (LT_COL_MAP.is_sle, SC::Val::from_canonical_u32(SLE32)), ], SC::Val::zero(), ); @@ -101,13 +103,11 @@ impl Lt32Chip { self.set_cols(cols, a, b, c); } Operation::Slt32(a, b, c) => { - // TODO: this is just a placeholder - cols.is_lt = F::one(); + cols.is_slt = F::one(); self.set_cols(cols, a, b, c); } Operation::Sle32(a, b, c) => { - // TODO: this is just a placeholder - cols.is_lte = F::one(); + cols.is_sle = F::one(); self.set_cols(cols, a, b, c); } } @@ -133,13 +133,25 @@ impl Lt32Chip { .find_map(|(n, (x, y))| if x == y { None } else { Some(n) }) { let z = 256u16 + b[n] as u16 - c[n] as u16; - for i in 0..10 { + for i in 0..9 { cols.bits[i] = F::from_canonical_u16(z >> i & 1); } cols.byte_flag[n] = F::one(); // b[n] != c[n] always here, so the difference is never zero. cols.diff_inv = (cols.input_1[n] - cols.input_2[n]).inverse(); } + // compute (little-endian) bit decomposition of the top bytes + for i in 0..8 { + cols.top_bits_1[i] = F::from_canonical_u8(b[0] >> i & 1); + cols.top_bits_2[i] = F::from_canonical_u8(c[0] >> i & 1); + } + // check if sign bits agree and set different_signs accordingly + cols.different_signs = if cols.top_bits_1[7] != cols.top_bits_2[7] { + F::one() + } else { + F::zero() + }; + cols.multiplicity = F::one(); } @@ -218,7 +230,6 @@ where let opcode = >::OPCODE; let comp = |a, b| a < b; let (dst, src1, src2) = Lt32Chip::execute_with_closure(state, ops, opcode, comp); - state .lt_u32_mut() .operations @@ -281,7 +292,6 @@ where a_i <= b_i }; let (dst, src1, src2) = Lt32Chip::execute_with_closure(state, ops, opcode, comp); - state .lt_u32_mut() .operations diff --git a/alu_u32/src/lt/stark.rs b/alu_u32/src/lt/stark.rs index 48684df7..2fd58a2e 100644 --- a/alu_u32/src/lt/stark.rs +++ b/alu_u32/src/lt/stark.rs @@ -22,7 +22,7 @@ where let main = builder.main(); let local: &Lt32Cols = main.row_slice(0).borrow(); - let base_2 = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512].map(AB::Expr::from_canonical_u32); + let base_2 = [1, 2, 4, 8, 16, 32, 64, 128, 256].map(AB::Expr::from_canonical_u32); let bit_comp: AB::Expr = local .bits @@ -76,26 +76,89 @@ where builder.assert_bool(local.byte_flag[i]); } + // Check the bit decomposition of the top bytes: + let top_comp_1: AB::Expr = local + .top_bits_1 + .into_iter() + .zip(base_2.iter().cloned()) + .map(|(bit, base)| bit * base) + .sum(); + let top_comp_2: AB::Expr = local + .top_bits_2 + .into_iter() + .zip(base_2.iter().cloned()) + .map(|(bit, base)| bit * base) + .sum(); + builder.assert_eq(top_comp_1, local.input_1[0]); + builder.assert_eq(top_comp_2, local.input_2[0]); + + // Check that `different_signs` is set correctly by comparing sign bits. + builder + .when(local.byte_flag[0]) + .when_ne(local.top_bits_1[7], local.top_bits_2[7]) + .assert_eq(local.different_signs, AB::Expr::one()); + builder + .when(local.different_signs) + .assert_eq(local.byte_flag[0], AB::Expr::one()); + // local.top_bits_1[7] and local.top_bits_2[7] are boolean; their sum is 1 iff they are unequal. + builder + .when(local.different_signs) + .assert_eq(local.top_bits_1[7] + local.top_bits_2[7], AB::Expr::one()); + builder.assert_bool(local.is_lt); builder.assert_bool(local.is_lte); - builder.assert_bool(local.is_lt + local.is_lte); + builder.assert_bool(local.is_slt); + builder.assert_bool(local.is_sle); + builder.assert_bool(local.is_lt + local.is_lte + local.is_slt + local.is_sle); + + let is_signed = local.is_slt + local.is_sle; + let is_unsigned = AB::Expr::one() - is_signed; + let same_sign = AB::Expr::one() - local.different_signs; + let are_equal = AB::Expr::one() - flag_sum.clone(); // Output constraints - // local.bits[8] is 1 iff input_1 > input_2: output should be 0 - builder.when(local.bits[8]).assert_zero(local.output); - // output should be 1 if is_lte & input_1 == input_2 + // Case 0: input_1 > input_2 as unsigned ints; equivalently, local.bits[8] == 1 + // when both inputs have the same sign, signed and unsigned inequality agree. builder - .when(local.is_lte) - .when_ne(flag_sum.clone(), AB::Expr::one()) + .when(local.bits[8]) + .when(is_unsigned.clone() + same_sign.clone()) + .assert_zero(local.output); + // when the inputs have different signs, signed inequality is the opposite of unsigned inequality. + builder + .when(local.bits[8]) + .when(local.different_signs) .assert_one(local.output); - // output should be 0 if is_lt & input_1 == input_2 + + // Case 1: input_1 < input_2 as unsigned ints; equivalently, local.bits[8] == is_equal == 0. builder - .when(local.is_lt) - .when_ne(flag_sum, AB::Expr::one()) + // when are_equal == 1, we have already enforced that local.bits[8] == 0 + .when_ne(local.bits[8] + are_equal.clone(), AB::Expr::one()) + .when(is_unsigned.clone() + same_sign.clone()) + .assert_one(local.output); + builder + .when_ne(local.bits[8] + are_equal.clone(), AB::Expr::one()) + .when(local.different_signs) .assert_zero(local.output); - // Check bit decomposition - for bit in local.bits.into_iter() { + // Case 2: input_1 == input_2; equivalently, are_equal == 1 + // output should be 1 if is_lte or is_sle + builder + .when(are_equal.clone()) + .when(local.is_lte + local.is_sle) + .assert_one(local.output); + // output should be 0 if is_lt or is_slt + builder + .when(are_equal.clone()) + .when(local.is_lt + local.is_slt) + .assert_zero(local.output); + + // Check "bit" values are all boolean + for bit in local + .bits + .into_iter() + .chain(local.top_bits_1.into_iter()) + .chain(local.top_bits_2.into_iter()) + { builder.assert_bool(bit); } } diff --git a/basic/tests/test_prover.rs b/basic/tests/test_prover.rs index 4c6c672d..58fde79c 100644 --- a/basic/tests/test_prover.rs +++ b/basic/tests/test_prover.rs @@ -3,7 +3,7 @@ extern crate core; use p3_baby_bear::BabyBear; use p3_fri::{TwoAdicFriPcs, TwoAdicFriPcsConfig}; use valida_alu_u32::add::{Add32Instruction, MachineWithAdd32Chip}; -use valida_alu_u32::lt::{Lt32Instruction, Lte32Instruction}; +use valida_alu_u32::lt::{Lt32Instruction, Lte32Instruction, Sle32Instruction, Slt32Instruction}; use valida_basic::BasicMachine; use valida_cpu::{ BeqInstruction, BneInstruction, Imm32Instruction, JalInstruction, JalvInstruction, @@ -261,6 +261,122 @@ fn left_imm_ops_program() -> Vec() -> Vec> { + let mut program = vec![]; + + // imm32 -4(fp), 0, 0, 0, 1 + // imm32 -8(fp), 255, 255, 255, 255 + // imm32 -12(fp), 255, 255, 255, 254 + program.extend([ + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([-4, 0, 0, 0, 1]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([-8, 255, 255, 255, 255]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([-12, 255, 255, 255, 254]), + }, + ]); + + // slt32 4(fp), -12(fp), -8(fp), 0, 0 + // slt32 8(fp), -12(fp), -4(fp), 0, 0 + // slt32 12(fp), -4(fp), -1, 0, 1 + // slt32 16(fp), -1, -8(fp), 1, 0 + // sle32 20(fp), -1, -8(fp), 1, 0 + // slt32 24(fp), -1, -12(fp), 1, 0 + // slt32 28(fp), -8(fp), -12(fp), 0, 0 + // slt32 32(fp), -8(fp), -4(fp), 0, 0 + program.extend([ + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([4, -12, -8, 0, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([8, -12, -4, 0, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([12, -4, -1, 0, 1]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([16, -1, -8, 1, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([20, -1, -8, 1, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([24, -1, -12, 1, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([28, -8, -12, 0, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([32, -8, -4, 0, 0]), + }, + ]); + + // lt32 36(fp), -12(fp), -8(fp), 0, 0 + // lt32 40(fp), -12(fp), -4(fp), 0, 0 + // lt32 44(fp), -4(fp), -1, 0, 1 + // lt32 48(fp), -1, -8(fp), 1, 0 + // lte32 52(fp), -1, -8(fp), 1, 0 + // lt32 56(fp), -1, -12(fp), 1, 0 + // lt32 60(fp), -8(fp), -12(fp), 0, 0 + // lt32 64(fp), -8(fp), -4(fp), 0, 0 + // stop + program.extend([ + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([36, -12, -8, 0, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([40, -12, -4, 0, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([44, -4, -1, 0, 1]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([48, -1, -8, 1, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([52, -1, -8, 1, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([56, -1, -12, 1, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([60, -8, -12, 0, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([64, -8, -4, 0, 0]), + }, + // stop 0, 0, 0, 0, 0 + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands::default(), + }, + ]); + + program +} + fn prove_program(program: Vec>) -> BasicMachine { let mut machine = BasicMachine::::default(); let rom = ProgramROM::new(program); @@ -351,7 +467,6 @@ fn prove_left_imm_ops() { let program = left_imm_ops_program::(); let machine = prove_program(program); - assert_eq!( *machine.mem().cells.get(&(0x1000 + 4)).unwrap(), Word([0, 0, 0, 0]) // 3 < 3 (false) @@ -393,3 +508,78 @@ fn prove_left_imm_ops() { Word([0, 0, 0, 1]) // 3 <= 256 (false) ); } + +#[test] +fn prove_signed_inequality() { + let program = signed_inequality_program::(); + + let machine = prove_program(program); + + // signed inequalities + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 4)).unwrap(), + Word([0, 0, 0, 1]) // -2 < -1 (true) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 8)).unwrap(), + Word([0, 0, 0, 1]) // -2 < 1 (true) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 12)).unwrap(), + Word([0, 0, 0, 0]) // 1 < -1 (false) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 16)).unwrap(), + Word([0, 0, 0, 0]) // -1 < -1 (false) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 20)).unwrap(), + Word([0, 0, 0, 1]) // -1 <= -1 (true) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 24)).unwrap(), + Word([0, 0, 0, 0]) // -1 < -2 (false) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 28)).unwrap(), + Word([0, 0, 0, 0]) // -1 < -2 (false) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 32)).unwrap(), + Word([0, 0, 0, 1]) // -1 < 1 (true) + ); + + // unsigned inequalities + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 36)).unwrap(), + Word([0, 0, 0, 1]) // 0xFFFFFFFE < 0xFFFFFFFF (true) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 40)).unwrap(), + Word([0, 0, 0, 0]) // 0xFFFFFFFE < 1 (false) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 44)).unwrap(), + Word([0, 0, 0, 1]) // 1 < 0xFFFFFFFF (true) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 48)).unwrap(), + Word([0, 0, 0, 0]) // 0xFFFFFFFF < 0xFFFFFFFFFF (false) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 52)).unwrap(), + Word([0, 0, 0, 1]) // 0xFFFFFFFF <= 0xFFFFFFFF (true) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 56)).unwrap(), + Word([0, 0, 0, 0]) // 0xFFFFFFFF < 0xFFFFFFFE (false) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 60)).unwrap(), + Word([0, 0, 0, 0]) // 0xFFFFFFFF < 0xFFFFFFFE (false) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 64)).unwrap(), + Word([0, 0, 0, 0]) // 0xFFFFFFFF < 1 (false) + ); +} From aaa207a9654d29a72fa03d632bac600ac3a389f5 Mon Sep 17 00:00:00 2001 From: Dan Dore Date: Fri, 3 May 2024 20:05:03 -0700 Subject: [PATCH 09/10] fix: set 'different_signs' column only for signed instructions --- alu_u32/src/lt/mod.rs | 2 ++ alu_u32/src/lt/stark.rs | 16 ++++++++++------ basic/tests/test_prover.rs | 3 ++- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/alu_u32/src/lt/mod.rs b/alu_u32/src/lt/mod.rs index 90c7f831..33c4b013 100644 --- a/alu_u32/src/lt/mod.rs +++ b/alu_u32/src/lt/mod.rs @@ -97,10 +97,12 @@ impl Lt32Chip { Operation::Lt32(a, b, c) => { cols.is_lt = F::one(); self.set_cols(cols, a, b, c); + cols.different_signs = F::zero(); } Operation::Lte32(a, b, c) => { cols.is_lte = F::one(); self.set_cols(cols, a, b, c); + cols.different_signs = F::zero(); } Operation::Slt32(a, b, c) => { cols.is_slt = F::one(); diff --git a/alu_u32/src/lt/stark.rs b/alu_u32/src/lt/stark.rs index 2fd58a2e..c875b711 100644 --- a/alu_u32/src/lt/stark.rs +++ b/alu_u32/src/lt/stark.rs @@ -92,9 +92,18 @@ where builder.assert_eq(top_comp_1, local.input_1[0]); builder.assert_eq(top_comp_2, local.input_2[0]); + let is_signed = local.is_slt + local.is_sle; + let is_unsigned = AB::Expr::one() - is_signed.clone(); + let same_sign = AB::Expr::one() - local.different_signs; + let are_equal = AB::Expr::one() - flag_sum.clone(); + + builder + .when(is_unsigned.clone()) + .assert_zero(local.different_signs); + // Check that `different_signs` is set correctly by comparing sign bits. builder - .when(local.byte_flag[0]) + .when(is_signed.clone()) .when_ne(local.top_bits_1[7], local.top_bits_2[7]) .assert_eq(local.different_signs, AB::Expr::one()); builder @@ -111,11 +120,6 @@ where builder.assert_bool(local.is_sle); builder.assert_bool(local.is_lt + local.is_lte + local.is_slt + local.is_sle); - let is_signed = local.is_slt + local.is_sle; - let is_unsigned = AB::Expr::one() - is_signed; - let same_sign = AB::Expr::one() - local.different_signs; - let are_equal = AB::Expr::one() - flag_sum.clone(); - // Output constraints // Case 0: input_1 > input_2 as unsigned ints; equivalently, local.bits[8] == 1 // when both inputs have the same sign, signed and unsigned inequality agree. diff --git a/basic/tests/test_prover.rs b/basic/tests/test_prover.rs index 58fde79c..611616af 100644 --- a/basic/tests/test_prover.rs +++ b/basic/tests/test_prover.rs @@ -290,6 +290,7 @@ fn signed_inequality_program() -> Vec, Val>>::OPCODE, @@ -370,7 +371,7 @@ fn signed_inequality_program() -> Vec, Val>>::OPCODE, - operands: Operands::default(), + operands: Operands([0, 0, 0, 0, 0]), }, ]); From 5a4bf94af010f8e4dcf1aa87b5777b201c34661a Mon Sep 17 00:00:00 2001 From: Morgan Thomas Date: Sat, 4 May 2024 16:25:12 -0400 Subject: [PATCH 10/10] refactor: logic to set different_signs --- alu_u32/src/lt/mod.rs | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/alu_u32/src/lt/mod.rs b/alu_u32/src/lt/mod.rs index 33c4b013..05d1e83d 100644 --- a/alu_u32/src/lt/mod.rs +++ b/alu_u32/src/lt/mod.rs @@ -96,28 +96,32 @@ impl Lt32Chip { match op { Operation::Lt32(a, b, c) => { cols.is_lt = F::one(); - self.set_cols(cols, a, b, c); - cols.different_signs = F::zero(); + self.set_cols(cols, false, a, b, c); } Operation::Lte32(a, b, c) => { cols.is_lte = F::one(); - self.set_cols(cols, a, b, c); - cols.different_signs = F::zero(); + self.set_cols(cols, false, a, b, c); } Operation::Slt32(a, b, c) => { cols.is_slt = F::one(); - self.set_cols(cols, a, b, c); + self.set_cols(cols, true, a, b, c); } Operation::Sle32(a, b, c) => { cols.is_sle = F::one(); - self.set_cols(cols, a, b, c); + self.set_cols(cols, true, a, b, c); } } row } - fn set_cols(&self, cols: &mut Lt32Cols, a: &Word, b: &Word, c: &Word) - where + fn set_cols( + &self, + cols: &mut Lt32Cols, + is_signed: bool, + a: &Word, + b: &Word, + c: &Word, + ) where F: PrimeField, { // Set the input columns @@ -148,8 +152,12 @@ impl Lt32Chip { cols.top_bits_2[i] = F::from_canonical_u8(c[0] >> i & 1); } // check if sign bits agree and set different_signs accordingly - cols.different_signs = if cols.top_bits_1[7] != cols.top_bits_2[7] { - F::one() + cols.different_signs = if is_signed { + if cols.top_bits_1[7] != cols.top_bits_2[7] { + F::one() + } else { + F::zero() + } } else { F::zero() };