diff --git a/examples/capi.c b/examples/capi.c index 3d098931..785532cf 100644 --- a/examples/capi.c +++ b/examples/capi.c @@ -66,6 +66,14 @@ static void get_balance(struct evm_uint256be* result, *result = balance(context, address); } +static size_t get_code_size(struct evm_context* context, const struct evm_address* address) +{ + printf("EVM-C: CODESIZE @"); + print_address(address); + printf("\n"); + return 0; +} + static size_t get_code(const uint8_t** code, struct evm_context* context, const struct evm_address* address) @@ -121,6 +129,7 @@ static const struct evm_context_fn_table ctx_fn_table = { get_storage, set_storage, get_balance, + get_code_size, get_code, selfdestruct, call, diff --git a/include/evm.h b/include/evm.h index 32f633fe..786f2bf4 100644 --- a/include/evm.h +++ b/include/evm.h @@ -336,14 +336,21 @@ typedef void (*evm_get_balance_fn)(struct evm_uint256be* result, struct evm_context* context, const struct evm_address* address); +/// Get code size callback function. +/// +/// This callback function is used by an EVM to get the size of the code stored +/// in the account at the given address. For accounts not having a code, this +/// function returns 0. +typedef size_t (*evm_get_code_size_fn)(struct evm_context* context, + const struct evm_address* address); + /// Get code callback function. /// /// This callback function is used by an EVM to get the code of a contract of /// given address. /// -/// @param[out] result_code The pointer to the contract code. This argument is -/// optional. If NULL is provided, the host MUST only -/// return the code size. It will be freed by the Client. +/// @param[out] result_code The pointer to the contract code. +/// It will be freed by the Client. /// @param context The pointer to the Host execution context. /// @see ::evm_context. /// @param address The address of the contract. @@ -408,6 +415,7 @@ struct evm_context_fn_table { evm_get_storage_fn get_storage; evm_set_storage_fn set_storage; evm_get_balance_fn get_balance; + evm_get_code_size_fn get_code_size; evm_get_code_fn get_code; evm_selfdestruct_fn selfdestruct; evm_call_fn call; diff --git a/libevmjit/Ext.cpp b/libevmjit/Ext.cpp index c47dcb6d..2ca59629 100644 --- a/libevmjit/Ext.cpp +++ b/libevmjit/Ext.cpp @@ -136,6 +136,23 @@ llvm::Function* getGetBalanceFunc(llvm::Module* _module) return func; } +llvm::Function* getGetCodeSizeFunc(llvm::Module* _module) +{ + static const auto funcName = "evm.codesize"; + auto func = _module->getFunction(funcName); + if (!func) + { + auto addrTy = llvm::IntegerType::get(_module->getContext(), 160); + auto fty = llvm::FunctionType::get( + Type::Size, {Type::EnvPtr, addrTy->getPointerTo()}, false); + func = llvm::Function::Create(fty, llvm::Function::ExternalLinkage, funcName, _module); + func->addAttribute(2, llvm::Attribute::ReadOnly); + func->addAttribute(2, llvm::Attribute::NoAlias); + func->addAttribute(2, llvm::Attribute::NoCapture); + } + return func; +} + llvm::Function* getGetCodeFunc(llvm::Module* _module) { static const auto funcName = "evm.code"; @@ -482,13 +499,12 @@ MemoryRef Ext::extcode(llvm::Value* _address) llvm::Value* Ext::extcodesize(llvm::Value* _address) { - auto func = getGetCodeFunc(getModule()); + auto func = getGetCodeSizeFunc(getModule()); auto addrTy = m_builder.getIntNTy(160); auto address = Endianness::toBE(m_builder, m_builder.CreateTrunc(_address, addrTy)); auto pAddr = m_builder.CreateBitCast(getArgAlloca(), addrTy->getPointerTo()); m_builder.CreateStore(address, pAddr); - auto ignoreCode = llvm::ConstantPointerNull::get(Type::BytePtr->getPointerTo()); - auto size = createCABICall(func, {ignoreCode, getRuntimeManager().getEnvPtr(), pAddr}); + auto size = createCABICall(func, {getRuntimeManager().getEnvPtr(), pAddr}); return m_builder.CreateZExt(size, Type::Word); } diff --git a/libevmjit/JIT.cpp b/libevmjit/JIT.cpp index 70f4bd39..96c67678 100644 --- a/libevmjit/JIT.cpp +++ b/libevmjit/JIT.cpp @@ -249,6 +249,7 @@ class SymbolResolver : public llvm::SectionMemoryManager .Case("evm.sload", reinterpret_cast(jit.host->get_storage)) .Case("evm.sstore", reinterpret_cast(jit.host->set_storage)) .Case("evm.balance", reinterpret_cast(jit.host->get_balance)) + .Case("evm.codesize", reinterpret_cast(jit.host->get_code_size)) .Case("evm.code", reinterpret_cast(jit.host->get_code)) .Case("evm.selfdestruct", reinterpret_cast(jit.host->selfdestruct)) .Case("evm.call", reinterpret_cast(call_v2))