diff --git a/src/backend/src/lib.rs b/src/backend/src/lib.rs index 24c958aa1c..d3c58afc39 100644 --- a/src/backend/src/lib.rs +++ b/src/backend/src/lib.rs @@ -1,4 +1,5 @@ use crate::guards::{caller_is_allowed, caller_is_not_anonymous}; +use crate::token::{add_to_user_token, remove_from_user_token}; use candid::{CandidType, Deserialize, Nat, Principal}; use core::ops::Deref; use ethers_core::abi::ethereum_types::{Address, H160, U256, U64}; @@ -25,6 +26,7 @@ use std::cell::RefCell; use std::str::FromStr; mod guards; +mod token; type VMem = VirtualMemory; type ConfigCell = StableCell>, VMem>; @@ -34,7 +36,6 @@ const CONFIG_MEMORY_ID: MemoryId = MemoryId::new(0); const USER_TOKEN_MEMORY_ID: MemoryId = MemoryId::new(1); const MAX_SYMBOL_LENGTH: usize = 20; -const MAX_TOKEN_LIST_LENGTH: usize = 100; thread_local! { static MEMORY_MANAGER: RefCell> = RefCell::new( @@ -366,42 +367,23 @@ fn add_user_token(token: Token) { } } let stored_principal = StoredPrincipal(ic_cdk::caller()); - mutate_state(|s| { - let Candid(mut tokens) = s.user_token.get(&stored_principal).unwrap_or_default(); - match tokens.iter().position(|t| { - t.chain_id == token.chain_id && parse_eth_address(&t.contract_address) == addr - }) { - Some(p) => { - tokens[p] = token; - } - None => { - if tokens.len() == MAX_TOKEN_LIST_LENGTH { - ic_cdk::trap(&format!( - "Token list length should not exceed {MAX_TOKEN_LIST_LENGTH}" - )); - } - tokens.push(token); - } - } - s.user_token.insert(stored_principal, Candid(tokens)) - }); + + let find = + |t: &Token| t.chain_id == token.chain_id && parse_eth_address(&t.contract_address) == addr; + + mutate_state(|s| add_to_user_token(stored_principal, &mut s.user_token, &token, &find)); } #[update(guard = "caller_is_not_anonymous")] fn remove_user_token(token_id: TokenId) { let addr = parse_eth_address(&token_id.contract_address); let stored_principal = StoredPrincipal(ic_cdk::caller()); - mutate_state(|s| match s.user_token.get(&stored_principal) { - None => (), - Some(Candid(mut tokens)) => { - if let Some(p) = tokens.iter().position(|t| { - t.chain_id == token_id.chain_id && parse_eth_address(&t.contract_address) == addr - }) { - tokens.swap_remove(p); - s.user_token.insert(stored_principal, Candid(tokens)); - } - } - }); + + let find = |t: &Token| { + t.chain_id == token_id.chain_id && parse_eth_address(&t.contract_address) == addr + }; + + mutate_state(|s| remove_from_user_token(stored_principal, &mut s.user_token, &find)); } #[query(guard = "caller_is_not_anonymous")] diff --git a/src/backend/src/token.rs b/src/backend/src/token.rs new file mode 100644 index 0000000000..ecd64e23e7 --- /dev/null +++ b/src/backend/src/token.rs @@ -0,0 +1,50 @@ +use crate::{Candid, StoredPrincipal, VMem}; +use candid::{CandidType, Deserialize}; +use ic_stable_structures::StableBTreeMap; + +const MAX_TOKEN_LIST_LENGTH: usize = 100; + +pub fn add_to_user_token( + stored_principal: StoredPrincipal, + user_token: &mut StableBTreeMap>, VMem>, + token: &T, + find: &dyn Fn(&T) -> bool, +) where + T: for<'a> Deserialize<'a> + CandidType + Clone, +{ + let Candid(mut tokens) = user_token.get(&stored_principal).unwrap_or_default(); + + match tokens.iter().position(find) { + Some(p) => { + tokens[p] = token.clone(); + } + None => { + if tokens.len() == MAX_TOKEN_LIST_LENGTH { + ic_cdk::trap(&format!( + "Token list length should not exceed {MAX_TOKEN_LIST_LENGTH}" + )); + } + tokens.push(token.clone()); + } + } + + user_token.insert(stored_principal, Candid(tokens)); +} + +pub fn remove_from_user_token( + stored_principal: StoredPrincipal, + user_token: &mut StableBTreeMap>, VMem>, + find: &dyn Fn(&T) -> bool, +) where + T: for<'a> Deserialize<'a> + CandidType, +{ + match user_token.get(&stored_principal) { + None => (), + Some(Candid(mut tokens)) => { + if let Some(p) = tokens.iter().position(find) { + tokens.swap_remove(p); + user_token.insert(stored_principal, Candid(tokens)); + } + } + } +}