Skip to content

Commit

Permalink
Removes program_id parameter from InvokeContext::push().
Browse files Browse the repository at this point in the history
  • Loading branch information
Lichtso committed Oct 12, 2021
1 parent 327dad3 commit 2167377
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 78 deletions.
10 changes: 1 addition & 9 deletions program-runtime/src/instruction_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -594,22 +594,14 @@ impl InstructionProcessor {
.get(0)
.ok_or(InstructionError::GenericError)?;

let program_id = instruction.program_id(&message.account_keys);

// Verify the calling program hasn't misbehaved
invoke_context.verify_and_update(instruction, account_indices, caller_write_privileges)?;

// clear the return data
invoke_context.set_return_data(Vec::new())?;

// Invoke callee
invoke_context.push(
program_id,
message,
instruction,
program_indices,
Some(account_indices),
)?;
invoke_context.push(message, instruction, program_indices, Some(account_indices))?;

let mut instruction_processor = InstructionProcessor::default();
for (program_id, process_instruction) in invoke_context.get_programs().iter() {
Expand Down
89 changes: 28 additions & 61 deletions runtime/src/message_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ impl<'a> ThisInvokeContext<'a> {
impl<'a> InvokeContext for ThisInvokeContext<'a> {
fn push(
&mut self,
key: &Pubkey,
message: &Message,
instruction: &CompiledInstruction,
program_indices: &[usize],
Expand All @@ -126,6 +125,23 @@ impl<'a> InvokeContext for ThisInvokeContext<'a> {
return Err(InstructionError::CallDepth);
}

if let Some(index_of_program_id) = program_indices.last() {
let program_id = &self.accounts[*index_of_program_id].0;
let contains = self
.invoke_stack
.iter()
.any(|frame| frame.program_id() == program_id);
let is_last = if let Some(last_frame) = self.invoke_stack.last() {
last_frame.program_id() == program_id
} else {
false
};
if contains && !is_last {
// Reentrancy not allowed unless caller is calling itself
return Err(InstructionError::ReentrancyNotAllowed);
}
}

if self.invoke_stack.is_empty() {
self.pre_accounts = Vec::with_capacity(instruction.accounts.len());
let mut work = |_unique_index: usize, account_index: usize| {
Expand All @@ -140,20 +156,6 @@ impl<'a> InvokeContext for ThisInvokeContext<'a> {
instruction.visit_each_account(&mut work)?;
}

let contains = self
.invoke_stack
.iter()
.any(|frame| frame.program_id() == Some(key));
let is_last = if let Some(last_frame) = self.invoke_stack.last() {
last_frame.program_id() == Some(key)
} else {
false
};
if contains && !is_last {
// Reentrancy not allowed unless caller is calling itself
return Err(InstructionError::ReentrancyNotAllowed);
}

// Create the KeyedAccounts that will be passed to the program
let demote_program_write_locks = self
.feature_set
Expand Down Expand Up @@ -183,13 +185,9 @@ impl<'a> InvokeContext for ThisInvokeContext<'a> {
)
}))
.collect::<Vec<_>>();
let index_of_program_id = keyed_accounts
.iter()
.take(program_indices.len())
.position(|keyed_account| keyed_account.2 == key)
.unwrap();

self.invoke_stack.push(InvokeContextStackFrame::new(
index_of_program_id,
program_indices.len(),
create_keyed_accounts_unified(keyed_accounts.as_slice()),
));
Ok(())
Expand Down Expand Up @@ -270,7 +268,7 @@ impl<'a> InvokeContext for ThisInvokeContext<'a> {
let program_id = self
.invoke_stack
.last()
.and_then(|frame| frame.program_id())
.map(|frame| frame.program_id())
.ok_or(InstructionError::CallDepth)?;
let rent = &self.rent;
let logger = &self.logger;
Expand Down Expand Up @@ -333,7 +331,7 @@ impl<'a> InvokeContext for ThisInvokeContext<'a> {
fn get_caller(&self) -> Result<&Pubkey, InstructionError> {
self.invoke_stack
.last()
.and_then(|frame| frame.program_id())
.map(|frame| frame.program_id())
.ok_or(InstructionError::CallDepth)
}
fn remove_first_keyed_account(&mut self) -> Result<(), InstructionError> {
Expand Down Expand Up @@ -557,7 +555,7 @@ impl MessageProcessor {

invoke_context.set_instruction_index(instruction_index);
let result = invoke_context
.push(program_id, message, instruction, program_indices, None)
.push(message, instruction, program_indices, None)
.and_then(|_| {
instruction_processor
.process_instruction(&instruction.data, &mut invoke_context)?;
Expand Down Expand Up @@ -701,10 +699,9 @@ mod tests {

// Check call depth increases and has a limit
let mut depth_reached = 0;
for program_id in invoke_stack.iter() {
for _ in 0..invoke_stack.len() {
if Err(InstructionError::CallDepth)
== invoke_context.push(
program_id,
&message,
&message.instructions[0],
&[MAX_DEPTH + depth_reached],
Expand Down Expand Up @@ -804,13 +801,7 @@ mod tests {
&fee_calculator,
);
invoke_context
.push(
&accounts[0].0,
&message,
&message.instructions[0],
&[0],
None,
)
.push(&message, &message.instructions[0], &[0], None)
.unwrap();
assert!(invoke_context
.verify(&message, &message.instructions[0], &[0])
Expand Down Expand Up @@ -1261,13 +1252,7 @@ mod tests {
&fee_calculator,
);
invoke_context
.push(
&caller_program_id,
&message,
&caller_instruction,
&program_indices[..1],
None,
)
.push(&message, &caller_instruction, &program_indices[..1], None)
.unwrap();

// not owned account modified by the caller (before the invoke)
Expand Down Expand Up @@ -1325,13 +1310,7 @@ mod tests {
Instruction::new_with_bincode(callee_program_id, &case.0, metas.clone());
let message = Message::new(&[callee_instruction], None);
invoke_context
.push(
&caller_program_id,
&message,
&caller_instruction,
&program_indices[..1],
None,
)
.push(&message, &caller_instruction, &program_indices[..1], None)
.unwrap();
let caller_write_privileges = message
.account_keys
Expand Down Expand Up @@ -1418,13 +1397,7 @@ mod tests {
&fee_calculator,
);
invoke_context
.push(
&caller_program_id,
&message,
&caller_instruction,
&program_indices,
None,
)
.push(&message, &caller_instruction, &program_indices, None)
.unwrap();

// not owned account modified by the invoker
Expand Down Expand Up @@ -1477,13 +1450,7 @@ mod tests {
Instruction::new_with_bincode(callee_program_id, &case.0, metas.clone());
let message = Message::new(&[callee_instruction.clone()], None);
invoke_context
.push(
&caller_program_id,
&message,
&caller_instruction,
&program_indices,
None,
)
.push(&message, &caller_instruction, &program_indices, None)
.unwrap();
assert_eq!(
InstructionProcessor::native_invoke(
Expand Down
13 changes: 5 additions & 8 deletions sdk/src/process_instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,8 @@ impl<'a> InvokeContextStackFrame<'a> {
}
}

pub fn program_id(&self) -> Option<&Pubkey> {
self.keyed_accounts
.get(self.number_of_program_accounts)
.map(|keyed_account| keyed_account.unsigned_key())
pub fn program_id(&self) -> &Pubkey {
self.keyed_accounts[self.number_of_program_accounts - 1].unsigned_key()
}
}

Expand All @@ -61,7 +59,6 @@ pub trait InvokeContext {
/// Push a stack frame onto the invocation stack
fn push(
&mut self,
key: &Pubkey,
message: &Message,
instruction: &CompiledInstruction,
program_indices: &[usize],
Expand Down Expand Up @@ -482,7 +479,8 @@ impl<'a> MockInvokeContext<'a> {
let number_of_program_accounts = keyed_accounts
.iter()
.position(|keyed_account| keyed_account.unsigned_key() == program_id)
.unwrap_or(0);
.unwrap_or(0)
+ 1;
invoke_context
.invoke_stack
.push(InvokeContextStackFrame::new(
Expand Down Expand Up @@ -511,7 +509,6 @@ pub fn mock_set_sysvar<T: Sysvar>(
impl<'a> InvokeContext for MockInvokeContext<'a> {
fn push(
&mut self,
_key: &Pubkey,
_message: &Message,
_instruction: &CompiledInstruction,
_program_indices: &[usize],
Expand Down Expand Up @@ -548,7 +545,7 @@ impl<'a> InvokeContext for MockInvokeContext<'a> {
fn get_caller(&self) -> Result<&Pubkey, InstructionError> {
self.invoke_stack
.last()
.and_then(|frame| frame.program_id())
.map(|frame| frame.program_id())
.ok_or(InstructionError::CallDepth)
}
fn remove_first_keyed_account(&mut self) -> Result<(), InstructionError> {
Expand Down

0 comments on commit 2167377

Please sign in to comment.