-
Notifications
You must be signed in to change notification settings - Fork 219
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
chore: refactor foreign call executors (#6659)
- Loading branch information
1 parent
ecaf63d
commit f0c5fe9
Showing
17 changed files
with
601 additions
and
510 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
use acvm::{ | ||
acir::brillig::{ForeignCallParam, ForeignCallResult}, | ||
pwg::ForeignCallWaitInfo, | ||
AcirField, | ||
}; | ||
use noirc_printable_type::{decode_string_value, ForeignCallError}; | ||
use serde::{Deserialize, Serialize}; | ||
|
||
use super::{ForeignCall, ForeignCallExecutor}; | ||
|
||
/// This struct represents an oracle mock. It can be used for testing programs that use oracles. | ||
#[derive(Debug, PartialEq, Eq, Clone)] | ||
struct MockedCall<F> { | ||
/// The id of the mock, used to update or remove it | ||
id: usize, | ||
/// The oracle it's mocking | ||
name: String, | ||
/// Optionally match the parameters | ||
params: Option<Vec<ForeignCallParam<F>>>, | ||
/// The parameters with which the mock was last called | ||
last_called_params: Option<Vec<ForeignCallParam<F>>>, | ||
/// The result to return when this mock is called | ||
result: ForeignCallResult<F>, | ||
/// How many times should this mock be called before it is removed | ||
times_left: Option<u64>, | ||
} | ||
|
||
impl<F> MockedCall<F> { | ||
fn new(id: usize, name: String) -> Self { | ||
Self { | ||
id, | ||
name, | ||
params: None, | ||
last_called_params: None, | ||
result: ForeignCallResult { values: vec![] }, | ||
times_left: None, | ||
} | ||
} | ||
} | ||
|
||
impl<F: PartialEq> MockedCall<F> { | ||
fn matches(&self, name: &str, params: &[ForeignCallParam<F>]) -> bool { | ||
self.name == name && (self.params.is_none() || self.params.as_deref() == Some(params)) | ||
} | ||
} | ||
|
||
#[derive(Debug, Default)] | ||
pub(crate) struct MockForeignCallExecutor<F> { | ||
/// Mocks have unique ids used to identify them in Noir, allowing to update or remove them. | ||
last_mock_id: usize, | ||
/// The registered mocks | ||
mocked_responses: Vec<MockedCall<F>>, | ||
} | ||
|
||
impl<F: AcirField> MockForeignCallExecutor<F> { | ||
fn extract_mock_id( | ||
foreign_call_inputs: &[ForeignCallParam<F>], | ||
) -> Result<(usize, &[ForeignCallParam<F>]), ForeignCallError> { | ||
let (id, params) = | ||
foreign_call_inputs.split_first().ok_or(ForeignCallError::MissingForeignCallInputs)?; | ||
let id = | ||
usize::try_from(id.unwrap_field().try_to_u64().expect("value does not fit into u64")) | ||
.expect("value does not fit into usize"); | ||
Ok((id, params)) | ||
} | ||
|
||
fn find_mock_by_id(&self, id: usize) -> Option<&MockedCall<F>> { | ||
self.mocked_responses.iter().find(|response| response.id == id) | ||
} | ||
|
||
fn find_mock_by_id_mut(&mut self, id: usize) -> Option<&mut MockedCall<F>> { | ||
self.mocked_responses.iter_mut().find(|response| response.id == id) | ||
} | ||
|
||
fn parse_string(param: &ForeignCallParam<F>) -> String { | ||
let fields: Vec<_> = param.fields().to_vec(); | ||
decode_string_value(&fields) | ||
} | ||
} | ||
|
||
impl<F: AcirField + Serialize + for<'a> Deserialize<'a>> ForeignCallExecutor<F> | ||
for MockForeignCallExecutor<F> | ||
{ | ||
fn execute( | ||
&mut self, | ||
foreign_call: &ForeignCallWaitInfo<F>, | ||
) -> Result<ForeignCallResult<F>, ForeignCallError> { | ||
let foreign_call_name = foreign_call.function.as_str(); | ||
match ForeignCall::lookup(foreign_call_name) { | ||
Some(ForeignCall::CreateMock) => { | ||
let mock_oracle_name = Self::parse_string(&foreign_call.inputs[0]); | ||
assert!(ForeignCall::lookup(&mock_oracle_name).is_none()); | ||
let id = self.last_mock_id; | ||
self.mocked_responses.push(MockedCall::new(id, mock_oracle_name)); | ||
self.last_mock_id += 1; | ||
|
||
Ok(F::from(id).into()) | ||
} | ||
Some(ForeignCall::SetMockParams) => { | ||
let (id, params) = Self::extract_mock_id(&foreign_call.inputs)?; | ||
self.find_mock_by_id_mut(id) | ||
.unwrap_or_else(|| panic!("Unknown mock id {}", id)) | ||
.params = Some(params.to_vec()); | ||
|
||
Ok(ForeignCallResult::default()) | ||
} | ||
Some(ForeignCall::GetMockLastParams) => { | ||
let (id, _) = Self::extract_mock_id(&foreign_call.inputs)?; | ||
let mock = | ||
self.find_mock_by_id(id).unwrap_or_else(|| panic!("Unknown mock id {}", id)); | ||
|
||
let last_called_params = mock | ||
.last_called_params | ||
.clone() | ||
.unwrap_or_else(|| panic!("Mock {} was never called", mock.name)); | ||
|
||
Ok(last_called_params.into()) | ||
} | ||
Some(ForeignCall::SetMockReturns) => { | ||
let (id, params) = Self::extract_mock_id(&foreign_call.inputs)?; | ||
self.find_mock_by_id_mut(id) | ||
.unwrap_or_else(|| panic!("Unknown mock id {}", id)) | ||
.result = ForeignCallResult { values: params.to_vec() }; | ||
|
||
Ok(ForeignCallResult::default()) | ||
} | ||
Some(ForeignCall::SetMockTimes) => { | ||
let (id, params) = Self::extract_mock_id(&foreign_call.inputs)?; | ||
let times = | ||
params[0].unwrap_field().try_to_u64().expect("Invalid bit size of times"); | ||
|
||
self.find_mock_by_id_mut(id) | ||
.unwrap_or_else(|| panic!("Unknown mock id {}", id)) | ||
.times_left = Some(times); | ||
|
||
Ok(ForeignCallResult::default()) | ||
} | ||
Some(ForeignCall::ClearMock) => { | ||
let (id, _) = Self::extract_mock_id(&foreign_call.inputs)?; | ||
self.mocked_responses.retain(|response| response.id != id); | ||
Ok(ForeignCallResult::default()) | ||
} | ||
_ => { | ||
let mock_response_position = self | ||
.mocked_responses | ||
.iter() | ||
.position(|response| response.matches(foreign_call_name, &foreign_call.inputs)); | ||
|
||
if let Some(response_position) = mock_response_position { | ||
// If the program has registered a mocked response to this oracle call then we prefer responding | ||
// with that. | ||
|
||
let mock = self | ||
.mocked_responses | ||
.get_mut(response_position) | ||
.expect("Invalid position of mocked response"); | ||
|
||
mock.last_called_params = Some(foreign_call.inputs.clone()); | ||
|
||
let result = mock.result.values.clone(); | ||
|
||
if let Some(times_left) = &mut mock.times_left { | ||
*times_left -= 1; | ||
if *times_left == 0 { | ||
self.mocked_responses.remove(response_position); | ||
} | ||
} | ||
|
||
Ok(result.into()) | ||
} else { | ||
Err(ForeignCallError::NoHandler(foreign_call_name.to_string())) | ||
} | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
use std::path::PathBuf; | ||
|
||
use acvm::{acir::brillig::ForeignCallResult, pwg::ForeignCallWaitInfo, AcirField}; | ||
use mocker::MockForeignCallExecutor; | ||
use noirc_printable_type::ForeignCallError; | ||
use print::PrintForeignCallExecutor; | ||
use rand::Rng; | ||
use rpc::RPCForeignCallExecutor; | ||
use serde::{Deserialize, Serialize}; | ||
|
||
mod mocker; | ||
mod print; | ||
mod rpc; | ||
|
||
pub trait ForeignCallExecutor<F> { | ||
fn execute( | ||
&mut self, | ||
foreign_call: &ForeignCallWaitInfo<F>, | ||
) -> Result<ForeignCallResult<F>, ForeignCallError>; | ||
} | ||
|
||
/// This enumeration represents the Brillig foreign calls that are natively supported by nargo. | ||
/// After resolution of a foreign call, nargo will restart execution of the ACVM | ||
pub enum ForeignCall { | ||
Print, | ||
CreateMock, | ||
SetMockParams, | ||
GetMockLastParams, | ||
SetMockReturns, | ||
SetMockTimes, | ||
ClearMock, | ||
} | ||
|
||
impl std::fmt::Display for ForeignCall { | ||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||
write!(f, "{}", self.name()) | ||
} | ||
} | ||
|
||
impl ForeignCall { | ||
pub(crate) fn name(&self) -> &'static str { | ||
match self { | ||
ForeignCall::Print => "print", | ||
ForeignCall::CreateMock => "create_mock", | ||
ForeignCall::SetMockParams => "set_mock_params", | ||
ForeignCall::GetMockLastParams => "get_mock_last_params", | ||
ForeignCall::SetMockReturns => "set_mock_returns", | ||
ForeignCall::SetMockTimes => "set_mock_times", | ||
ForeignCall::ClearMock => "clear_mock", | ||
} | ||
} | ||
|
||
pub(crate) fn lookup(op_name: &str) -> Option<ForeignCall> { | ||
match op_name { | ||
"print" => Some(ForeignCall::Print), | ||
"create_mock" => Some(ForeignCall::CreateMock), | ||
"set_mock_params" => Some(ForeignCall::SetMockParams), | ||
"get_mock_last_params" => Some(ForeignCall::GetMockLastParams), | ||
"set_mock_returns" => Some(ForeignCall::SetMockReturns), | ||
"set_mock_times" => Some(ForeignCall::SetMockTimes), | ||
"clear_mock" => Some(ForeignCall::ClearMock), | ||
_ => None, | ||
} | ||
} | ||
} | ||
|
||
#[derive(Debug, Default)] | ||
pub struct DefaultForeignCallExecutor<F> { | ||
/// The executor for any [`ForeignCall::Print`] calls. | ||
printer: Option<PrintForeignCallExecutor>, | ||
mocker: MockForeignCallExecutor<F>, | ||
external: Option<RPCForeignCallExecutor>, | ||
} | ||
|
||
impl<F: Default> DefaultForeignCallExecutor<F> { | ||
pub fn new( | ||
show_output: bool, | ||
resolver_url: Option<&str>, | ||
root_path: Option<PathBuf>, | ||
package_name: Option<String>, | ||
) -> Self { | ||
let id = rand::thread_rng().gen(); | ||
let printer = if show_output { Some(PrintForeignCallExecutor) } else { None }; | ||
let external_resolver = resolver_url.map(|resolver_url| { | ||
RPCForeignCallExecutor::new(resolver_url, id, root_path, package_name) | ||
}); | ||
DefaultForeignCallExecutor { | ||
printer, | ||
mocker: MockForeignCallExecutor::default(), | ||
external: external_resolver, | ||
} | ||
} | ||
} | ||
|
||
impl<F: AcirField + Serialize + for<'a> Deserialize<'a>> ForeignCallExecutor<F> | ||
for DefaultForeignCallExecutor<F> | ||
{ | ||
fn execute( | ||
&mut self, | ||
foreign_call: &ForeignCallWaitInfo<F>, | ||
) -> Result<ForeignCallResult<F>, ForeignCallError> { | ||
let foreign_call_name = foreign_call.function.as_str(); | ||
match ForeignCall::lookup(foreign_call_name) { | ||
Some(ForeignCall::Print) => { | ||
if let Some(printer) = &mut self.printer { | ||
printer.execute(foreign_call) | ||
} else { | ||
Ok(ForeignCallResult::default()) | ||
} | ||
} | ||
Some( | ||
ForeignCall::CreateMock | ||
| ForeignCall::SetMockParams | ||
| ForeignCall::GetMockLastParams | ||
| ForeignCall::SetMockReturns | ||
| ForeignCall::SetMockTimes | ||
| ForeignCall::ClearMock, | ||
) => self.mocker.execute(foreign_call), | ||
|
||
None => { | ||
// First check if there's any defined mock responses for this foreign call. | ||
match self.mocker.execute(foreign_call) { | ||
Err(ForeignCallError::NoHandler(_)) => (), | ||
response_or_error => return response_or_error, | ||
}; | ||
|
||
if let Some(external_resolver) = &mut self.external { | ||
// If the user has registered an external resolver then we forward any remaining oracle calls there. | ||
match external_resolver.execute(foreign_call) { | ||
Err(ForeignCallError::NoHandler(_)) => (), | ||
response_or_error => return response_or_error, | ||
}; | ||
} | ||
|
||
// If all executors have no handler for the given foreign call then we cannot | ||
// return a correct response to the ACVM. The best we can do is to return an empty response, | ||
// this allows us to ignore any foreign calls which exist solely to pass information from inside | ||
// the circuit to the environment (e.g. custom logging) as the execution will still be able to progress. | ||
// | ||
// We optimistically return an empty response for all oracle calls as the ACVM will error | ||
// should a response have been required. | ||
Ok(ForeignCallResult::default()) | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
use acvm::{acir::brillig::ForeignCallResult, pwg::ForeignCallWaitInfo, AcirField}; | ||
use noirc_printable_type::{ForeignCallError, PrintableValueDisplay}; | ||
|
||
use super::{ForeignCall, ForeignCallExecutor}; | ||
|
||
#[derive(Debug, Default)] | ||
pub(super) struct PrintForeignCallExecutor; | ||
|
||
impl<F: AcirField> ForeignCallExecutor<F> for PrintForeignCallExecutor { | ||
fn execute( | ||
&mut self, | ||
foreign_call: &ForeignCallWaitInfo<F>, | ||
) -> Result<ForeignCallResult<F>, ForeignCallError> { | ||
let foreign_call_name = foreign_call.function.as_str(); | ||
match ForeignCall::lookup(foreign_call_name) { | ||
Some(ForeignCall::Print) => { | ||
let skip_newline = foreign_call.inputs[0].unwrap_field().is_zero(); | ||
|
||
let foreign_call_inputs = foreign_call | ||
.inputs | ||
.split_first() | ||
.ok_or(ForeignCallError::MissingForeignCallInputs)? | ||
.1; | ||
|
||
let display_values: PrintableValueDisplay<F> = foreign_call_inputs.try_into()?; | ||
let display_string = | ||
format!("{display_values}{}", if skip_newline { "" } else { "\n" }); | ||
|
||
print!("{display_string}"); | ||
|
||
Ok(ForeignCallResult::default()) | ||
} | ||
_ => Err(ForeignCallError::NoHandler(foreign_call_name.to_string())), | ||
} | ||
} | ||
} |
Oops, something went wrong.