diff --git a/src/ckb_chain/config.rs b/src/ckb_chain/config.rs index ef6dc292b..9fe7000c1 100644 --- a/src/ckb_chain/config.rs +++ b/src/ckb_chain/config.rs @@ -1,6 +1,7 @@ +use ckb_types::prelude::PackVec; use clap_serde_derive::ClapSerde; - use secp256k1::SecretKey; +use serde_with::serde_as; use std::{ io::{ErrorKind, Read}, path::PathBuf, @@ -9,14 +10,13 @@ use std::{ use std::str::FromStr; use ckb_types::prelude::Builder; +use ckb_types::prelude::Pack; use ckb_types::H256; use ckb_types::{ core::DepType, packed::{CellDep, OutPoint}, }; use ckb_types::{core::ScriptHashType, packed::CellDepVec}; -use ckb_types::{packed::Script, prelude::Pack}; -use clap::ValueEnum; use clap_serde_derive::clap::{self}; use molecule::prelude::Entity; use serde::Deserialize; @@ -48,10 +48,9 @@ pub struct CkbChainConfig { name = "CKB_UDT_WHITELIST", long = "ckb-udt-whitelist", env, - value_parser, help = "a list of supported UDT scripts" )] - udt_whitelist: Option, + pub udt_whitelist: Option, } impl CkbChainConfig { @@ -104,33 +103,32 @@ impl CkbChainConfig { std::io::Error::new(ErrorKind::InvalidData, "invalid secret key data").into() }) } - - pub fn udt_whitelist(&self) -> Vec { - let udt_infos: Vec = self - .udt_whitelist - .iter() - .map(|arg_info| arg_info.0.iter().map(|u| u.into())) - .flatten() - .collect(); - - return udt_infos; - } } -#[derive(Debug, Clone, Copy, ValueEnum, Deserialize, PartialEq, Eq)] -enum UdtScriptHashType { - Type, - Data, - Data1, - Data2, -} +serde_with::serde_conv!( + ScriptHashTypeWrapper, + ScriptHashType, + |_: &ScriptHashType| { panic!("no support to serialize") }, + |s: String| { + let v = match s.to_lowercase().as_str() { + "type" => ScriptHashType::Type, + "data" => ScriptHashType::Data, + "data1" => ScriptHashType::Data1, + "data2" => ScriptHashType::Data2, + _ => return Err("invalid hash type"), + }; + Ok(v) + } +); +#[serde_as] #[derive(Deserialize, Debug, Clone)] -struct UdtScript { - code_hash: H256, - hash_type: UdtScriptHashType, +pub struct UdtScript { + pub code_hash: H256, + #[serde_as(as = "ScriptHashTypeWrapper")] + pub hash_type: ScriptHashType, /// args may be used in pattern matching - args: String, + pub args: String, } #[derive(Deserialize, Clone, Debug)] @@ -140,44 +138,42 @@ struct UdtCellDep { index: u32, } -/// This is only used fro configuration file parsing -/// it will converted into UdtScriptInfo -#[derive(Deserialize, Clone, Debug)] -struct UdtArgInfo { - name: String, - script: UdtScript, - auto_accept_amount: Option, - cell_deps: Vec, +#[derive(Debug, Clone)] +pub struct CellDepsVec(pub CellDepVec); + +impl<'de> Deserialize<'de> for CellDepsVec { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + let udt_cell_deps: Vec = Deserialize::deserialize(deserializer)?; + let cell_deps: Vec = udt_cell_deps.iter().map(|dep| CellDep::from(dep)).collect(); + Ok(CellDepsVec(cell_deps.pack())) + } } #[derive(Deserialize, Clone, Debug)] -struct UdtArgInfos(Vec); +pub struct UdtArgInfo { + pub name: String, + pub script: UdtScript, + pub auto_accept_amount: Option, + pub cell_deps: CellDepsVec, +} -impl FromStr for UdtArgInfos { - type Err = serde_json::Error; +#[derive(Deserialize, Clone, Debug)] +pub struct UdtCfgInfos(pub Vec); - fn from_str(s: &str) -> std::result::Result { - serde_json::from_str(s) +impl Default for UdtCfgInfos { + fn default() -> Self { + UdtCfgInfos(Vec::new()) } } -impl From<&UdtScript> for Script { - fn from(script: &UdtScript) -> Self { - let _type = match script.hash_type { - UdtScriptHashType::Data => ScriptHashType::Data, - UdtScriptHashType::Data1 => ScriptHashType::Data1, - UdtScriptHashType::Data2 => ScriptHashType::Data2, - UdtScriptHashType::Type => ScriptHashType::Type, - }; - let mut builder = Script::new_builder() - .code_hash(script.code_hash.pack()) - .hash_type(_type.into()); +impl FromStr for UdtCfgInfos { + type Err = serde_json::Error; - let arg = script.args.strip_prefix("0x").unwrap_or(&script.args); - if let Ok(packed_args) = H256::from_str(arg) { - builder = builder.args(packed_args.as_bytes().pack()); - } - builder.build() + fn from_str(s: &str) -> std::result::Result { + serde_json::from_str(s) } } @@ -199,32 +195,3 @@ impl From<&UdtCellDep> for CellDep { .build() } } - -#[derive(Clone, Debug)] -pub struct UdtScriptInfo { - pub name: String, - pub script: Script, - pub arg_pattern: String, - pub auto_accept_amount: Option, - pub cell_deps: CellDepVec, -} - -impl From<&UdtArgInfo> for UdtScriptInfo { - fn from(arg_info: &UdtArgInfo) -> Self { - let cell_deps: Vec = arg_info - .cell_deps - .iter() - .map(|dep| CellDep::from(dep)) - .collect(); - let cell_deps = CellDepVec::new_builder().set(cell_deps).build(); - let script: Script = (&arg_info.script).into(); - let arg_pattern = arg_info.script.args.clone(); - UdtScriptInfo { - name: arg_info.name.clone(), - auto_accept_amount: arg_info.auto_accept_amount, - script, - arg_pattern, - cell_deps, - } - } -} diff --git a/src/ckb_chain/contracts.rs b/src/ckb_chain/contracts.rs index 4de824889..bbb123541 100644 --- a/src/ckb_chain/contracts.rs +++ b/src/ckb_chain/contracts.rs @@ -17,7 +17,10 @@ use ckb_testtool::{ckb_types::bytes::Bytes, context::Context}; #[cfg(test)] use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; -use super::config::UdtScriptInfo; +use super::{ + config::{UdtArgInfo, UdtCfgInfos}, + CkbChainConfig, +}; #[cfg(test)] #[derive(Clone, Debug)] @@ -99,7 +102,7 @@ impl MockContext { contracts_context: Arc::new(ContractsInfo { contract_default_scripts: map, script_cell_deps, - udt_whitelist: vec![], + udt_whitelist: UdtCfgInfos::default(), }), }; debug!("Created mock context to test transactions."); @@ -129,7 +132,7 @@ pub enum Contract { struct ContractsInfo { contract_default_scripts: HashMap, script_cell_deps: HashMap, - udt_whitelist: Vec, + udt_whitelist: UdtCfgInfos, } #[derive(Clone)] @@ -218,7 +221,7 @@ impl From for ContractsContext { } impl ContractsContext { - pub fn new(network: CkbNetwork, udt_whitelist: Vec) -> Self { + pub fn new(network: CkbNetwork, udt_whitelist: UdtCfgInfos) -> Self { match network { #[cfg(test)] CkbNetwork::Mocknet => { @@ -429,7 +432,7 @@ impl ContractsContext { res.build() } - fn get_udt_whitelist(&self) -> &Vec { + fn get_udt_whitelist(&self) -> &UdtCfgInfos { match self { #[cfg(test)] Self::Mock(mock) => &mock.contracts_context.udt_whitelist, @@ -447,15 +450,16 @@ impl ContractsContext { .build() } - pub(crate) fn get_udt_info(&self, udt_script: &Script) -> Option<&UdtScriptInfo> { - for udt in self.get_udt_whitelist().iter() { - if udt.script.code_hash() == udt_script.code_hash() - && udt.script.hash_type() == udt_script.hash_type() + pub(crate) fn get_udt_info(&self, udt_script: &Script) -> Option<&UdtArgInfo> { + for udt in &self.get_udt_whitelist().0 { + let _type: ScriptHashType = udt_script.hash_type().try_into().expect("valid hash type"); + if udt.script.code_hash.pack() == udt_script.code_hash() + && udt.script.hash_type == _type { let args = format!("0x{:x}", udt_script.args().raw_data()); - let pattern = Regex::new(&udt.arg_pattern).expect("invalid expressio"); + let pattern = Regex::new(&udt.script.args).expect("invalid expressio"); if pattern.is_match(&args) { - return Some(udt); + return Some(&udt); } } } @@ -465,13 +469,16 @@ impl ContractsContext { pub fn init_contracts_context( network: Option, - udt_whitelist: Option>, + ckb_chain_config: Option<&CkbChainConfig>, ) -> &'static ContractsContext { static INSTANCE: once_cell::sync::OnceCell = once_cell::sync::OnceCell::new(); + let udt_whitelist = ckb_chain_config + .map(|config| config.udt_whitelist.clone()) + .unwrap_or_default(); INSTANCE.get_or_init(|| { ContractsContext::new( network.unwrap_or(DEFAULT_CONTRACT_NETWORK), - udt_whitelist.unwrap_or(vec![]), + udt_whitelist.unwrap_or_default(), ) }); INSTANCE.get().unwrap() @@ -490,7 +497,7 @@ pub fn get_cell_deps_by_contracts(contracts: Vec) -> CellDepVec { init_contracts_context(None, None).get_cell_deps(contracts) } -fn get_udt_info(script: &Script) -> Option<&UdtScriptInfo> { +fn get_udt_info(script: &Script) -> Option<&UdtArgInfo> { init_contracts_context(None, None).get_udt_info(script) } @@ -499,7 +506,7 @@ pub fn check_udt_script(script: &Script) -> bool { } pub fn get_udt_cell_deps(script: &Script) -> Option { - get_udt_info(script).map(|udt| udt.cell_deps.clone()) + get_udt_info(script).map(|udt| udt.cell_deps.0.clone()) } pub fn is_udt_type_auto_accept(script: &Script, amount: u128) -> bool { diff --git a/src/main.rs b/src/main.rs index e95fbe9dc..9d0bb6732 100644 --- a/src/main.rs +++ b/src/main.rs @@ -50,8 +50,7 @@ pub async fn main() { let ckb_chain_config = config.ckb_chain.expect("ckb-chain service is required for ckb service. \ Add ckb-chain service to the services list in the config file and relevant configuration to the ckb_chain section of the config file."); - let _ = - init_contracts_context(ckb_config.network, Some(ckb_chain_config.udt_whitelist())); + let _ = init_contracts_context(ckb_config.network, Some(&ckb_chain_config)); let ckb_chain_actor = Actor::spawn_linked( Some("ckb-chain".to_string()),