Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: handle calls to non-existent contracts in AVM witgen #10862

Merged
merged 4 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 36 additions & 31 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/bytecode_trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,38 +111,43 @@ void AvmBytecodeTraceBuilder::build_bytecode_hash_columns()
{
// This is the main loop that will generate the bytecode trace
for (auto& contract_bytecode : all_contracts_bytecode) {
FF running_hash = FF::zero();
auto field_encoded_bytecode = encode_bytecode(contract_bytecode.bytecode);
// This size is already based on the number of fields
for (size_t i = 0; i < field_encoded_bytecode.size(); ++i) {
bytecode_hash_trace.push_back(BytecodeHashTraceEntry{
.field_encoded_bytecode = field_encoded_bytecode[i],
.running_hash = running_hash,
.bytecode_field_length_remaining = static_cast<uint16_t>(field_encoded_bytecode.size() - i),
});
// We pair-wise hash the i-th bytecode field with the running hash (which is the output of previous i-1
// round). I.e.
// initially running_hash = 0,
// the first round is running_hash = hash(bytecode[0], running_hash),
// the second round is running_hash = hash(bytecode[1],running_hash), and so on.
running_hash = poseidon2::hash({ field_encoded_bytecode[i], running_hash });
if (contract_bytecode.bytecode.size() == 0) {
vinfo("Excluding non-existent contract from bytecode hash columns...");
} else {
Comment on lines +114 to +116
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only real diff here. We are assuming that if the hinted bytecode is empty, the contract doesn't exist and should therefore be omitted from the bytecode hashing.

FF running_hash = FF::zero();
auto field_encoded_bytecode = encode_bytecode(contract_bytecode.bytecode);
// This size is already based on the number of fields
for (size_t i = 0; i < field_encoded_bytecode.size(); ++i) {
bytecode_hash_trace.push_back(BytecodeHashTraceEntry{
.field_encoded_bytecode = field_encoded_bytecode[i],
.running_hash = running_hash,
.bytecode_field_length_remaining = static_cast<uint16_t>(field_encoded_bytecode.size() - i),
});
// We pair-wise hash the i-th bytecode field with the running hash (which is the output of previous i-1
// round). I.e.
// initially running_hash = 0,
// the first round is running_hash = hash(bytecode[0], running_hash),
// the second round is running_hash = hash(bytecode[1],running_hash), and so on.
running_hash = poseidon2::hash({ field_encoded_bytecode[i], running_hash });
}
// Now running_hash actually contains the bytecode hash
BytecodeHashTraceEntry last_entry;
last_entry.bytecode_field_length_remaining = 0;
last_entry.running_hash = running_hash;
// Assert that the computed bytecode hash is the same as what we received as the hint
ASSERT(running_hash == contract_bytecode.contract_class_id_preimage.public_bytecode_commitment);

last_entry.class_id =
compute_contract_class_id(contract_bytecode.contract_class_id_preimage.artifact_hash,
contract_bytecode.contract_class_id_preimage.private_fn_root,
running_hash);
// Assert that the computed class id is the same as what we received as the hint
ASSERT(last_entry.class_id == contract_bytecode.contract_instance.contract_class_id);

last_entry.contract_address = compute_address_from_instance(contract_bytecode.contract_instance);
// Assert that the computed contract address is the same as what we received as the hint
ASSERT(last_entry.contract_address == contract_bytecode.contract_instance.address);
}
// Now running_hash actually contains the bytecode hash
BytecodeHashTraceEntry last_entry;
last_entry.bytecode_field_length_remaining = 0;
last_entry.running_hash = running_hash;
// Assert that the computed bytecode hash is the same as what we received as the hint
ASSERT(running_hash == contract_bytecode.contract_class_id_preimage.public_bytecode_commitment);

last_entry.class_id = compute_contract_class_id(contract_bytecode.contract_class_id_preimage.artifact_hash,
contract_bytecode.contract_class_id_preimage.private_fn_root,
running_hash);
// Assert that the computed class id is the same as what we received as the hint
ASSERT(last_entry.class_id == contract_bytecode.contract_instance.contract_class_id);

last_entry.contract_address = compute_address_from_instance(contract_bytecode.contract_instance);
// Assert that the computed contract address is the same as what we received as the hint
ASSERT(last_entry.contract_address == contract_bytecode.contract_instance.address);
}
}

Expand Down
3 changes: 2 additions & 1 deletion barretenberg/cpp/src/barretenberg/vm/avm/trace/errors.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ enum class AvmError : uint32_t {
DUPLICATE_NULLIFIER,
SIDE_EFFECT_LIMIT_REACHED,
OUT_OF_GAS,
STATIC_CALL_ALTERATION
STATIC_CALL_ALTERATION,
NO_BYTECODE_FOUND,
};

} // namespace bb::avm_trace
65 changes: 48 additions & 17 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/execution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "barretenberg/vm/aztec_constants.hpp"
#include "barretenberg/vm/constants.hpp"
#include "barretenberg/vm/stats.hpp"
#include "errors.hpp"

#include <algorithm>
#include <cassert>
Expand Down Expand Up @@ -444,10 +445,25 @@ AvmError Execution::execute_enqueued_call(AvmTraceBuilder& trace_builder,
.da_gas_left = da_gas_allocated_to_enqueued_call,
.internal_return_ptr_stack = {},
};
trace_builder.allocate_gas_for_call(l2_gas_allocated_to_enqueued_call, da_gas_allocated_to_enqueued_call);
// Find the bytecode based on contract address of the public call request
std::vector<uint8_t> bytecode =
trace_builder.get_bytecode(trace_builder.current_ext_call_ctx.contract_address, check_bytecode_membership);
std::vector<uint8_t> bytecode;
try {
bytecode =
trace_builder.get_bytecode(trace_builder.current_ext_call_ctx.contract_address, check_bytecode_membership);
} catch ([[maybe_unused]] const std::runtime_error& e) {
info("AVM enqueued call exceptionally halted. Error: No bytecode found for enqueued call");
// FIXME: properly handle case when bytecode is not found!
// For now, we add a dummy row in main trace to mutate later.
// Dummy row in main trace to mutate afterwards.
// This error was encountered before any opcodes were executed, but
// we need at least one row in the execution trace to then mutate and say "it halted and consumed all gas!"
trace_builder.op_add(0, 0, 0, 0, OpCode::ADD_8);
trace_builder.handle_exceptional_halt();
return AvmError::NO_BYTECODE_FOUND;
;
}

trace_builder.allocate_gas_for_call(l2_gas_allocated_to_enqueued_call, da_gas_allocated_to_enqueued_call);

Comment on lines -447 to 467
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i COULD have had get_bytecode return [byteode, error]... but this is way simpler...

// Copied version of pc maintained in trace builder. The value of pc is evolving based
// on opcode logic and therefore is not maintained here. However, the next opcode in the execution
Expand All @@ -456,12 +472,16 @@ AvmError Execution::execute_enqueued_call(AvmTraceBuilder& trace_builder,
std::stack<uint32_t> debug_counter_stack;
uint32_t counter = 0;
trace_builder.set_call_ptr(context_id);
while ((pc = trace_builder.get_pc()) < bytecode.size()) {
while (is_ok(error) && (pc = trace_builder.get_pc()) < bytecode.size()) {
auto [inst, parse_error] = Deserialization::parse(bytecode, pc);

// FIXME: properly handle case when an instruction fails parsing
// especially first instruction in bytecode
if (!is_ok(error)) {
info("AVM failed to deserialize bytecode at pc: ", pc);
// FIXME: properly handle case when an instruction fails parsing!
// For now, we add a dummy row in main trace to mutate later.
// This error was encountered before any opcodes were executed, but
// we need at least one row in the execution trace to then mutate and say "it halted and consumed all gas!"
trace_builder.op_add(0, 0, 0, 0, OpCode::ADD_8);
error = parse_error;
break;
Comment on lines -462 to 486
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haven't tested this yet, but it should work! We can write a test for deserialization failures in a follow-up

}
Expand Down Expand Up @@ -855,12 +875,17 @@ AvmError Execution::execute_enqueued_call(AvmTraceBuilder& trace_builder,
std::get<uint16_t>(inst.operands.at(3)),
std::get<uint16_t>(inst.operands.at(4)),
std::get<uint16_t>(inst.operands.at(5)));
// TODO: what if an error is encountered on return or call which have already modified stack?
// We hack it in here the logic to change contract address that we are processing
bytecode = trace_builder.get_bytecode(trace_builder.current_ext_call_ctx.contract_address,
/*check_membership=*/false);
debug_counter_stack.push(counter);
counter = 0;
// If opcode errored, nested call won't happen. Don't retrieve bytecode, etc.
if (is_ok(error)) {
try {
bytecode = trace_builder.get_bytecode(trace_builder.current_ext_call_ctx.contract_address,
/*check_membership=*/true);
} catch ([[maybe_unused]] const std::runtime_error& e) {
error = AvmError::NO_BYTECODE_FOUND;
}
debug_counter_stack.push(counter);
counter = 0;
}
break;
}
case OpCode::STATICCALL: {
Expand All @@ -870,11 +895,17 @@ AvmError Execution::execute_enqueued_call(AvmTraceBuilder& trace_builder,
std::get<uint16_t>(inst.operands.at(3)),
std::get<uint16_t>(inst.operands.at(4)),
std::get<uint16_t>(inst.operands.at(5)));
// We hack it in here the logic to change contract address that we are processing
bytecode = trace_builder.get_bytecode(trace_builder.current_ext_call_ctx.contract_address,
/*check_membership=*/false);
debug_counter_stack.push(counter);
counter = 0;
// If opcode errored, nested call won't happen. Don't retrieve bytecode, etc.
if (is_ok(error)) {
try {
bytecode = trace_builder.get_bytecode(trace_builder.current_ext_call_ctx.contract_address,
/*check_membership=*/true);
} catch ([[maybe_unused]] const std::runtime_error& e) {
error = AvmError::NO_BYTECODE_FOUND;
}
debug_counter_stack.push(counter);
counter = 0;
}
break;
}
case OpCode::RETURN: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ void AvmGasTraceBuilder::constrain_gas_for_halt(bool exceptional_halt,
halting_entry.is_halt_or_first_row_in_nested_call = true;

gas_opcode_lookup_counter[halting_entry.opcode]--;

// clear this flag (in case the CALL opcode itself led to an exception)
next_row_is_first_in_nested_call = false;
}

void AvmGasTraceBuilder::constrain_gas_for_top_level_exceptional_halt(uint32_t l2_gas_allocated,
Expand All @@ -172,6 +175,9 @@ void AvmGasTraceBuilder::constrain_gas_for_top_level_exceptional_halt(uint32_t l
halting_entry.is_halt_or_first_row_in_nested_call = true;

gas_opcode_lookup_counter[halting_entry.opcode]--;

// clear this flag (in case the CALL opcode itself led to an exception)
next_row_is_first_in_nested_call = false;
}

void AvmGasTraceBuilder::finalize(std::vector<AvmFullRow<FF>>& main_trace)
Expand Down
2 changes: 2 additions & 0 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ std::string to_name(AvmError error)
return "SIDE EFFECT LIMIT REACHED";
case AvmError::OUT_OF_GAS:
return "OUT OF GAS";
case AvmError::NO_BYTECODE_FOUND:
return "NO BYTECODE FOUND";
default:
throw std::runtime_error("Invalid error type");
break;
Expand Down
95 changes: 49 additions & 46 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ std::vector<uint8_t> AvmTraceBuilder::get_bytecode(const FF contract_address, bo
// nullifier read hint for the contract address
NullifierReadTreeHint nullifier_read_hint = bytecode_hint.contract_instance.membership_hint;

vinfo("contract address: ", contract_address);
vinfo("contract address nullifier: ", contract_address_nullifier);
vinfo("low leaf nullifier: ", nullifier_read_hint.low_leaf_preimage.nullifier);
vinfo("low leaf next nullifier: ", nullifier_read_hint.low_leaf_preimage.next_nullifier);
// If the hinted preimage matches the contract address nullifier, the membership check will prove its existence,
// otherwise the membership check will prove that a low-leaf exists that skips the contract address nullifier.
exists = nullifier_read_hint.low_leaf_preimage.nullifier == contract_address_nullifier;
Expand All @@ -199,18 +203,15 @@ std::vector<uint8_t> AvmTraceBuilder::get_bytecode(const FF contract_address, bo
} else {
// This was a non-membership proof!
// Enforce that the tree access membership checked a low-leaf that skips the contract address nullifier.
// Show that the contract address nullifier meets the non membership conditions (sandwich or max)
ASSERT(contract_address_nullifier < nullifier_read_hint.low_leaf_preimage.nullifier &&
(nullifier_read_hint.low_leaf_preimage.next_nullifier == FF::zero() ||
contract_address_nullifier > nullifier_read_hint.low_leaf_preimage.next_nullifier));
AvmMerkleTreeTraceBuilder::assert_nullifier_non_membership_check(nullifier_read_hint.low_leaf_preimage,
contract_address_nullifier);
}
}

if (exists) {
vinfo("Found bytecode for contract address: ", contract_address);
return bytecode_hint.bytecode;
}
// TODO(dbanks12): handle non-existent bytecode
vinfo("Bytecode not found for contract address: ", contract_address);
throw std::runtime_error("Bytecode not found");
}
Expand Down Expand Up @@ -3774,48 +3775,50 @@ AvmError AvmTraceBuilder::constrain_external_call(OpCode opcode,

pc += Deserialization::get_pc_increment(opcode);

// Save the current gas left in the context before pushing it to stack
// It will be used on RETURN/REVERT/halt to remember how much gas the caller had left.
current_ext_call_ctx.l2_gas_left = gas_trace_builder.get_l2_gas_left();
current_ext_call_ctx.da_gas_left = gas_trace_builder.get_da_gas_left();

// We push the current ext call ctx onto the stack and initialize a new one
current_ext_call_ctx.last_pc = pc;
current_ext_call_ctx.success_offset = resolved_success_offset;
current_ext_call_ctx.tree_snapshot = merkle_tree_trace_builder.get_tree_snapshots();
current_ext_call_ctx.public_data_unique_writes = merkle_tree_trace_builder.get_public_data_unique_writes();
external_call_ctx_stack.emplace(current_ext_call_ctx);

// Ext Ctx setup
std::vector<FF> calldata;
read_slice_from_memory(resolved_args_offset, args_size, calldata);

set_call_ptr(static_cast<uint8_t>(clk));

// Don't try allocating more than the gas that is actually left
const auto l2_gas_allocated_to_nested_call =
std::min(static_cast<uint32_t>(read_gas_l2.val), gas_trace_builder.get_l2_gas_left());
const auto da_gas_allocated_to_nested_call =
std::min(static_cast<uint32_t>(read_gas_da.val), gas_trace_builder.get_da_gas_left());
current_ext_call_ctx = ExtCallCtx{
.context_id = static_cast<uint8_t>(clk),
.parent_id = current_ext_call_ctx.context_id,
.is_static_call = opcode == OpCode::STATICCALL,
.contract_address = read_addr.val,
.calldata = calldata,
.nested_returndata = {},
.last_pc = 0,
.success_offset = 0,
.start_l2_gas_left = l2_gas_allocated_to_nested_call,
.start_da_gas_left = da_gas_allocated_to_nested_call,
.l2_gas_left = l2_gas_allocated_to_nested_call,
.da_gas_left = da_gas_allocated_to_nested_call,
.internal_return_ptr_stack = {},
.tree_snapshot = {},
};
if (is_ok(error)) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just moved into if is ok

// Save the current gas left in the context before pushing it to stack
// It will be used on RETURN/REVERT/halt to remember how much gas the caller had left.
current_ext_call_ctx.l2_gas_left = gas_trace_builder.get_l2_gas_left();
current_ext_call_ctx.da_gas_left = gas_trace_builder.get_da_gas_left();

// We push the current ext call ctx onto the stack and initialize a new one
current_ext_call_ctx.last_pc = pc;
current_ext_call_ctx.success_offset = resolved_success_offset,
current_ext_call_ctx.tree_snapshot = merkle_tree_trace_builder.get_tree_snapshots();
current_ext_call_ctx.public_data_unique_writes = merkle_tree_trace_builder.get_public_data_unique_writes();
external_call_ctx_stack.emplace(current_ext_call_ctx);

// Ext Ctx setup
std::vector<FF> calldata;
read_slice_from_memory(resolved_args_offset, args_size, calldata);

set_call_ptr(static_cast<uint8_t>(clk));

// Don't try allocating more than the gas that is actually left
const auto l2_gas_allocated_to_nested_call =
std::min(static_cast<uint32_t>(read_gas_l2.val), gas_trace_builder.get_l2_gas_left());
const auto da_gas_allocated_to_nested_call =
std::min(static_cast<uint32_t>(read_gas_da.val), gas_trace_builder.get_da_gas_left());
current_ext_call_ctx = ExtCallCtx{
.context_id = static_cast<uint8_t>(clk),
.parent_id = current_ext_call_ctx.context_id,
.is_static_call = opcode == OpCode::STATICCALL,
.contract_address = read_addr.val,
.calldata = calldata,
.nested_returndata = {},
.last_pc = 0,
.success_offset = 0,
.start_l2_gas_left = l2_gas_allocated_to_nested_call,
.start_da_gas_left = da_gas_allocated_to_nested_call,
.l2_gas_left = l2_gas_allocated_to_nested_call,
.da_gas_left = da_gas_allocated_to_nested_call,
.internal_return_ptr_stack = {},
.tree_snapshot = {},
};

allocate_gas_for_call(l2_gas_allocated_to_nested_call, da_gas_allocated_to_nested_call);
set_pc(0);
allocate_gas_for_call(l2_gas_allocated_to_nested_call, da_gas_allocated_to_nested_call);
set_pc(0);
}

return error;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,8 @@ contract AvmTest {
dep::aztec::oracle::debug_log::debug_log("pedersen_hash_with_index");
let _ = pedersen_hash_with_index(args_field);
dep::aztec::oracle::debug_log::debug_log("test_get_contract_instance");
// address should match yarn-project/simulator/src/public/fixtures/index.ts's
// MockedAvmTestContractDataSource.otherContractInstance
test_get_contract_instance(AztecAddress::from_field(0x4444));
dep::aztec::oracle::debug_log::debug_log("get_address");
let _ = get_address();
Expand Down
Loading
Loading