Skip to content

Commit

Permalink
[KGA-26] [KGA-17] feat: add checks on Starknet return values for Dual…
Browse files Browse the repository at this point in the history
…VMToken (#1616)

code-423n4/2024-09-kakarot-findings#51
code-423n4/2024-09-kakarot-findings#42

Checks the return value for DualVMToken. If the SN call returned false,
panicking the EVM call will make the starknet tx revert as a side
effect.

---------

Co-authored-by: Clément Walter <[email protected]>
  • Loading branch information
enitrat and ClementWalter authored Nov 20, 2024
1 parent ae79059 commit 9b4a593
Show file tree
Hide file tree
Showing 8 changed files with 344 additions and 35 deletions.
1 change: 1 addition & 0 deletions cairo/token/src/lib.cairo
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
mod starknet_token;
mod non_standard_starknet_token;
131 changes: 131 additions & 0 deletions cairo/token/src/non_standard_starknet_token.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
//! A non-standard implementation of the ERC20 Token on Starknet.
//! Used for testing purposes with the DualVMToken.
//! Applied the following changes:
//! - `transfer` and `transfer_from` always return false
//! - `approve` always returns false
//! - `name`, `symbol` return a felt instead of a ByteArray

#[starknet::interface]
trait IERC20FeltMetadata<TState> {
fn name(self: @TState) -> felt252;
fn symbol(self: @TState) -> felt252;
fn decimals(self: @TState) -> u8;
}


#[starknet::contract]
mod NonStandardStarknetToken {
use openzeppelin::token::erc20::ERC20Component;
use openzeppelin::token::erc20::interface::{IERC20};
use super::IERC20FeltMetadata;
use starknet::ContractAddress;

component!(path: ERC20Component, storage: erc20, event: ERC20Event);

impl ERC20InternalImpl = ERC20Component::InternalImpl<ContractState>;

#[storage]
struct Storage {
#[substorage(v0)]
erc20: ERC20Component::Storage,
decimals: u8,
name: felt252,
symbol: felt252,
}

#[event]
#[derive(Drop, starknet::Event)]
enum Event {
#[flat]
ERC20Event: ERC20Component::Event
}

#[constructor]
fn constructor(
ref self: ContractState,
name: felt252,
symbol: felt252,
decimals: u8,
initial_supply: u256,
recipient: ContractAddress
) {
self._set_decimals(decimals);

// ERC20 initialization
self.name.write(name);
self.symbol.write(symbol);
self.erc20._mint(recipient, initial_supply);
}

#[external(v0)]
fn mint(ref self: ContractState, to: ContractAddress, amount: u256) {
self.erc20._mint(to, amount);
}

#[abi(embed_v0)]
impl ERC20MetadataImpl of IERC20FeltMetadata<ContractState> {
fn name(self: @ContractState) -> felt252 {
self.name.read()
}

fn symbol(self: @ContractState) -> felt252 {
self.symbol.read()
}

fn decimals(self: @ContractState) -> u8 {
self.decimals.read()
}
}

#[abi(embed_v0)]
impl ERC20 of IERC20<ContractState> {
/// Returns the value of tokens in existence.
fn total_supply(self: @ContractState) -> u256 {
self.erc20.ERC20_total_supply.read()
}

/// Returns the amount of tokens owned by `account`.
fn balance_of(self: @ContractState, account: ContractAddress) -> u256 {
self.erc20.ERC20_balances.read(account)
}

/// Returns the remaining number of tokens that `spender` is
/// allowed to spend on behalf of `owner` through `transfer_from`.
/// This is zero by default.
/// This value changes when `approve` or `transfer_from` are called.
fn allowance(
self: @ContractState, owner: ContractAddress, spender: ContractAddress
) -> u256 {
self.erc20.ERC20_allowances.read((owner, spender))
}


/// Modified to always return false
fn transfer(ref self: ContractState, recipient: ContractAddress, amount: u256) -> bool {
false
}


/// Modified to always return false
fn transfer_from(
ref self: ContractState,
sender: ContractAddress,
recipient: ContractAddress,
amount: u256
) -> bool {
false
}

/// Modified to always return false
fn approve(ref self: ContractState, spender: ContractAddress, amount: u256) -> bool {
false
}
}

#[generate_trait]
impl InternalImpl of InternalTrait {
fn _set_decimals(ref self: ContractState, decimals: u8) {
self.decimals.write(decimals);
}
}
}
9 changes: 8 additions & 1 deletion cairo/token/src/starknet_token.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@ mod StarknetToken {
}

#[constructor]
fn constructor(ref self: ContractState, name: ByteArray, symbol: ByteArray, decimals: u8, initial_supply: u256, recipient: ContractAddress) {
fn constructor(
ref self: ContractState,
name: ByteArray,
symbol: ByteArray,
decimals: u8,
initial_supply: u256,
recipient: ContractAddress
) {
self._set_decimals(decimals);

// ERC20 initialization
Expand Down
34 changes: 31 additions & 3 deletions kakarot_scripts/compile_kakarot.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
# %% Imports
import logging
import multiprocessing as mp
import re
from datetime import datetime

from kakarot_scripts.constants import COMPILED_CONTRACTS, DECLARED_CONTRACTS, NETWORK
from kakarot_scripts.constants import (
CAIRO_DIR,
COMPILED_CONTRACTS,
CONTRACTS,
DECLARED_CONTRACTS,
NETWORK,
)
from kakarot_scripts.utils.starknet import (
compile_contract,
compile_cairo_zero_contract,
compile_scarb_package,
compute_deployed_class_hash,
dump_class_hashes,
locate_scarb_root,
)

mp.set_start_method("fork")
Expand All @@ -22,8 +31,27 @@ def main():
# %% Compile
logger.info(f"ℹ️ Compiling contracts for network {NETWORK['name']}")
initial_time = datetime.now()

# Split contracts into Cairo 0 and Cairo 1 to avoid
# re-compiling the same package multiple times.
cairo0_contracts = []
cairo1_packages = set()

for contract in COMPILED_CONTRACTS:
contract_path = CONTRACTS.get(contract["contract_name"]) or CONTRACTS.get(
re.sub("(?!^)([A-Z]+)", r"_\1", contract["contract_name"]).lower()
)
if contract_path.is_relative_to(CAIRO_DIR):
cairo1_packages.add(locate_scarb_root(contract_path))
else:
cairo0_contracts.append(contract)

with mp.Pool() as pool:
pool.map(compile_contract, COMPILED_CONTRACTS)
cairo0_task = pool.map_async(compile_cairo_zero_contract, cairo0_contracts)
cairo1_task = pool.map_async(compile_scarb_package, cairo1_packages)

cairo0_task.wait()
cairo1_task.wait()
logger.info("ℹ️ Computing deployed class hashes")
with mp.Pool() as pool:
class_hashes = pool.map(compute_deployed_class_hash, DECLARED_CONTRACTS)
Expand Down
2 changes: 2 additions & 0 deletions kakarot_scripts/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ class ChainId(IntEnum):
{"contract_name": "OpenzeppelinAccount", "is_account_contract": True},
{"contract_name": "replace_class", "is_account_contract": False},
{"contract_name": "StarknetToken", "is_account_contract": False},
{"contract_name": "NonStandardStarknetToken", "is_account_contract": False},
{"contract_name": "uninitialized_account_fixture", "is_account_contract": False},
{"contract_name": "uninitialized_account", "is_account_contract": False},
{"contract_name": "UniversalLibraryCaller", "is_account_contract": False},
Expand All @@ -244,6 +245,7 @@ class ChainId(IntEnum):
"OpenzeppelinAccount",
"replace_class",
"StarknetToken",
"NonStandardStarknetToken",
"uninitialized_account_fixture",
"uninitialized_account",
"UniversalLibraryCaller",
Expand Down
72 changes: 44 additions & 28 deletions kakarot_scripts/utils/starknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,40 +341,56 @@ def get_tx_url(tx_hash: int) -> str:
return f"{NETWORK['explorer_url']}/tx/0x{tx_hash:064x}"


def compile_contract(contract):
def locate_scarb_root(contract_path):
current_dir = contract_path.parent
while current_dir != current_dir.parent:
scarb_toml = current_dir / "Scarb.toml"
if scarb_toml.exists():
return current_dir
current_dir = current_dir.parent
return None


def compile_scarb_package(package_path):
logger.info(f"ℹ️ Compiling package {package_path}")
start = datetime.now()
output = subprocess.run(
"scarb build", shell=True, cwd=package_path, capture_output=True
)
if output.returncode != 0:
raise RuntimeError(
f"❌ {package_path} raised:\n{output.stderr}.\nOutput:\n{output.stdout}"
)

elapsed = datetime.now() - start
logger.info(f"✅ {package_path} compiled in {elapsed.total_seconds():.2f}s")


def compile_cairo_zero_contract(contract):
logger.info(f"⏳ Compiling {contract['contract_name']}")
start = datetime.now()
contract_path = CONTRACTS.get(contract["contract_name"]) or CONTRACTS.get(
re.sub("(?!^)([A-Z]+)", r"_\1", contract["contract_name"]).lower()
)

if contract_path.is_relative_to(CAIRO_DIR):
output = subprocess.run(
"scarb build", shell=True, cwd=contract_path.parent, capture_output=True
)
else:
output = subprocess.run(
[
"starknet-compile-deprecated",
contract_path,
"--output",
BUILD_DIR / f"{contract['contract_name']}.json",
"--cairo_path",
str(CAIRO_ZERO_DIR),
*(
["--no_debug_info"]
if NETWORK["type"] is not NetworkType.DEV
else []
),
*(["--account_contract"] if contract["is_account_contract"] else []),
*(
["--disable_hint_validation"]
if NETWORK["type"] is NetworkType.DEV
else []
),
],
capture_output=True,
)
output = subprocess.run(
[
"starknet-compile-deprecated",
contract_path,
"--output",
BUILD_DIR / f"{contract['contract_name']}.json",
"--cairo_path",
str(CAIRO_ZERO_DIR),
*(["--no_debug_info"] if NETWORK["type"] is not NetworkType.DEV else []),
*(["--account_contract"] if contract["is_account_contract"] else []),
*(
["--disable_hint_validation"]
if NETWORK["type"] is NetworkType.DEV
else []
),
],
capture_output=True,
)

if output.returncode != 0:
raise RuntimeError(
Expand Down
17 changes: 14 additions & 3 deletions solidity_contracts/src/CairoPrecompiles/DualVmToken.sol
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ contract DualVmToken is NoDelegateCall {

/// @dev Emitted when an invalid starknet address is used
error InvalidStarknetAddress();
/// @dev Emitted when the return value of a starknet transfer is `false`.
error TransferFailed();
/// @dev Emitted when the return value of a starknet approval is `false`.
error ApprovalFailed();

/*//////////////////////////////////////////////////////////////
METADATA ACCESS
Expand Down Expand Up @@ -244,7 +248,10 @@ contract DualVmToken is NoDelegateCall {
approveCallData[1] = uint256(amountLow);
approveCallData[2] = uint256(amountHigh);

starknetToken.delegatecallCairo("approve", approveCallData);
bool success = abi.decode(starknetToken.delegatecallCairo("approve", approveCallData), (bool));
if (!success) {
revert ApprovalFailed();
}
}

/// @dev Transfer tokens to an evm account
Expand Down Expand Up @@ -285,7 +292,10 @@ contract DualVmToken is NoDelegateCall {
transferCallData[1] = uint256(amountLow);
transferCallData[2] = uint256(amountHigh);

starknetToken.delegatecallCairo("transfer", transferCallData);
bool success = abi.decode(starknetToken.delegatecallCairo("transfer", transferCallData), (bool));
if (!success) {
revert TransferFailed();
}
}

/// @dev Transfer tokens from one evm address to another
Expand Down Expand Up @@ -369,6 +379,7 @@ contract DualVmToken is NoDelegateCall {
transferFromCallData[2] = uint256(amountLow);
transferFromCallData[3] = uint256(amountHigh);

starknetToken.delegatecallCairo("transfer_from", transferFromCallData);
bool success = abi.decode(starknetToken.delegatecallCairo("transfer_from", transferFromCallData), (bool));
require(success, "Transfer failed");
}
}
Loading

0 comments on commit 9b4a593

Please sign in to comment.