Skip to content
This repository has been archived by the owner on Apr 9, 2024. It is now read-only.

Commit

Permalink
feat(brilig)!: Multiple foreign call inputs (#367)
Browse files Browse the repository at this point in the history
Co-authored-by: ludamad <[email protected]>
  • Loading branch information
vezenovm and ludamad authored Jun 13, 2023
1 parent c0544a9 commit 78d62b2
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 48 deletions.
2 changes: 1 addition & 1 deletion acvm/src/pwg/brillig.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,5 +142,5 @@ pub struct ForeignCallWaitInfo {
/// An identifier interpreted by the caller process
pub function: String,
/// Resolved inputs to a foreign call computed in the previous steps of a Brillig VM process
pub inputs: Vec<Value>,
pub inputs: Vec<Vec<Value>>,
}
27 changes: 15 additions & 12 deletions acvm/src/pwg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,8 @@ mod tests {
// Oracles are named 'foreign calls' in brillig
brillig_vm::Opcode::ForeignCall {
function: "invert".into(),
destination: RegisterValueOrArray::RegisterIndex(RegisterIndex::from(1)),
input: RegisterValueOrArray::RegisterIndex(RegisterIndex::from(0)),
destinations: vec![RegisterValueOrArray::RegisterIndex(RegisterIndex::from(1))],
inputs: vec![RegisterValueOrArray::RegisterIndex(RegisterIndex::from(0))],
},
],
predicate: None,
Expand Down Expand Up @@ -535,8 +535,9 @@ mod tests {
"Should be waiting for a single input"
);
// As caller of VM, need to resolve foreign calls
let foreign_call_result =
vec![Value::from(foreign_call.foreign_call_wait_info.inputs[0].to_field().inverse())];
let foreign_call_result = vec![Value::from(
foreign_call.foreign_call_wait_info.inputs[0][0].to_field().inverse(),
)];
// Alter Brillig oracle opcode with foreign call resolution
let brillig: Brillig = foreign_call.resolve(foreign_call_result.into());
let mut next_opcodes_for_solving = vec![Opcode::Brillig(brillig)];
Expand Down Expand Up @@ -610,13 +611,13 @@ mod tests {
// Oracles are named 'foreign calls' in brillig
brillig_vm::Opcode::ForeignCall {
function: "invert".into(),
destination: RegisterValueOrArray::RegisterIndex(RegisterIndex::from(1)),
input: RegisterValueOrArray::RegisterIndex(RegisterIndex::from(0)),
destinations: vec![RegisterValueOrArray::RegisterIndex(RegisterIndex::from(1))],
inputs: vec![RegisterValueOrArray::RegisterIndex(RegisterIndex::from(0))],
},
brillig_vm::Opcode::ForeignCall {
function: "invert".into(),
destination: RegisterValueOrArray::RegisterIndex(RegisterIndex::from(3)),
input: RegisterValueOrArray::RegisterIndex(RegisterIndex::from(2)),
destinations: vec![RegisterValueOrArray::RegisterIndex(RegisterIndex::from(3))],
inputs: vec![RegisterValueOrArray::RegisterIndex(RegisterIndex::from(2))],
},
],
predicate: None,
Expand Down Expand Up @@ -669,7 +670,8 @@ mod tests {
"Should be waiting for a single input"
);

let x_plus_y_inverse = foreign_call.foreign_call_wait_info.inputs[0].to_field().inverse();
let x_plus_y_inverse =
foreign_call.foreign_call_wait_info.inputs[0][0].to_field().inverse();
// Alter Brillig oracle opcode
let brillig: Brillig = foreign_call.resolve(vec![Value::from(x_plus_y_inverse)].into());

Expand All @@ -693,7 +695,8 @@ mod tests {
"Should be waiting for a single input"
);

let i_plus_j_inverse = foreign_call.foreign_call_wait_info.inputs[0].to_field().inverse();
let i_plus_j_inverse =
foreign_call.foreign_call_wait_info.inputs[0][0].to_field().inverse();
assert_ne!(x_plus_y_inverse, i_plus_j_inverse);
// Alter Brillig oracle opcode
let brillig = foreign_call.resolve(vec![Value::from(i_plus_j_inverse)].into());
Expand Down Expand Up @@ -756,8 +759,8 @@ mod tests {
// Oracles are named 'foreign calls' in brillig
brillig_vm::Opcode::ForeignCall {
function: "invert".into(),
destination: RegisterValueOrArray::RegisterIndex(RegisterIndex::from(1)),
input: RegisterValueOrArray::RegisterIndex(RegisterIndex::from(0)),
destinations: vec![RegisterValueOrArray::RegisterIndex(RegisterIndex::from(1))],
inputs: vec![RegisterValueOrArray::RegisterIndex(RegisterIndex::from(0))],
},
],
predicate: Some(Expression::default()),
Expand Down
154 changes: 123 additions & 31 deletions brillig_vm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ pub enum VMStatus {
/// Interpreted by simulator context
function: String,
/// Input values
inputs: Vec<Value>,
/// Each input is a list of values as an input can be either a single value or a memory pointer
inputs: Vec<Vec<Value>>,
},
}

Expand All @@ -44,11 +45,18 @@ pub enum VMStatus {
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)]
pub struct ForeignCallResult {
/// Resolved output values of the foreign call.
pub values: Vec<Value>,
/// Each output is its own list of values as an output can be either a single value or a memory pointer
pub values: Vec<Vec<Value>>,
}

impl From<Vec<Value>> for ForeignCallResult {
fn from(values: Vec<Value>) -> Self {
ForeignCallResult { values: vec![values] }
}
}

impl From<Vec<Vec<Value>>> for ForeignCallResult {
fn from(values: Vec<Vec<Value>>) -> Self {
ForeignCallResult { values }
}
}
Expand Down Expand Up @@ -110,7 +118,7 @@ impl VM {

/// Sets the status of the VM to `ForeignCallWait`.
/// Indicating that the VM is now waiting for a foreign call to be resolved.
fn wait_for_foreign_call(&mut self, function: String, inputs: Vec<Value>) -> VMStatus {
fn wait_for_foreign_call(&mut self, function: String, inputs: Vec<Vec<Value>>) -> VMStatus {
self.status(VMStatus::ForeignCallWait { function, inputs })
}

Expand Down Expand Up @@ -176,7 +184,7 @@ impl VM {
self.fail("return opcode hit, but callstack already empty".to_string())
}
}
Opcode::ForeignCall { function, destination, input } => {
Opcode::ForeignCall { function, destinations, inputs } => {
if self.foreign_call_counter >= self.foreign_call_results.len() {
// When this opcode is called, it is possible that the results of a foreign call are
// not yet known (not enough entries in `foreign_call_results`).
Expand All @@ -185,33 +193,45 @@ impl VM {
// resolved inputs back to the caller. Once the caller pushes to `foreign_call_results`,
// they can then make another call to the VM that starts at this opcode
// but has the necessary results to proceed with execution.
let resolved_inputs = self.get_register_value_or_memory_values(*input);
let resolved_inputs = inputs
.iter()
.map(|input| self.get_register_value_or_memory_values(*input))
.collect::<Vec<_>>();
return self.wait_for_foreign_call(function.clone(), resolved_inputs);
}

let ForeignCallResult { values } =
&self.foreign_call_results[self.foreign_call_counter];
match destination {
RegisterValueOrArray::RegisterIndex(index) => {
assert_eq!(
values.len(),
1,
"Function result size does not match brillig bytecode"
);
self.registers.set(*index, values[0])
}
RegisterValueOrArray::HeapArray(index, size) => {
let destination_value = self.registers.get(*index);
assert_eq!(
values.len(),
*size,
"Function result size does not match brillig bytecode"
);
for (i, value) in values.iter().enumerate() {
self.memory[destination_value.to_usize() + i] = *value;

for (destination, values) in destinations.iter().zip(values) {
match destination {
RegisterValueOrArray::RegisterIndex(index) => {
assert_eq!(
values.len(),
1,
"Function result size does not match brillig bytecode"
);
self.registers.set(*index, values[0])
}
RegisterValueOrArray::HeapArray(index, size) => {
let destination_value = self.registers.get(*index);
assert_eq!(
values.len(),
*size,
"Function result size does not match brillig bytecode"
);
for (i, value) in values.iter().enumerate() {
self.memory[destination_value.to_usize() + i] = *value;
}
}
}
}

// This check must come after resolving the foreign call outputs as `fail` uses a mutable reference
if destinations.len() != values.len() {
self.fail(format!("{} output values were provided as a foreign call result for {} destination slots", values.len(), destinations.len()));
}

self.foreign_call_counter += 1;
self.increment_program_counter()
}
Expand Down Expand Up @@ -804,8 +824,8 @@ mod tests {
// Call foreign function "double" with the input register
Opcode::ForeignCall {
function: "double".into(),
destination: RegisterValueOrArray::RegisterIndex(r_result),
input: RegisterValueOrArray::RegisterIndex(r_input),
destinations: vec![RegisterValueOrArray::RegisterIndex(r_result)],
inputs: vec![RegisterValueOrArray::RegisterIndex(r_input)],
},
];

Expand All @@ -816,13 +836,13 @@ mod tests {
vm.status,
VMStatus::ForeignCallWait {
function: "double".into(),
inputs: vec![Value::from(5u128)]
inputs: vec![vec![Value::from(5u128)]]
}
);

// Push result we're waiting for
vm.foreign_call_results.push(ForeignCallResult {
values: vec![Value::from(10u128)], // Result of doubling 5u128
values: vec![vec![Value::from(10u128)]], // Result of doubling 5u128
});

// Resume VM
Expand Down Expand Up @@ -859,8 +879,8 @@ mod tests {
// *output = matrix_2x2_transpose(*input)
Opcode::ForeignCall {
function: "matrix_2x2_transpose".into(),
destination: RegisterValueOrArray::HeapArray(r_output, initial_matrix.len()),
input: RegisterValueOrArray::HeapArray(r_input, initial_matrix.len()),
destinations: vec![RegisterValueOrArray::HeapArray(r_output, initial_matrix.len())],
inputs: vec![RegisterValueOrArray::HeapArray(r_input, initial_matrix.len())],
},
];

Expand All @@ -871,12 +891,84 @@ mod tests {
vm.status,
VMStatus::ForeignCallWait {
function: "matrix_2x2_transpose".into(),
inputs: initial_matrix
inputs: vec![initial_matrix]
}
);

// Push result we're waiting for
vm.foreign_call_results.push(ForeignCallResult { values: vec![expected_result.clone()] });

// Resume VM
brillig_execute(&mut vm);

// Check that VM finished once resumed
assert_eq!(vm.status, VMStatus::Finished);

// Check result in memory
let result_values = vm.memory[0..4].to_vec();
assert_eq!(result_values, expected_result);

// Ensure the foreign call counter has been incremented
assert_eq!(vm.foreign_call_counter, 1);
}

#[test]
fn foreign_call_opcode_multiple_array_inputs_result() {
let r_input_a = RegisterIndex::from(0);
let r_input_b = RegisterIndex::from(1);
let r_output = RegisterIndex::from(2);

// Define a simple 2x2 matrix in memory
let matrix_a =
vec![Value::from(1u128), Value::from(2u128), Value::from(3u128), Value::from(4u128)];

let matrix_b = vec![
Value::from(10u128),
Value::from(11u128),
Value::from(12u128),
Value::from(13u128),
];

// Transpose of the matrix (but arbitrary for this test, the 'correct value')
let expected_result = vec![
Value::from(34u128),
Value::from(37u128),
Value::from(78u128),
Value::from(85u128),
];

let matrix_mul_program = vec![
// input = 0
Opcode::Const { destination: r_input_a, value: Value::from(0u128) },
// input = 0
Opcode::Const { destination: r_input_b, value: Value::from(4u128) },
// output = 0
Opcode::Const { destination: r_output, value: Value::from(0u128) },
// *output = matrix_2x2_transpose(*input)
Opcode::ForeignCall {
function: "matrix_2x2_transpose".into(),
destinations: vec![RegisterValueOrArray::HeapArray(r_output, matrix_a.len())],
inputs: vec![
RegisterValueOrArray::HeapArray(r_input_a, matrix_a.len()),
RegisterValueOrArray::HeapArray(r_input_b, matrix_b.len()),
],
},
];
let mut initial_memory = matrix_a.clone();
initial_memory.extend(matrix_b.clone());
let mut vm = brillig_execute_and_get_vm(initial_memory, matrix_mul_program);

// Check that VM is waiting
assert_eq!(
vm.status,
VMStatus::ForeignCallWait {
function: "matrix_2x2_transpose".into(),
inputs: vec![matrix_a, matrix_b]
}
);

// Push result we're waiting for
vm.foreign_call_results.push(ForeignCallResult { values: expected_result.clone() });
vm.foreign_call_results.push(ForeignCallResult { values: vec![expected_result.clone()] });

// Resume VM
brillig_execute(&mut vm);
Expand Down
8 changes: 4 additions & 4 deletions brillig_vm/src/opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ pub enum Opcode {
/// Interpreted by caller context, ie this will have different meanings depending on
/// who the caller is.
function: String,
/// Destination register (may be a memory pointer).
destination: RegisterValueOrArray,
/// Input register (may be a memory pointer).
input: RegisterValueOrArray,
/// Destination registers (may be single values or memory pointers).
destinations: Vec<RegisterValueOrArray>,
/// Input registers (may be single values or memory pointers).
inputs: Vec<RegisterValueOrArray>,
},
Mov {
destination: RegisterIndex,
Expand Down

0 comments on commit 78d62b2

Please sign in to comment.