Skip to content

Commit

Permalink
feat(avm)!: revert/rethrow oracle
Browse files Browse the repository at this point in the history
  • Loading branch information
fcarreiro committed Oct 24, 2024
1 parent ac8e6d7 commit 7c45995
Show file tree
Hide file tree
Showing 11 changed files with 126 additions and 53 deletions.
74 changes: 50 additions & 24 deletions avm-transpiler/src/transpile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,29 +316,11 @@ pub fn brillig_to_avm(
});
}
BrilligOpcode::Trap { revert_data } => {
let bits_needed =
*[bits_needed_for(&revert_data.pointer), bits_needed_for(&revert_data.size)]
.iter()
.max()
.unwrap();
let avm_opcode = match bits_needed {
8 => AvmOpcode::REVERT_8,
16 => AvmOpcode::REVERT_16,
_ => panic!("REVERT only support 8 or 16 bit encodings, got: {}", bits_needed),
};
avm_instrs.push(AvmInstruction {
opcode: avm_opcode,
indirect: Some(
AddressingModeBuilder::default()
.indirect_operand(&revert_data.pointer)
.build(),
),
operands: vec![
make_operand(bits_needed, &revert_data.pointer.to_usize()),
make_operand(bits_needed, &revert_data.size),
],
..Default::default()
});
generate_revert_instruction(
&mut avm_instrs,
&revert_data.pointer,
&revert_data.size,
);
}
BrilligOpcode::Cast { destination, source, bit_size } => {
handle_cast(&mut avm_instrs, source, destination, *bit_size);
Expand Down Expand Up @@ -418,6 +400,7 @@ fn handle_foreign_call(
}
"avmOpcodeCalldataCopy" => handle_calldata_copy(avm_instrs, destinations, inputs),
"avmOpcodeReturn" => handle_return(avm_instrs, destinations, inputs),
"avmOpcodeRevert" => handle_revert(avm_instrs, destinations, inputs),
"avmOpcodeStorageRead" => handle_storage_read(avm_instrs, destinations, inputs),
"avmOpcodeStorageWrite" => handle_storage_write(avm_instrs, destinations, inputs),
"debugLog" => handle_debug_log(avm_instrs, destinations, inputs),
Expand Down Expand Up @@ -929,6 +912,32 @@ fn generate_cast_instruction(
}
}

/// Generates an AVM REVERT instruction.
fn generate_revert_instruction(
avm_instrs: &mut Vec<AvmInstruction>,
revert_data_pointer: &MemoryAddress,
revert_data_size: &MemoryAddress,
) {
let bits_needed =
*[revert_data_pointer, revert_data_size].map(bits_needed_for).iter().max().unwrap();
let avm_opcode = match bits_needed {
8 => AvmOpcode::REVERT_8,
16 => AvmOpcode::REVERT_16,
_ => panic!("REVERT only support 8 or 16 bit encodings, got: {}", bits_needed),
};
avm_instrs.push(AvmInstruction {
opcode: avm_opcode,
indirect: Some(
AddressingModeBuilder::default().indirect_operand(revert_data_pointer).build(),
),
operands: vec![
make_operand(bits_needed, &revert_data_pointer.to_usize()),
make_operand(bits_needed, &revert_data_size.to_usize()),
],
..Default::default()
});
}

/// Generates an AVM MOV instruction.
fn generate_mov_instruction(
indirect: Option<AvmOperand>,
Expand Down Expand Up @@ -1214,7 +1223,6 @@ fn handle_return(
assert!(inputs.len() == 1);
assert!(destinations.len() == 0);

// First arg is the size, which is ignored because it's redundant.
let (return_data_offset, return_data_size) = match inputs[0] {
ValueOrArray::HeapArray(HeapArray { pointer, size }) => (pointer, size as u32),
_ => panic!("Return instruction's args input should be a HeapArray"),
Expand All @@ -1233,6 +1241,24 @@ fn handle_return(
});
}

// #[oracle(avmOpcodeRevert)]
// unconstrained fn revert_opcode(revertdata: [Field]) {}
fn handle_revert(
avm_instrs: &mut Vec<AvmInstruction>,
destinations: &Vec<ValueOrArray>,
inputs: &Vec<ValueOrArray>,
) {
assert!(inputs.len() == 1);
assert!(destinations.len() == 0);

let (revert_data_offset, revert_data_size_offset) = match inputs[0] {
ValueOrArray::HeapVector(HeapVector { pointer, size }) => (&pointer, &size),
_ => panic!("Revert instruction's args input should be a HeapVector"),
};

generate_revert_instruction(avm_instrs, revert_data_offset, revert_data_size_offset);
}

/// Emit a storage write opcode
/// The current implementation writes an array of values into storage ( contiguous slots in memory )
fn handle_storage_write(
Expand Down
7 changes: 7 additions & 0 deletions noir-projects/aztec-nr/aztec/src/context/public_context.nr
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,13 @@ unconstrained fn calldata_copy_opcode<let N: u32>(cdoffset: u32, copy_size: u32)
#[oracle(avmOpcodeReturn)]
unconstrained fn return_opcode<let N: u32>(returndata: [Field; N]) {}

// This opcode reverts using the exact data given. In general it should only be used
// to do rethrows, where the revert data is the same as the original revert data.
// For normal reverts, use Noir's `assert` which, on top of reverting, will also add
// an error selector to the revert data.
#[oracle(avmOpcodeRevert)]
unconstrained fn revert_opcode(revertdata: [Field]) {}

#[oracle(avmOpcodeCall)]
unconstrained fn call_opcode<let RET_SIZE: u32>(
gas: [Field; 2], // gas allocation: [l2_gas, da_gas]
Expand Down
14 changes: 12 additions & 2 deletions noir/noir-repo/acvm-repo/acvm/tests/solver.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::collections::{BTreeMap, HashSet};
use std::sync::Arc;

use acir::brillig::{BitSize, IntegerBitSize};
use acir::brillig::{BitSize, HeapVector, IntegerBitSize};
use acir::{
acir_field::GenericFieldElement,
brillig::{BinaryFieldOp, HeapArray, MemoryAddress, Opcode as BrilligOpcode, ValueOrArray},
Expand Down Expand Up @@ -667,7 +667,12 @@ fn unsatisfied_opcode_resolved_brillig() {
let jmp_if_opcode =
BrilligOpcode::JumpIf { condition: MemoryAddress::direct(2), location: location_of_stop };

let trap_opcode = BrilligOpcode::Trap { revert_data: HeapArray::default() };
let trap_opcode = BrilligOpcode::Trap {
revert_data: HeapVector {
pointer: MemoryAddress::direct(0),
size: MemoryAddress::direct(3),
},
};
let stop_opcode = BrilligOpcode::Stop { return_data_offset: 0, return_data_size: 0 };

let brillig_bytecode = BrilligBytecode {
Expand All @@ -682,6 +687,11 @@ fn unsatisfied_opcode_resolved_brillig() {
bit_size: BitSize::Integer(IntegerBitSize::U32),
value: FieldElement::from(0u64),
},
BrilligOpcode::Const {
destination: MemoryAddress::direct(3),
bit_size: BitSize::Integer(IntegerBitSize::U32),
value: FieldElement::from(0u64),
},
calldata_copy_opcode,
equal_opcode,
jmp_if_opcode,
Expand Down
2 changes: 1 addition & 1 deletion noir/noir-repo/acvm-repo/brillig/src/opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ pub enum BrilligOpcode<F> {
BlackBox(BlackBoxOp),
/// Used to denote execution failure, returning data after the offset
Trap {
revert_data: HeapArray,
revert_data: HeapVector,
},
/// Stop execution, returning data after the offset
Stop {
Expand Down
19 changes: 15 additions & 4 deletions noir/noir-repo/acvm-repo/brillig_vm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,10 +314,11 @@ impl<'a, F: AcirField, B: BlackBoxFunctionSolver<F>> VM<'a, F, B> {
self.increment_program_counter()
}
Opcode::Trap { revert_data } => {
if revert_data.size > 0 {
let revert_data_size = self.memory.read(revert_data.size).to_usize();
if revert_data_size > 0 {
self.trap(
self.memory.read_ref(revert_data.pointer).unwrap_direct(),
revert_data.size,
revert_data_size,
)
} else {
self.trap(0, 0)
Expand Down Expand Up @@ -904,8 +905,18 @@ mod tests {
size_address: MemoryAddress::direct(0),
offset_address: MemoryAddress::direct(1),
},
Opcode::Jump { location: 5 },
Opcode::Trap { revert_data: HeapArray::default() },
Opcode::Jump { location: 6 },
Opcode::Const {
destination: MemoryAddress::direct(0),
bit_size: BitSize::Integer(IntegerBitSize::U32),
value: FieldElement::from(0u64),
},
Opcode::Trap {
revert_data: HeapVector {
pointer: MemoryAddress::direct(0),
size: MemoryAddress::direct(0),
},
},
Opcode::BinaryFieldOp {
op: BinaryFieldOp::Equals,
lhs: MemoryAddress::direct(0),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,17 @@ pub(crate) mod tests {
// uses unresolved jumps which requires a block to be constructed in SSA and
// we don't need this for Brillig IR tests
context.push_opcode(BrilligOpcode::JumpIf { condition: r_equality, location: 8 });
context.push_opcode(BrilligOpcode::Trap { revert_data: HeapArray::default() });
context.push_opcode(BrilligOpcode::Const {
destination: MemoryAddress::direct(0),
bit_size: BitSize::Integer(IntegerBitSize::U32),
value: FieldElement::from(0u64),
});
context.push_opcode(BrilligOpcode::Trap {
revert_data: HeapVector {
pointer: MemoryAddress::direct(0),
size: MemoryAddress::direct(0),
},
});

context.stop_instruction();

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use acvm::{
acir::brillig::{HeapArray, MemoryAddress},
acir::brillig::{HeapVector, MemoryAddress},
AcirField,
};

Expand Down Expand Up @@ -157,12 +157,12 @@ impl<F: AcirField + DebugToString, Registers: RegisterAllocator> BrilligContext<
assert!(condition.bit_size == 1);

self.codegen_if_not(condition.address, |ctx| {
let revert_data = HeapArray {
pointer: ctx.allocate_register(),
// + 1 due to the revert data id being the first item returned
size: Self::flattened_tuple_size(&revert_data_types) + 1,
};
ctx.codegen_allocate_immediate_mem(revert_data.pointer, revert_data.size);
// + 1 due to the revert data id being the first item returned
let revert_data_size = Self::flattened_tuple_size(&revert_data_types) + 1;
let revert_data_size_var = ctx.make_usize_constant_instruction(revert_data_size.into());
let revert_data =
HeapVector { pointer: ctx.allocate_register(), size: revert_data_size_var.address };
ctx.codegen_allocate_immediate_mem(revert_data.pointer, revert_data_size);

let current_revert_data_pointer = ctx.allocate_register();
ctx.mov_instruction(current_revert_data_pointer, revert_data.pointer);
Expand Down Expand Up @@ -208,6 +208,7 @@ impl<F: AcirField + DebugToString, Registers: RegisterAllocator> BrilligContext<
);
}
ctx.trap_instruction(revert_data);
ctx.deallocate_single_addr(revert_data_size_var);
ctx.deallocate_register(revert_data.pointer);
ctx.deallocate_register(current_revert_data_pointer);
});
Expand All @@ -223,7 +224,12 @@ impl<F: AcirField + DebugToString, Registers: RegisterAllocator> BrilligContext<
assert!(condition.bit_size == 1);

self.codegen_if_not(condition.address, |ctx| {
ctx.trap_instruction(HeapArray::default());
let revert_data_size_var = ctx.make_usize_constant_instruction(F::zero());
ctx.trap_instruction(HeapVector {
pointer: MemoryAddress::direct(0),
size: revert_data_size_var.address,
});
ctx.deallocate_single_addr(revert_data_size_var);
if let Some(assert_message) = assert_message {
ctx.obj.add_assert_message_to_last_opcode(assert_message);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ impl DebugShow {
}

/// Emits a `trap` instruction.
pub(crate) fn trap_instruction(&self, revert_data: HeapArray) {
pub(crate) fn trap_instruction(&self, revert_data: HeapVector) {
debug_println!(self.enable_debug_trace, " TRAP {}", revert_data);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use acvm::{
acir::{
brillig::{
BinaryFieldOp, BinaryIntOp, BitSize, BlackBoxOp, HeapArray, HeapValueType,
BinaryFieldOp, BinaryIntOp, BitSize, BlackBoxOp, HeapValueType, HeapVector,
MemoryAddress, Opcode as BrilligOpcode, ValueOrArray,
},
AcirField,
Expand Down Expand Up @@ -425,7 +425,7 @@ impl<F: AcirField + DebugToString, Registers: RegisterAllocator> BrilligContext<
self.deallocate_single_addr(offset_var);
}

pub(super) fn trap_instruction(&mut self, revert_data: HeapArray) {
pub(super) fn trap_instruction(&mut self, revert_data: HeapVector) {
self.debug_show.trap_instruction(revert_data);

self.push_opcode(BrilligOpcode::Trap { revert_data });
Expand Down
9 changes: 5 additions & 4 deletions yarn-project/simulator/src/avm/opcodes/external_calls.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,9 @@ describe('External Calls', () => {
Opcode.REVERT_16, // opcode
0x01, // indirect
...Buffer.from('1234', 'hex'), // returnOffset
...Buffer.from('a234', 'hex'), // retSize
...Buffer.from('a234', 'hex'), // retSizeOffset
]);
const inst = new Revert(/*indirect=*/ 0x01, /*returnOffset=*/ 0x1234, /*retSize=*/ 0xa234).as(
const inst = new Revert(/*indirect=*/ 0x01, /*returnOffset=*/ 0x1234, /*retSizeOffset=*/ 0xa234).as(
Opcode.REVERT_16,
Revert.wireFormat16,
);
Expand All @@ -305,9 +305,10 @@ describe('External Calls', () => {
const returnData = [...'assert message'].flatMap(c => new Field(c.charCodeAt(0)));
returnData.unshift(new Field(0n)); // Prepend an error selector

context.machineState.memory.setSlice(0, returnData);
context.machineState.memory.set(0, new Uint32(returnData.length));
context.machineState.memory.setSlice(10, returnData);

const instruction = new Revert(/*indirect=*/ 0, /*returnOffset=*/ 0, returnData.length);
const instruction = new Revert(/*indirect=*/ 0, /*returnOffset=*/ 10, returnData.length);
await instruction.execute(context);

expect(context.machineState.getHalted()).toBe(true);
Expand Down
14 changes: 8 additions & 6 deletions yarn-project/simulator/src/avm/opcodes/external_calls.ts
Original file line number Diff line number Diff line change
Expand Up @@ -204,22 +204,24 @@ export class Revert extends Instruction {
OperandType.UINT16,
];

constructor(private indirect: number, private returnOffset: number, private retSize: number) {
constructor(private indirect: number, private returnOffset: number, private retSizeOffset: number) {
super();
}

public async execute(context: AvmContext): Promise<void> {
const memory = context.machineState.memory.track(this.type);
context.machineState.consumeGas(this.gasCost(this.retSize));

const operands = [this.returnOffset];
const operands = [this.returnOffset, this.retSizeOffset];
const addressing = Addressing.fromWire(this.indirect, operands.length);
const [returnOffset] = addressing.resolve(operands, memory);
const [returnOffset, retSizeOffset] = addressing.resolve(operands, memory);

const output = memory.getSlice(returnOffset, this.retSize).map(word => word.toFr());
memory.checkTag(TypeTag.UINT32, retSizeOffset);
const retSize = memory.get(retSizeOffset).toNumber();
context.machineState.consumeGas(this.gasCost(retSize));
const output = memory.getSlice(returnOffset, retSize).map(word => word.toFr());

context.machineState.revert(output);
memory.assert({ reads: this.retSize, addressing });
memory.assert({ reads: retSize + 1, addressing });
}
}

Expand Down

0 comments on commit 7c45995

Please sign in to comment.