diff --git a/src/tests/account/ethereum/test_eth_account.cairo b/src/tests/account/ethereum/test_eth_account.cairo index a239de7f6..1f26956e0 100644 --- a/src/tests/account/ethereum/test_eth_account.cairo +++ b/src/tests/account/ethereum/test_eth_account.cairo @@ -316,7 +316,7 @@ fn test_multicall() { to: erc20.contract_address, selector: selectors::transfer, calldata: calldata2.span() }; - // Bundle calls and exeute + // Bundle calls and execute calls.append(call1); calls.append(call2); let ret = account.__execute__(calls); diff --git a/src/tests/account/starknet/test_account.cairo b/src/tests/account/starknet/test_account.cairo index 13dd7f36e..7bfa8e377 100644 --- a/src/tests/account/starknet/test_account.cairo +++ b/src/tests/account/starknet/test_account.cairo @@ -308,7 +308,7 @@ fn test_multicall() { to: erc20.contract_address, selector: selectors::transfer, calldata: calldata2.span() }; - // Bundle calls and exeute + // Bundle calls and execute let calls = array![call1, call2]; let ret = account.__execute__(calls); diff --git a/src/tests/presets.cairo b/src/tests/presets.cairo index 5edfe8b18..87e7cb2f4 100644 --- a/src/tests/presets.cairo +++ b/src/tests/presets.cairo @@ -1,4 +1,4 @@ -// mod test_account; +mod test_account; mod test_erc1155; mod test_erc20; mod test_erc721; diff --git a/src/tests/presets/test_account.cairo b/src/tests/presets/test_account.cairo index 61dc1c803..3aaf59543 100644 --- a/src/tests/presets/test_account.cairo +++ b/src/tests/presets/test_account.cairo @@ -1,68 +1,65 @@ use core::num::traits::Zero; -use openzeppelin::account::AccountComponent::{OwnerAdded, OwnerRemoved}; use openzeppelin::account::interface::ISRC6_ID; use openzeppelin::introspection::interface::ISRC5_ID; use openzeppelin::presets::AccountUpgradeable; +use openzeppelin::presets::interfaces::account::{ + AccountUpgradeableABISafeDispatcher, AccountUpgradeableABISafeDispatcherTrait +}; use openzeppelin::presets::interfaces::{ AccountUpgradeableABIDispatcher, AccountUpgradeableABIDispatcherTrait }; use openzeppelin::tests::account::starknet::common::{ - assert_only_event_owner_added, assert_event_owner_removed -}; -use openzeppelin::tests::account::starknet::common::{ - deploy_erc20, SIGNED_TX_DATA, SignedTransactionData + get_accept_ownership_signature, deploy_erc20, SIGNED_TX_DATA, }; -use openzeppelin::tests::mocks::account_mocks::SnakeAccountMock; -use openzeppelin::tests::upgrades::common::assert_only_event_upgraded; +use openzeppelin::tests::account::starknet::common::{AccountSpyHelpers, SignedTransactionData}; +use openzeppelin::tests::upgrades::common::UpgradeableSpyHelpers; +use openzeppelin::tests::utils::constants::stark::{KEY_PAIR, KEY_PAIR_2}; use openzeppelin::tests::utils::constants::{ - PUBKEY, NEW_PUBKEY, SALT, ZERO, CALLER, RECIPIENT, OTHER, QUERY_OFFSET, QUERY_VERSION, - MIN_TRANSACTION_VERSION, CLASS_HASH_ZERO + SALT, QUERY_OFFSET, QUERY_VERSION, MIN_TRANSACTION_VERSION }; +use openzeppelin::tests::utils::constants::{ZERO, CALLER, RECIPIENT, OTHER, CLASS_HASH_ZERO}; +use openzeppelin::tests::utils::signing::{StarkKeyPair, StarkKeyPairExt}; use openzeppelin::tests::utils; -use openzeppelin::token::erc20::interface::{IERC20DispatcherTrait, IERC20Dispatcher}; +use openzeppelin::token::erc20::interface::IERC20DispatcherTrait; use openzeppelin::utils::selectors; use openzeppelin::utils::serde::SerializedAppend; +use snforge_std::{ + cheat_signature_global, cheat_transaction_version_global, cheat_transaction_hash_global +}; +use snforge_std::{spy_events, test_address, start_cheat_caller_address}; use starknet::account::Call; -use starknet::testing; use starknet::{ContractAddress, ClassHash}; -fn CLASS_HASH() -> felt252 { - AccountUpgradeable::TEST_CLASS_HASH -} - -fn V2_CLASS_HASH() -> ClassHash { - SnakeAccountMock::TEST_CLASS_HASH.try_into().unwrap() -} - // // Setup // -fn setup_dispatcher() -> AccountUpgradeableABIDispatcher { - let calldata = array![PUBKEY]; - let target = utils::deploy(CLASS_HASH(), calldata); - utils::drop_event(target); +fn declare_v2_class() -> ClassHash { + utils::declare_class("SnakeAccountMock").class_hash +} + +fn setup_dispatcher(key_pair: StarkKeyPair) -> (ContractAddress, AccountUpgradeableABIDispatcher) { + let calldata = array![key_pair.public_key]; + let account_address = utils::declare_and_deploy("AccountUpgradeable", calldata); + let dispatcher = AccountUpgradeableABIDispatcher { contract_address: account_address }; - AccountUpgradeableABIDispatcher { contract_address: target } + (account_address, dispatcher) } fn setup_dispatcher_with_data( - data: Option<@SignedTransactionData> -) -> AccountUpgradeableABIDispatcher { - testing::set_version(MIN_TRANSACTION_VERSION); + key_pair: StarkKeyPair, data: SignedTransactionData +) -> (AccountUpgradeableABIDispatcher, felt252) { + let account_class = utils::declare_class("AccountUpgradeable"); + let calldata = array![key_pair.public_key]; + let contract_address = utils::deploy(account_class, calldata); + let account_dispatcher = AccountUpgradeableABIDispatcher { contract_address }; - let mut calldata = array![]; - if data.is_some() { - let data = data.unwrap(); - testing::set_signature(array![*data.r, *data.s].span()); - testing::set_transaction_hash(*data.transaction_hash); - - calldata.append(*data.public_key); - } else { - calldata.append(PUBKEY); - } - let address = utils::deploy(CLASS_HASH(), calldata); - AccountUpgradeableABIDispatcher { contract_address: address } + cheat_signature_global(array![data.r, data.s].span()); + cheat_transaction_hash_global(data.tx_hash); + cheat_transaction_version_global(MIN_TRANSACTION_VERSION); + start_cheat_caller_address(contract_address, ZERO()); + + (account_dispatcher, account_class.class_hash.into()) } // @@ -72,12 +69,15 @@ fn setup_dispatcher_with_data( #[test] fn test_constructor() { let mut state = AccountUpgradeable::contract_state_for_testing(); - AccountUpgradeable::constructor(ref state, PUBKEY); + let mut spy = spy_events(); + let key_pair = KEY_PAIR(); + let account_address = test_address(); + AccountUpgradeable::constructor(ref state, key_pair.public_key); - assert_only_event_owner_added(ZERO(), PUBKEY); + spy.assert_only_event_owner_added(account_address, key_pair.public_key); let public_key = AccountUpgradeable::AccountMixinImpl::get_public_key(@state); - assert_eq!(public_key, PUBKEY); + assert_eq!(public_key, key_pair.public_key); let supports_isrc5 = AccountUpgradeable::AccountMixinImpl::supports_interface(@state, ISRC5_ID); assert!(supports_isrc5); @@ -92,44 +92,66 @@ fn test_constructor() { #[test] fn test_public_key_setter_and_getter() { - let dispatcher = setup_dispatcher(); + let key_pair = KEY_PAIR(); + let (account_address, dispatcher) = setup_dispatcher(key_pair); + let mut spy = spy_events(); - testing::set_contract_address(dispatcher.contract_address); + let new_key_pair = KEY_PAIR_2(); + let signature = get_accept_ownership_signature( + account_address, key_pair.public_key, new_key_pair + ); + start_cheat_caller_address(account_address, account_address); + dispatcher.set_public_key(new_key_pair.public_key, signature); - dispatcher.set_public_key(NEW_PUBKEY, get_accept_ownership_signature()); - let public_key = dispatcher.get_public_key(); - assert_eq!(public_key, NEW_PUBKEY); + assert_eq!(dispatcher.get_public_key(), new_key_pair.public_key); - assert_event_owner_removed(dispatcher.contract_address, PUBKEY); - assert_only_event_owner_added(dispatcher.contract_address, NEW_PUBKEY); + spy.assert_event_owner_removed(dispatcher.contract_address, key_pair.public_key); + spy.assert_only_event_owner_added(dispatcher.contract_address, new_key_pair.public_key); } #[test] fn test_public_key_setter_and_getter_camel() { - let dispatcher = setup_dispatcher(); + let key_pair = KEY_PAIR(); + let (account_address, dispatcher) = setup_dispatcher(key_pair); + let mut spy = spy_events(); - testing::set_contract_address(dispatcher.contract_address); + let new_key_pair = KEY_PAIR_2(); + let signature = get_accept_ownership_signature( + account_address, key_pair.public_key, new_key_pair + ); + start_cheat_caller_address(account_address, account_address); + dispatcher.setPublicKey(new_key_pair.public_key, signature); - dispatcher.setPublicKey(NEW_PUBKEY, get_accept_ownership_signature()); - let public_key = dispatcher.getPublicKey(); - assert_eq!(public_key, NEW_PUBKEY); + assert_eq!(dispatcher.getPublicKey(), new_key_pair.public_key); - assert_event_owner_removed(dispatcher.contract_address, PUBKEY); - assert_only_event_owner_added(dispatcher.contract_address, NEW_PUBKEY); + spy.assert_event_owner_removed(dispatcher.contract_address, key_pair.public_key); + spy.assert_only_event_owner_added(dispatcher.contract_address, new_key_pair.public_key); } #[test] -#[should_panic(expected: ('Account: unauthorized', 'ENTRYPOINT_FAILED'))] +#[should_panic(expected: ('Account: unauthorized',))] fn test_set_public_key_different_account() { - let dispatcher = setup_dispatcher(); - dispatcher.set_public_key(NEW_PUBKEY, get_accept_ownership_signature()); + let key_pair = KEY_PAIR(); + let (account_address, dispatcher) = setup_dispatcher(key_pair); + + let new_key_pair = KEY_PAIR_2(); + let signature = get_accept_ownership_signature( + account_address, key_pair.public_key, new_key_pair + ); + dispatcher.set_public_key(new_key_pair.public_key, signature); } #[test] -#[should_panic(expected: ('Account: unauthorized', 'ENTRYPOINT_FAILED'))] +#[should_panic(expected: ('Account: unauthorized',))] fn test_setPublicKey_different_account() { - let dispatcher = setup_dispatcher(); - dispatcher.setPublicKey(NEW_PUBKEY, get_accept_ownership_signature()); + let key_pair = KEY_PAIR(); + let (account_address, dispatcher) = setup_dispatcher(key_pair); + + let new_key_pair = KEY_PAIR_2(); + let signature = get_accept_ownership_signature( + account_address, key_pair.public_key, new_key_pair + ); + dispatcher.setPublicKey(new_key_pair.public_key, signature); } // @@ -137,16 +159,12 @@ fn test_setPublicKey_different_account() { // fn is_valid_sig_dispatcher() -> (AccountUpgradeableABIDispatcher, felt252, Array) { - let dispatcher = setup_dispatcher(); - - let data = SIGNED_TX_DATA(); - let hash = data.transaction_hash; - let mut signature = array![data.r, data.s]; - - testing::set_contract_address(dispatcher.contract_address); - dispatcher.set_public_key(data.public_key, get_accept_ownership_signature()); + let key_pair = KEY_PAIR(); + let (_, dispatcher) = setup_dispatcher(key_pair); - (dispatcher, hash, signature) + let data = SIGNED_TX_DATA(key_pair); + let signature = array![data.r, data.s]; + (dispatcher, data.tx_hash, signature) } #[test] @@ -159,44 +177,64 @@ fn test_is_valid_signature() { #[test] fn test_is_valid_signature_bad_sig() { - let (dispatcher, hash, _) = is_valid_sig_dispatcher(); + let (dispatcher, tx_hash, _) = is_valid_sig_dispatcher(); + let bad_signature = array!['BAD', 'SIG']; - let bad_signature = array![0x987, 0x564]; - - let is_valid = dispatcher.is_valid_signature(hash, bad_signature.clone()); + let is_valid = dispatcher.is_valid_signature(tx_hash, bad_signature); assert!(is_valid.is_zero(), "Should reject invalid signature"); } +#[test] +fn test_is_valid_signature_invalid_len_sig() { + let (dispatcher, tx_hash, _) = is_valid_sig_dispatcher(); + let invalid_len_sig = array!['INVALID_LEN']; + + let is_valid = dispatcher.is_valid_signature(tx_hash, invalid_len_sig); + assert!(is_valid.is_zero(), "Should reject signature of invalid length"); +} + #[test] fn test_isValidSignature() { - let (dispatcher, hash, signature) = is_valid_sig_dispatcher(); + let (dispatcher, tx_hash, signature) = is_valid_sig_dispatcher(); - let is_valid = dispatcher.isValidSignature(hash, signature); + let is_valid = dispatcher.isValidSignature(tx_hash, signature); assert_eq!(is_valid, starknet::VALIDATED); } #[test] fn test_isValidSignature_bad_sig() { - let (dispatcher, hash, _) = is_valid_sig_dispatcher(); + let (dispatcher, tx_hash, _) = is_valid_sig_dispatcher(); + let bad_signature = array!['BAD', 'SIG']; - let bad_signature = array![0x987, 0x564]; - - let is_valid = dispatcher.isValidSignature(hash, bad_signature); + let is_valid = dispatcher.isValidSignature(tx_hash, bad_signature); assert!(is_valid.is_zero(), "Should reject invalid signature"); } +#[test] +fn test_isValidSignature_invalid_len_sig() { + let (dispatcher, tx_hash, _) = is_valid_sig_dispatcher(); + let invalid_len_sig = array!['INVALID_LEN']; + + let is_valid = dispatcher.isValidSignature(tx_hash, invalid_len_sig); + assert!(is_valid.is_zero(), "Should reject signature of invalid length"); +} + // // supports_interface // #[test] fn test_supports_interface() { - let dispatcher = setup_dispatcher(); + let key_pair = KEY_PAIR(); + let (_, dispatcher) = setup_dispatcher(key_pair); + let supports_isrc5 = dispatcher.supports_interface(ISRC5_ID); assert!(supports_isrc5); + let supports_isrc6 = dispatcher.supports_interface(ISRC6_ID); assert!(supports_isrc6); - let doesnt_support_0x123 = !dispatcher.supports_interface(0x123); + + let doesnt_support_0x123 = !dispatcher.supports_interface('DUMMY_INTERFACE_ID'); assert!(doesnt_support_0x123); } @@ -206,110 +244,120 @@ fn test_supports_interface() { #[test] fn test_validate_deploy() { - let account = setup_dispatcher_with_data(Option::Some(@SIGNED_TX_DATA())); + let key_pair = KEY_PAIR(); + let (account, class_hash) = setup_dispatcher_with_data(key_pair, SIGNED_TX_DATA(key_pair)); // `__validate_deploy__` does not directly use the passed arguments. Their // values are already integrated in the tx hash. The passed arguments in this // testing context are decoupled from the signature and have no effect on the test. - let is_valid = account.__validate_deploy__(CLASS_HASH(), SALT, PUBKEY); + let is_valid = account.__validate_deploy__(class_hash, SALT, key_pair.public_key); assert_eq!(is_valid, starknet::VALIDATED); } #[test] -#[should_panic(expected: ('Account: invalid signature', 'ENTRYPOINT_FAILED'))] +#[should_panic(expected: ('Account: invalid signature',))] fn test_validate_deploy_invalid_signature_data() { - let mut data = SIGNED_TX_DATA(); - data.transaction_hash += 1; - let account = setup_dispatcher_with_data(Option::Some(@data)); + let key_pair = KEY_PAIR(); + let mut data = SIGNED_TX_DATA(key_pair); + data.tx_hash += 1; + let (account, class_hash) = setup_dispatcher_with_data(key_pair, data); - account.__validate_deploy__(CLASS_HASH(), SALT, PUBKEY); + account.__validate_deploy__(class_hash, SALT, key_pair.public_key); } #[test] -#[should_panic(expected: ('Account: invalid signature', 'ENTRYPOINT_FAILED'))] +#[should_panic(expected: ('Account: invalid signature',))] fn test_validate_deploy_invalid_signature_length() { - let account = setup_dispatcher_with_data(Option::Some(@SIGNED_TX_DATA())); - let mut signature = array![0x1]; + let key_pair = KEY_PAIR(); + let (account, class_hash) = setup_dispatcher_with_data(key_pair, SIGNED_TX_DATA(key_pair)); - testing::set_signature(signature.span()); + let invalid_len_sig = array!['INVALID_LEN']; + cheat_signature_global(invalid_len_sig.span()); - account.__validate_deploy__(CLASS_HASH(), SALT, PUBKEY); + account.__validate_deploy__(class_hash, SALT, key_pair.public_key); } #[test] -#[should_panic(expected: ('Account: invalid signature', 'ENTRYPOINT_FAILED'))] +#[should_panic(expected: ('Account: invalid signature',))] fn test_validate_deploy_empty_signature() { - let account = setup_dispatcher_with_data(Option::Some(@SIGNED_TX_DATA())); + let key_pair = KEY_PAIR(); + let (account, class_hash) = setup_dispatcher_with_data(key_pair, SIGNED_TX_DATA(key_pair)); + let empty_sig = array![]; + cheat_signature_global(empty_sig.span()); - testing::set_signature(empty_sig.span()); - account.__validate_deploy__(CLASS_HASH(), SALT, PUBKEY); + account.__validate_deploy__(class_hash, SALT, key_pair.public_key); } #[test] fn test_validate_declare() { - let account = setup_dispatcher_with_data(Option::Some(@SIGNED_TX_DATA())); + let key_pair = KEY_PAIR(); + let (account, class_hash) = setup_dispatcher_with_data(key_pair, SIGNED_TX_DATA(key_pair)); // `__validate_declare__` does not directly use the class_hash argument. Its // value is already integrated in the tx hash. The class_hash argument in this // testing context is decoupled from the signature and has no effect on the test. - let is_valid = account.__validate_declare__(CLASS_HASH()); + let is_valid = account.__validate_declare__(class_hash); assert_eq!(is_valid, starknet::VALIDATED); } #[test] -#[should_panic(expected: ('Account: invalid signature', 'ENTRYPOINT_FAILED'))] +#[should_panic(expected: ('Account: invalid signature',))] fn test_validate_declare_invalid_signature_data() { - let mut data = SIGNED_TX_DATA(); - data.transaction_hash += 1; - let account = setup_dispatcher_with_data(Option::Some(@data)); + let key_pair = KEY_PAIR(); + let mut data = SIGNED_TX_DATA(key_pair); + data.tx_hash += 1; + let (account, class_hash) = setup_dispatcher_with_data(key_pair, data); - account.__validate_declare__(CLASS_HASH()); + account.__validate_declare__(class_hash); } #[test] -#[should_panic(expected: ('Account: invalid signature', 'ENTRYPOINT_FAILED'))] +#[should_panic(expected: ('Account: invalid signature',))] fn test_validate_declare_invalid_signature_length() { - let account = setup_dispatcher_with_data(Option::Some(@SIGNED_TX_DATA())); - let mut signature = array![0x1]; + let key_pair = KEY_PAIR(); + let (account, class_hash) = setup_dispatcher_with_data(key_pair, SIGNED_TX_DATA(key_pair)); - testing::set_signature(signature.span()); + let invalid_len_sig = array!['INVALID_LEN']; + cheat_signature_global(invalid_len_sig.span()); - account.__validate_declare__(CLASS_HASH()); + account.__validate_declare__(class_hash); } #[test] -#[should_panic(expected: ('Account: invalid signature', 'ENTRYPOINT_FAILED'))] +#[should_panic(expected: ('Account: invalid signature',))] fn test_validate_declare_empty_signature() { - let account = setup_dispatcher_with_data(Option::Some(@SIGNED_TX_DATA())); - let empty_sig = array![]; + let key_pair = KEY_PAIR(); + let (account, class_hash) = setup_dispatcher_with_data(key_pair, SIGNED_TX_DATA(key_pair)); - testing::set_signature(empty_sig.span()); + let empty_sig = array![]; + cheat_signature_global(empty_sig.span()); - account.__validate_declare__(CLASS_HASH()); + account.__validate_declare__(class_hash); } fn test_execute_with_version(version: Option) { - let data = SIGNED_TX_DATA(); - let account = setup_dispatcher_with_data(Option::Some(@data)); + let key_pair = KEY_PAIR(); + let data = SIGNED_TX_DATA(key_pair); + let (account, _) = setup_dispatcher_with_data(key_pair, data); let erc20 = deploy_erc20(account.contract_address, 1000); // Craft call and add to calls array let amount: u256 = 200; + let recipient = RECIPIENT(); let mut calldata = array![]; - calldata.append_serde(RECIPIENT()); + calldata.append_serde(recipient); calldata.append_serde(amount); let call = Call { to: erc20.contract_address, selector: selectors::transfer, calldata: calldata.span() }; - let mut calls = array![]; - calls.append(call); + let calls = array![call]; // Handle version for test - if version.is_some() { - testing::set_version(version.unwrap()); + if let Option::Some(version) = version { + cheat_transaction_version_global(version) } // Execute @@ -317,7 +365,7 @@ fn test_execute_with_version(version: Option) { // Assert that the transfer was successful assert_eq!(erc20.balance_of(account.contract_address), 800, "Should have remainder"); - assert_eq!(erc20.balance_of(RECIPIENT()), amount, "Should have transferred"); + assert_eq!(erc20.balance_of(recipient), amount, "Should have transferred"); // Test return value let mut call_serialized_retval = *ret.at(0); @@ -327,7 +375,7 @@ fn test_execute_with_version(version: Option) { #[test] fn test_execute() { - test_execute_with_version(Option::None(())); + test_execute_with_version(Option::None); } #[test] @@ -341,7 +389,7 @@ fn test_execute_query_version() { } #[test] -#[should_panic(expected: ('Account: invalid tx version', 'ENTRYPOINT_FAILED'))] +#[should_panic(expected: ('Account: invalid tx version',))] fn test_execute_invalid_query_version() { test_execute_with_version(Option::Some(QUERY_OFFSET)); } @@ -352,34 +400,37 @@ fn test_execute_future_query_version() { } #[test] -#[should_panic(expected: ('Account: invalid tx version', 'ENTRYPOINT_FAILED'))] +#[should_panic(expected: ('Account: invalid tx version',))] fn test_execute_invalid_version() { test_execute_with_version(Option::Some(MIN_TRANSACTION_VERSION - 1)); } #[test] fn test_validate() { - let calls = array![]; - let account = setup_dispatcher_with_data(Option::Some(@SIGNED_TX_DATA())); + let key_pair = KEY_PAIR(); + let (account, _) = setup_dispatcher_with_data(key_pair, SIGNED_TX_DATA(key_pair)); + let calls = array![]; let is_valid = account.__validate__(calls); assert_eq!(is_valid, starknet::VALIDATED); } #[test] -#[should_panic(expected: ('Account: invalid signature', 'ENTRYPOINT_FAILED'))] +#[should_panic(expected: ('Account: invalid signature',))] fn test_validate_invalid() { - let calls = array![]; - let mut data = SIGNED_TX_DATA(); - data.transaction_hash += 1; - let account = setup_dispatcher_with_data(Option::Some(@data)); + let key_pair = KEY_PAIR(); + let mut data = SIGNED_TX_DATA(key_pair); + data.tx_hash += 1; + let (account, _) = setup_dispatcher_with_data(key_pair, data); + let calls = array![]; account.__validate__(calls); } #[test] fn test_multicall() { - let account = setup_dispatcher_with_data(Option::Some(@SIGNED_TX_DATA())); + let key_pair = KEY_PAIR(); + let (account, _) = setup_dispatcher_with_data(key_pair, SIGNED_TX_DATA(key_pair)); let erc20 = deploy_erc20(account.contract_address, 1000); let recipient1 = RECIPIENT(); let recipient2 = OTHER(); @@ -403,7 +454,7 @@ fn test_multicall() { to: erc20.contract_address, selector: selectors::transfer, calldata: calldata2.span() }; - // Bundle calls and exeute + // Bundle calls and execute calls.append(call1); calls.append(call2); let ret = account.__execute__(calls); @@ -426,23 +477,23 @@ fn test_multicall() { #[test] fn test_multicall_zero_calls() { - let account = setup_dispatcher_with_data(Option::Some(@SIGNED_TX_DATA())); - let mut calls = array![]; + let key_pair = KEY_PAIR(); + let (account, _) = setup_dispatcher_with_data(key_pair, SIGNED_TX_DATA(key_pair)); + let calls = array![]; let response = account.__execute__(calls); assert!(response.is_empty()); } #[test] -#[should_panic(expected: ('Account: invalid caller', 'ENTRYPOINT_FAILED'))] +#[should_panic(expected: ('Account: invalid caller',))] fn test_account_called_from_contract() { - let account = setup_dispatcher(); - let calls = array![]; - - testing::set_contract_address(account.contract_address); - testing::set_caller_address(CALLER()); + let key_pair = KEY_PAIR(); + let (account_address, dispatcher) = setup_dispatcher(key_pair); - account.__execute__(calls); + let calls = array![]; + start_cheat_caller_address(account_address, CALLER()); + dispatcher.__execute__(calls); } // @@ -450,89 +501,72 @@ fn test_account_called_from_contract() { // #[test] -#[should_panic(expected: ('Account: unauthorized', 'ENTRYPOINT_FAILED',))] +#[should_panic(expected: ('Account: unauthorized',))] fn test_upgrade_access_control() { - let v1 = setup_dispatcher(); - v1.upgrade(CLASS_HASH_ZERO()); + let key_pair = KEY_PAIR(); + let (_, v1_dispatcher) = setup_dispatcher(key_pair); + + v1_dispatcher.upgrade(CLASS_HASH_ZERO()); } #[test] -#[should_panic(expected: ('Class hash cannot be zero', 'ENTRYPOINT_FAILED',))] +#[should_panic(expected: ('Class hash cannot be zero',))] fn test_upgrade_with_class_hash_zero() { - let v1 = setup_dispatcher(); + let key_pair = KEY_PAIR(); + let (account_address, v1_dispatcher) = setup_dispatcher(key_pair); - set_contract_and_caller(v1.contract_address); - v1.upgrade(CLASS_HASH_ZERO()); + start_cheat_caller_address(account_address, account_address); + v1_dispatcher.upgrade(CLASS_HASH_ZERO()); } #[test] fn test_upgraded_event() { - let v1 = setup_dispatcher(); - let v2_class_hash = V2_CLASS_HASH(); + let key_pair = KEY_PAIR(); + let (account_address, v1_dispatcher) = setup_dispatcher(key_pair); + let mut spy = spy_events(); - set_contract_and_caller(v1.contract_address); - v1.upgrade(v2_class_hash); + let v2_class_hash = declare_v2_class(); + start_cheat_caller_address(account_address, account_address); + v1_dispatcher.upgrade(v2_class_hash); - assert_only_event_upgraded(v1.contract_address, v2_class_hash); + spy.assert_only_event_upgraded(account_address, v2_class_hash); } #[test] -#[should_panic(expected: ('ENTRYPOINT_NOT_FOUND',))] +#[feature("safe_dispatcher")] fn test_v2_missing_camel_selector() { - let v1 = setup_dispatcher(); - let v2_class_hash = V2_CLASS_HASH(); + let key_pair = KEY_PAIR(); + let (account_address, v1_dispatcher) = setup_dispatcher(key_pair); + + let v2_class_hash = declare_v2_class(); + start_cheat_caller_address(account_address, account_address); + v1_dispatcher.upgrade(v2_class_hash); - set_contract_and_caller(v1.contract_address); - v1.upgrade(v2_class_hash); + let safe_dispatcher = AccountUpgradeableABISafeDispatcher { contract_address: account_address }; + let result = safe_dispatcher.getPublicKey(); - let dispatcher = AccountUpgradeableABIDispatcher { contract_address: v1.contract_address }; - dispatcher.getPublicKey(); + utils::assert_entrypoint_not_found_error(result, selector!("getPublicKey"), account_address) } #[test] fn test_state_persists_after_upgrade() { - let v1 = setup_dispatcher(); - let v2_class_hash = V2_CLASS_HASH(); + let key_pair = KEY_PAIR(); + let (account_address, v1_dispatcher) = setup_dispatcher(key_pair); - set_contract_and_caller(v1.contract_address); - let dispatcher = AccountUpgradeableABIDispatcher { contract_address: v1.contract_address }; + let new_key_pair = KEY_PAIR_2(); + let accept_ownership_sig = get_accept_ownership_signature( + account_address, key_pair.public_key, new_key_pair + ); + start_cheat_caller_address(account_address, account_address); + v1_dispatcher.set_public_key(new_key_pair.public_key, accept_ownership_sig); - dispatcher.set_public_key(NEW_PUBKEY, get_accept_ownership_signature()); + let expected_public_key = new_key_pair.public_key; + let camel_public_key = v1_dispatcher.getPublicKey(); + assert_eq!(camel_public_key, expected_public_key); - let camel_public_key = dispatcher.getPublicKey(); - assert_eq!(camel_public_key, NEW_PUBKEY); - - v1.upgrade(v2_class_hash); - let snake_public_key = dispatcher.get_public_key(); - - assert_eq!(snake_public_key, camel_public_key); -} - -// -// Helpers -// + let v2_class_hash = declare_v2_class(); + v1_dispatcher.upgrade(v2_class_hash); + let snake_public_key = v1_dispatcher.get_public_key(); -fn set_contract_and_caller(address: ContractAddress) { - testing::set_contract_address(address); - testing::set_caller_address(address); -} - -fn get_accept_ownership_signature() -> Span { - // 0xecfdac5cd0e60434b672a97ba94520b9acfe629d123a883005e45afa25ccea = - // PoseidonTrait::new() - // .update_with('StarkNet Message') - // .update_with('accept_ownership') - // .update_with(dispatcher.contract_address) - // .update_with(PUBKEY) - // .finalize(); - - // This signature was computed using starknet js sdk from the following values: - // - private_key: '1234' - // - public_key: 0x26da8d11938b76025862be14fdb8b28438827f73e75e86f7bfa38b196951fa7 - // - msg_hash: 0xecfdac5cd0e60434b672a97ba94520b9acfe629d123a883005e45afa25ccea - array![ - 0x379fcc17e39513c19b5d97143e919ec3a5d9f59d4ae80fef83e037bc9275240, - 0x7719ade3541834755ed48a4373a19872c4032937344d832e2743df677b0a43 - ] - .span() + assert_eq!(snake_public_key, expected_public_key); } diff --git a/src/tests/presets/test_erc1155.cairo b/src/tests/presets/test_erc1155.cairo index 57c597612..62ff34fbb 100644 --- a/src/tests/presets/test_erc1155.cairo +++ b/src/tests/presets/test_erc1155.cairo @@ -868,11 +868,9 @@ fn test_v2_missing_camel_selector() { v1.upgrade(v2_class_hash); let safe_dispatcher = IERC1155CamelSafeDispatcher { contract_address: v1.contract_address }; - let panic_data = safe_dispatcher.balanceOf(owner, TOKEN_ID).unwrap_err(); + let result = safe_dispatcher.balanceOf(owner, TOKEN_ID); - utils::assert_entrypoint_not_found_error( - panic_data, selector!("balanceOf"), v1.contract_address - ) + utils::assert_entrypoint_not_found_error(result, selector!("balanceOf"), v1.contract_address) } #[test] diff --git a/src/tests/presets/test_erc20.cairo b/src/tests/presets/test_erc20.cairo index 8d10d7ab8..5630256d9 100644 --- a/src/tests/presets/test_erc20.cairo +++ b/src/tests/presets/test_erc20.cairo @@ -480,11 +480,9 @@ fn test_v2_missing_camel_selector() { let safe_dispatcher = ERC20UpgradeableABISafeDispatcher { contract_address: v1.contract_address }; - let mut panic_data = safe_dispatcher.totalSupply().unwrap_err(); + let result = safe_dispatcher.totalSupply(); - utils::assert_entrypoint_not_found_error( - panic_data, selector!("totalSupply"), v1.contract_address - ) + utils::assert_entrypoint_not_found_error(result, selector!("totalSupply"), v1.contract_address) } #[test] diff --git a/src/tests/presets/test_eth_account.cairo b/src/tests/presets/test_eth_account.cairo index 81ea869ea..c30da1c7e 100644 --- a/src/tests/presets/test_eth_account.cairo +++ b/src/tests/presets/test_eth_account.cairo @@ -499,11 +499,9 @@ fn test_v2_missing_camel_selector() { let safe_dispatcher = EthAccountUpgradeableABISafeDispatcher { contract_address: v1.contract_address }; - let panic_data = safe_dispatcher.getPublicKey().unwrap_err(); + let result = safe_dispatcher.getPublicKey(); - utils::assert_entrypoint_not_found_error( - panic_data, selector!("getPublicKey"), v1.contract_address - ) + utils::assert_entrypoint_not_found_error(result, selector!("getPublicKey"), v1.contract_address) } #[test] diff --git a/src/tests/utils/common.cairo b/src/tests/utils/common.cairo index 853be448b..f453a32fb 100644 --- a/src/tests/utils/common.cairo +++ b/src/tests/utils/common.cairo @@ -1,5 +1,5 @@ use core::to_byte_array::FormatAsByteArray; -use starknet::ContractAddress; +use starknet::{ContractAddress, SyscallResult}; /// Converts panic data into a string (ByteArray). /// @@ -9,7 +9,7 @@ pub fn panic_data_to_byte_array(panic_data: Array) -> ByteArray { let mut panic_data = panic_data.span(); // Remove BYTE_ARRAY_MAGIC from the panic data. - panic_data.pop_front().unwrap(); + panic_data.pop_front().expect('Empty panic data provided'); match Serde::::deserialize(ref panic_data) { Option::Some(string) => string, @@ -36,15 +36,23 @@ pub impl IntoBase16String> of IntoBase16StringTrait { } } -/// Asserts that the panic data is an "Entrypoint not found" error, following the starknet foundry -/// emitted error format. -pub fn assert_entrypoint_not_found_error( - panic_data: Array, selector: felt252, contract_address: ContractAddress +/// Asserts that the syscall result of a call failed with an "Entrypoint not found" error, +/// following the starknet foundry emitted error format. +pub fn assert_entrypoint_not_found_error>( + result: SyscallResult, selector: felt252, contract_address: ContractAddress ) { - let expected_panic_message = format!( - "Entry point selector {} not found in contract {}", - selector.into_base_16_string(), - contract_address.into_base_16_string() - ); - assert!(panic_data_to_byte_array(panic_data) == expected_panic_message); + if let Result::Err(panic_data) = result { + let expected_panic_message = format!( + "Entry point selector {} not found in contract {}", + selector.into_base_16_string(), + contract_address.into_base_16_string() + ); + let actual_panic_message = panic_data_to_byte_array(panic_data); + assert!( + actual_panic_message == expected_panic_message, + "Got unexpected panic message: ${actual_panic_message}" + ); + } else { + panic!("${selector} call was expected to fail, but succeeded"); + } } diff --git a/src/tests/utils/deployment.cairo b/src/tests/utils/deployment.cairo index dae9a2071..2b50ab15f 100644 --- a/src/tests/utils/deployment.cairo +++ b/src/tests/utils/deployment.cairo @@ -1,7 +1,6 @@ use core::starknet::SyscallResultTrait; use openzeppelin::tests::utils::panic_data_to_byte_array; -use snforge_std::{declare, get_class_hash, ContractClass, ContractClassTrait}; -use snforge_std::{start_cheat_caller_address, stop_cheat_caller_address}; +use snforge_std::{ContractClass, ContractClassTrait}; use starknet::ContractAddress; pub fn deploy(contract_class: ContractClass, calldata: Array) -> ContractAddress { @@ -24,23 +23,26 @@ pub fn deploy_at( pub fn deploy_another_at( existing: ContractAddress, target_address: ContractAddress, calldata: Array ) { - let class_hash = get_class_hash(existing); + let class_hash = snforge_std::get_class_hash(existing); let contract_class = ContractClassTrait::new(class_hash); deploy_at(contract_class, target_address, calldata) } pub fn declare_class(contract_name: ByteArray) -> ContractClass { - declare(contract_name).unwrap_syscall() + match snforge_std::declare(contract_name) { + Result::Ok(contract_class) => contract_class, + Result::Err(panic_data) => panic!("{}", panic_data_to_byte_array(panic_data)) + } } pub fn declare_and_deploy(contract_name: ByteArray, calldata: Array) -> ContractAddress { - let contract_class = declare(contract_name).unwrap_syscall(); + let contract_class = declare_class(contract_name); deploy(contract_class, calldata) } pub fn declare_and_deploy_at( contract_name: ByteArray, target_address: ContractAddress, calldata: Array ) { - let contract_class = declare(contract_name).unwrap_syscall(); + let contract_class = declare_class(contract_name); deploy_at(contract_class, target_address, calldata) }