Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Save Brillig execution state in ACVM #3026

Merged
merged 9 commits into from
Oct 12, 2023
154 changes: 98 additions & 56 deletions acvm-repo/acvm/src/pwg/brillig.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,59 @@ use crate::{pwg::OpcodeNotSolvable, OpcodeResolutionError};

use super::{get_value, insert_value};

pub(super) struct BrilligSolver;
pub(super) enum BrilligSolverStatus {
Finished,
InProgress,
ForeignCallWait(ForeignCallWaitInfo),
}

impl BrilligSolver {
pub(super) fn solve<B: BlackBoxFunctionSolver>(
initial_witness: &mut WitnessMap,
pub(super) struct BrilligSolver<'b, B: BlackBoxFunctionSolver> {
vm: VM<'b, B>,
acir_index: usize,
}

impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> {
/// Evaluates if the Brillig block should be skipped entirely
pub(super) fn should_skip(
witness: &WitnessMap,
brillig: &Brillig,
foreign_call_results: Vec<ForeignCallResult>,
bb_solver: &B,
acir_index: usize,
) -> Result<Option<ForeignCallWaitInfo>, OpcodeResolutionError> {
// If the predicate is `None`, then we simply return the value 1
) -> Result<bool, OpcodeResolutionError> {
// If the predicate is `None`, the block should never be skipped
// If the predicate is `Some` but we cannot find a value, then we return stalled
let pred_value = match &brillig.predicate {
Some(pred) => get_value(pred, initial_witness),
None => Ok(FieldElement::one()),
}?;
match &brillig.predicate {
Some(pred) => Ok(get_value(pred, witness)?.is_zero()),
None => Ok(false),
}
}

// A zero predicate indicates the oracle should be skipped, and its outputs zeroed.
if pred_value.is_zero() {
Self::zero_out_brillig_outputs(initial_witness, brillig)?;
return Ok(None);
/// Assigns the zero value to all outputs of the given [`Brillig`] bytecode.
pub(super) fn zero_out_brillig_outputs(
initial_witness: &mut WitnessMap,
brillig: &Brillig,
) -> Result<(), OpcodeResolutionError> {
for output in &brillig.outputs {
match output {
BrilligOutputs::Simple(witness) => {
insert_value(witness, FieldElement::zero(), initial_witness)?;
}
BrilligOutputs::Array(witness_arr) => {
for witness in witness_arr {
insert_value(witness, FieldElement::zero(), initial_witness)?;
}
}
}
}
Ok(())
}

/// Constructs a solver for a Brillig block given the bytecode and initial
/// witness.
pub(super) fn new(
initial_witness: &mut WitnessMap,
brillig: &'b Brillig,
bb_solver: &'b B,
acir_index: usize,
) -> Result<Self, OpcodeResolutionError> {
// Set input values
let mut input_register_values: Vec<Value> = Vec::new();
let mut input_memory: Vec<Value> = Vec::new();
Expand Down Expand Up @@ -75,80 +105,92 @@ impl BrilligSolver {
}

// Instantiate a Brillig VM given the solved input registers and memory
// along with the Brillig bytecode, and any present foreign call results.
// along with the Brillig bytecode.
let input_registers = Registers::load(input_register_values);
let mut vm = VM::new(
input_registers,
input_memory,
&brillig.bytecode,
foreign_call_results,
bb_solver,
);
let vm = VM::new(input_registers, input_memory, &brillig.bytecode, vec![], bb_solver);
Ok(Self { vm, acir_index })
}

// Run the Brillig VM on these inputs, bytecode, etc!
let vm_status = vm.process_opcodes();
pub(super) fn solve(&mut self) -> Result<BrilligSolverStatus, OpcodeResolutionError> {
let status = self.vm.process_opcodes();
self.handle_vm_status(status)
}

// Check the status of the Brillig VM.
fn handle_vm_status(
&self,
vm_status: VMStatus,
) -> Result<BrilligSolverStatus, OpcodeResolutionError> {
// Check the status of the Brillig VM and return a resolution.
// It may be finished, in-progress, failed, or may be waiting for results of a foreign call.
// Return the "resolution" to the caller who may choose to make subsequent calls
// (when it gets foreign call results for example).
match vm_status {
VMStatus::Finished => {
for (i, output) in brillig.outputs.iter().enumerate() {
let register_value = vm.get_registers().get(RegisterIndex::from(i));
match output {
BrilligOutputs::Simple(witness) => {
insert_value(witness, register_value.to_field(), initial_witness)?;
}
BrilligOutputs::Array(witness_arr) => {
// Treat the register value as a pointer to memory
for (i, witness) in witness_arr.iter().enumerate() {
let value = &vm.get_memory()[register_value.to_usize() + i];
insert_value(witness, value.to_field(), initial_witness)?;
}
}
}
}
Ok(None)
}
VMStatus::InProgress => unreachable!("Brillig VM has not completed execution"),
VMStatus::Finished => Ok(BrilligSolverStatus::Finished),
VMStatus::InProgress => Ok(BrilligSolverStatus::InProgress),
VMStatus::Failure { message, call_stack } => {
Err(OpcodeResolutionError::BrilligFunctionFailed {
message,
call_stack: call_stack
.iter()
.map(|brillig_index| OpcodeLocation::Brillig {
acir_index,
acir_index: self.acir_index,
brillig_index: *brillig_index,
})
.collect(),
})
}
VMStatus::ForeignCallWait { function, inputs } => {
Ok(Some(ForeignCallWaitInfo { function, inputs }))
Ok(BrilligSolverStatus::ForeignCallWait(ForeignCallWaitInfo { function, inputs }))
}
}
}

/// Assigns the zero value to all outputs of the given [`Brillig`] bytecode.
fn zero_out_brillig_outputs(
initial_witness: &mut WitnessMap,
pub(super) fn finalize(
self,
witness: &mut WitnessMap,
brillig: &Brillig,
) -> Result<(), OpcodeResolutionError> {
for output in &brillig.outputs {
// Finish the Brillig execution by writing the outputs to the witness map
let vm_status = self.vm.get_status();
match vm_status {
VMStatus::Finished => {
self.write_brillig_outputs(witness, brillig)?;
Ok(())
}
_ => panic!("Brillig VM has not completed execution"),
}
}

fn write_brillig_outputs(
&self,
witness_map: &mut WitnessMap,
brillig: &Brillig,
) -> Result<(), OpcodeResolutionError> {
// Write VM execution results into the witness map
for (i, output) in brillig.outputs.iter().enumerate() {
let register_value = self.vm.get_registers().get(RegisterIndex::from(i));
match output {
BrilligOutputs::Simple(witness) => {
insert_value(witness, FieldElement::zero(), initial_witness)?;
insert_value(witness, register_value.to_field(), witness_map)?;
}
BrilligOutputs::Array(witness_arr) => {
for witness in witness_arr {
insert_value(witness, FieldElement::zero(), initial_witness)?;
// Treat the register value as a pointer to memory
for (i, witness) in witness_arr.iter().enumerate() {
let value = &self.vm.get_memory()[register_value.to_usize() + i];
insert_value(witness, value.to_field(), witness_map)?;
}
}
}
}
Ok(())
}

pub(super) fn resolve_pending_foreign_call(&mut self, foreign_call_result: ForeignCallResult) {
match self.vm.get_status() {
VMStatus::ForeignCallWait { .. } => self.vm.resolve_foreign_call(foreign_call_result),
_ => unreachable!("Brillig VM is not waiting for a foreign call"),
}
}
}

/// Encapsulates a request from a Brillig VM process that encounters a [foreign call opcode][acir::brillig_vm::Opcode::ForeignCall]
Expand Down
73 changes: 47 additions & 26 deletions acvm-repo/acvm/src/pwg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ use acir::{
use acvm_blackbox_solver::BlackBoxResolutionError;

use self::{
arithmetic::ArithmeticSolver, brillig::BrilligSolver, directives::solve_directives,
arithmetic::ArithmeticSolver,
brillig::{BrilligSolver, BrilligSolverStatus},
directives::solve_directives,
memory_op::MemoryOpSolver,
};
use crate::{BlackBoxFunctionSolver, Language};
Expand Down Expand Up @@ -141,9 +143,7 @@ pub struct ACVM<'a, B: BlackBoxFunctionSolver> {

witness_map: WitnessMap,

/// Results of oracles/functions external to brillig like a database read.
// Each element of this vector corresponds to a single foreign call but may contain several values.
foreign_call_results: HashMap<usize, Vec<ForeignCallResult>>,
brillig_solver: Option<BrilligSolver<'a, B>>,
}

impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> {
Expand All @@ -156,7 +156,7 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> {
opcodes,
instruction_pointer: 0,
witness_map: initial_witness,
foreign_call_results: HashMap::default(),
brillig_solver: None,
}
}

Expand Down Expand Up @@ -221,10 +221,8 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> {
panic!("ACVM is not expecting a foreign call response as no call was made");
}

// We want to inject the foreign call result into the brillig opcode which initiated the call.
let foreign_call_results =
self.foreign_call_results.entry(self.instruction_pointer).or_default();
foreign_call_results.push(foreign_call_result);
let brillig_solver = self.brillig_solver.as_mut().expect("No active Brillig solver");
brillig_solver.resolve_pending_foreign_call(foreign_call_result);

// Now that the foreign call has been resolved then we can resume execution.
self.status(ACVMStatus::InProgress);
Expand Down Expand Up @@ -260,23 +258,10 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> {
let solver = self.block_solvers.entry(*block_id).or_default();
solver.solve_memory_op(op, &mut self.witness_map, predicate)
}
Opcode::Brillig(brillig) => {
let foreign_call_results = self
.foreign_call_results
.get(&self.instruction_pointer)
.cloned()
.unwrap_or_default();
match BrilligSolver::solve(
&mut self.witness_map,
brillig,
foreign_call_results,
self.backend,
self.instruction_pointer,
) {
Ok(Some(foreign_call)) => return self.wait_for_foreign_call(foreign_call),
res => res.map(|_| ()),
}
}
Opcode::Brillig(_) => match self.solve_brillig_opcode() {
Ok(Some(foreign_call)) => return self.wait_for_foreign_call(foreign_call),
res => res.map(|_| ()),
},
};
match resolution {
Ok(()) => {
Expand Down Expand Up @@ -310,6 +295,42 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> {
}
}
}

fn solve_brillig_opcode(
&mut self,
) -> Result<Option<ForeignCallWaitInfo>, OpcodeResolutionError> {
let Opcode::Brillig(brillig) = &self.opcodes[self.instruction_pointer] else {
unreachable!("Not executing a Brillig opcode");
};
let witness = &mut self.witness_map;
if BrilligSolver::<B>::should_skip(witness, brillig)? {
BrilligSolver::<B>::zero_out_brillig_outputs(witness, brillig).map(|_| None)
} else {
// If we're resuming execution after resolving a foreign call then
// there will be a cached `BrilligSolver` to avoid recomputation.
let mut solver: BrilligSolver<'_, B> = match self.brillig_solver.take() {
Some(solver) => solver,
None => {
BrilligSolver::new(witness, brillig, self.backend, self.instruction_pointer)?
}
};
match solver.solve()? {
BrilligSolverStatus::ForeignCallWait(foreign_call) => {
// Cache the current state of the solver
self.brillig_solver = Some(solver);
Ok(Some(foreign_call))
TomAFrench marked this conversation as resolved.
Show resolved Hide resolved
TomAFrench marked this conversation as resolved.
Show resolved Hide resolved
}
BrilligSolverStatus::InProgress => {
unreachable!("Brillig solver still in progress")
}
BrilligSolverStatus::Finished => {
// Write execution outputs
solver.finalize(witness, brillig)?;
Ok(None)
}
}
}
}
}

// Returns the concrete value for a particular witness
Expand Down
22 changes: 17 additions & 5 deletions acvm-repo/brillig_vm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> {
status
}

pub fn get_status(&self) -> VMStatus {
self.status.clone()
}

/// Sets the current status of the VM to Finished (completed execution).
fn finish(&mut self) -> VMStatus {
self.status(VMStatus::Finished)
Expand All @@ -127,6 +131,14 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> {
self.status(VMStatus::ForeignCallWait { function, inputs })
}

pub fn resolve_foreign_call(&mut self, foreign_call_result: ForeignCallResult) {
if self.foreign_call_counter < self.foreign_call_results.len() {
panic!("No unresolved foreign calls");
}
self.foreign_call_results.push(foreign_call_result);
self.status(VMStatus::InProgress);
}
TomAFrench marked this conversation as resolved.
Show resolved Hide resolved

/// Sets the current status of the VM to `fail`.
/// Indicating that the VM encountered a `Trap` Opcode
/// or an invalid state.
Expand Down Expand Up @@ -926,7 +938,7 @@ mod tests {
);

// Push result we're waiting for
vm.foreign_call_results.push(
vm.resolve_foreign_call(
Value::from(10u128).into(), // Result of doubling 5u128
);

Expand Down Expand Up @@ -987,7 +999,7 @@ mod tests {
);

// Push result we're waiting for
vm.foreign_call_results.push(expected_result.clone().into());
vm.resolve_foreign_call(expected_result.clone().into());

// Resume VM
brillig_execute(&mut vm);
Expand Down Expand Up @@ -1060,7 +1072,7 @@ mod tests {
);

// Push result we're waiting for
vm.foreign_call_results.push(ForeignCallResult {
vm.resolve_foreign_call(ForeignCallResult {
values: vec![ForeignCallParam::Array(output_string.clone())],
});

Expand Down Expand Up @@ -1122,7 +1134,7 @@ mod tests {
);

// Push result we're waiting for
vm.foreign_call_results.push(expected_result.clone().into());
vm.resolve_foreign_call(expected_result.clone().into());

// Resume VM
brillig_execute(&mut vm);
Expand Down Expand Up @@ -1207,7 +1219,7 @@ mod tests {
);

// Push result we're waiting for
vm.foreign_call_results.push(expected_result.clone().into());
vm.resolve_foreign_call(expected_result.clone().into());

// Resume VM
brillig_execute(&mut vm);
Expand Down
Loading