Skip to content

Commit

Permalink
code refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyukang committed Jun 25, 2024
1 parent adcd789 commit 3ca56d1
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 101 deletions.
135 changes: 50 additions & 85 deletions src/ckb_chain/config.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use ckb_types::prelude::PackVec;
use clap_serde_derive::ClapSerde;

use secp256k1::SecretKey;
use serde_with::serde_as;
use std::{
io::{ErrorKind, Read},
path::PathBuf,
Expand All @@ -9,14 +10,13 @@ use std::{
use std::str::FromStr;

use ckb_types::prelude::Builder;
use ckb_types::prelude::Pack;
use ckb_types::H256;
use ckb_types::{
core::DepType,
packed::{CellDep, OutPoint},
};
use ckb_types::{core::ScriptHashType, packed::CellDepVec};
use ckb_types::{packed::Script, prelude::Pack};
use clap::ValueEnum;
use clap_serde_derive::clap::{self};
use molecule::prelude::Entity;
use serde::Deserialize;
Expand Down Expand Up @@ -48,10 +48,9 @@ pub struct CkbChainConfig {
name = "CKB_UDT_WHITELIST",
long = "ckb-udt-whitelist",
env,
value_parser,
help = "a list of supported UDT scripts"
)]
udt_whitelist: Option<UdtArgInfos>,
pub udt_whitelist: Option<UdtCfgInfos>,
}

impl CkbChainConfig {
Expand Down Expand Up @@ -104,33 +103,32 @@ impl CkbChainConfig {
std::io::Error::new(ErrorKind::InvalidData, "invalid secret key data").into()
})
}

pub fn udt_whitelist(&self) -> Vec<UdtScriptInfo> {
let udt_infos: Vec<UdtScriptInfo> = self
.udt_whitelist
.iter()
.map(|arg_info| arg_info.0.iter().map(|u| u.into()))
.flatten()
.collect();

return udt_infos;
}
}

#[derive(Debug, Clone, Copy, ValueEnum, Deserialize, PartialEq, Eq)]
enum UdtScriptHashType {
Type,
Data,
Data1,
Data2,
}
serde_with::serde_conv!(
ScriptHashTypeWrapper,
ScriptHashType,
|_: &ScriptHashType| { panic!("no support to serialize") },
|s: String| {
let v = match s.to_lowercase().as_str() {
"type" => ScriptHashType::Type,
"data" => ScriptHashType::Data,
"data1" => ScriptHashType::Data1,
"data2" => ScriptHashType::Data2,
_ => return Err("invalid hash type"),
};
Ok(v)
}
);

#[serde_as]
#[derive(Deserialize, Debug, Clone)]
struct UdtScript {
code_hash: H256,
hash_type: UdtScriptHashType,
pub struct UdtScript {
pub code_hash: H256,
#[serde_as(as = "ScriptHashTypeWrapper")]
pub hash_type: ScriptHashType,
/// args may be used in pattern matching
args: String,
pub args: String,
}

#[derive(Deserialize, Clone, Debug)]
Expand All @@ -140,44 +138,40 @@ struct UdtCellDep {
index: u32,
}

/// This is only used fro configuration file parsing
/// it will converted into UdtScriptInfo
serde_with::serde_conv!(
CellDepVecWrapper,
CellDepVec,
|_: &CellDepVec| { panic!("no support to serialize") },
|s: Vec<UdtCellDep>| -> Result<CellDepVec, &'static str> {
let cell_deps: Vec<CellDep> = s.iter().map(|dep| CellDep::from(dep)).collect();
Ok(cell_deps.pack())
}
);

#[serde_as]
#[derive(Deserialize, Clone, Debug)]
struct UdtArgInfo {
name: String,
script: UdtScript,
auto_accept_amount: Option<u128>,
cell_deps: Vec<UdtCellDep>,
pub struct UdtArgInfo {
pub name: String,
pub script: UdtScript,
pub auto_accept_amount: Option<u128>,
#[serde_as(as = "CellDepVecWrapper")]
pub cell_deps: CellDepVec,
}

#[derive(Deserialize, Clone, Debug)]
struct UdtArgInfos(Vec<UdtArgInfo>);
pub struct UdtCfgInfos(pub Vec<UdtArgInfo>);

impl FromStr for UdtArgInfos {
type Err = serde_json::Error;

fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
serde_json::from_str(s)
impl Default for UdtCfgInfos {
fn default() -> Self {
UdtCfgInfos(Vec::new())
}
}

impl From<&UdtScript> for Script {
fn from(script: &UdtScript) -> Self {
let _type = match script.hash_type {
UdtScriptHashType::Data => ScriptHashType::Data,
UdtScriptHashType::Data1 => ScriptHashType::Data1,
UdtScriptHashType::Data2 => ScriptHashType::Data2,
UdtScriptHashType::Type => ScriptHashType::Type,
};
let mut builder = Script::new_builder()
.code_hash(script.code_hash.pack())
.hash_type(_type.into());
impl FromStr for UdtCfgInfos {
type Err = serde_json::Error;

let arg = script.args.strip_prefix("0x").unwrap_or(&script.args);
if let Ok(packed_args) = H256::from_str(arg) {
builder = builder.args(packed_args.as_bytes().pack());
}
builder.build()
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
serde_json::from_str(s)
}
}

Expand All @@ -199,32 +193,3 @@ impl From<&UdtCellDep> for CellDep {
.build()
}
}

#[derive(Clone, Debug)]
pub struct UdtScriptInfo {
pub name: String,
pub script: Script,
pub arg_pattern: String,
pub auto_accept_amount: Option<u128>,
pub cell_deps: CellDepVec,
}

impl From<&UdtArgInfo> for UdtScriptInfo {
fn from(arg_info: &UdtArgInfo) -> Self {
let cell_deps: Vec<CellDep> = arg_info
.cell_deps
.iter()
.map(|dep| CellDep::from(dep))
.collect();
let cell_deps = CellDepVec::new_builder().set(cell_deps).build();
let script: Script = (&arg_info.script).into();
let arg_pattern = arg_info.script.args.clone();
UdtScriptInfo {
name: arg_info.name.clone(),
auto_accept_amount: arg_info.auto_accept_amount,
script,
arg_pattern,
cell_deps,
}
}
}
35 changes: 21 additions & 14 deletions src/ckb_chain/contracts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ use ckb_testtool::{ckb_types::bytes::Bytes, context::Context};
#[cfg(test)]
use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};

use super::config::UdtScriptInfo;
use super::{
config::{UdtArgInfo, UdtCfgInfos},
CkbChainConfig,
};

#[cfg(test)]
#[derive(Clone, Debug)]
Expand Down Expand Up @@ -99,7 +102,7 @@ impl MockContext {
contracts_context: Arc::new(ContractsInfo {
contract_default_scripts: map,
script_cell_deps,
udt_whitelist: vec![],
udt_whitelist: UdtCfgInfos::default(),
}),
};
debug!("Created mock context to test transactions.");
Expand Down Expand Up @@ -129,7 +132,7 @@ pub enum Contract {
struct ContractsInfo {
contract_default_scripts: HashMap<Contract, Script>,
script_cell_deps: HashMap<Contract, CellDepVec>,
udt_whitelist: Vec<UdtScriptInfo>,
udt_whitelist: UdtCfgInfos,
}

#[derive(Clone)]
Expand Down Expand Up @@ -218,7 +221,7 @@ impl From<MockContext> for ContractsContext {
}

impl ContractsContext {
pub fn new(network: CkbNetwork, udt_whitelist: Vec<UdtScriptInfo>) -> Self {
pub fn new(network: CkbNetwork, udt_whitelist: UdtCfgInfos) -> Self {
match network {
#[cfg(test)]
CkbNetwork::Mocknet => {
Expand Down Expand Up @@ -429,7 +432,7 @@ impl ContractsContext {
res.build()
}

fn get_udt_whitelist(&self) -> &Vec<UdtScriptInfo> {
fn get_udt_whitelist(&self) -> &UdtCfgInfos {
match self {
#[cfg(test)]
Self::Mock(mock) => &mock.contracts_context.udt_whitelist,
Expand All @@ -447,15 +450,16 @@ impl ContractsContext {
.build()
}

pub(crate) fn get_udt_info(&self, udt_script: &Script) -> Option<&UdtScriptInfo> {
for udt in self.get_udt_whitelist().iter() {
if udt.script.code_hash() == udt_script.code_hash()
&& udt.script.hash_type() == udt_script.hash_type()
pub(crate) fn get_udt_info(&self, udt_script: &Script) -> Option<&UdtArgInfo> {
for udt in &self.get_udt_whitelist().0 {
let _type: ScriptHashType = udt_script.hash_type().try_into().expect("valid hash type");
if udt.script.code_hash.pack() == udt_script.code_hash()
&& udt.script.hash_type == _type
{
let args = format!("0x{:x}", udt_script.args().raw_data());
let pattern = Regex::new(&udt.arg_pattern).expect("invalid expressio");
let pattern = Regex::new(&udt.script.args).expect("invalid expressio");
if pattern.is_match(&args) {
return Some(udt);
return Some(&udt);
}
}
}
Expand All @@ -465,13 +469,16 @@ impl ContractsContext {

pub fn init_contracts_context(
network: Option<CkbNetwork>,
udt_whitelist: Option<Vec<UdtScriptInfo>>,
ckb_chain_config: Option<&CkbChainConfig>,
) -> &'static ContractsContext {
static INSTANCE: once_cell::sync::OnceCell<ContractsContext> = once_cell::sync::OnceCell::new();
let udt_whitelist = ckb_chain_config
.map(|config| config.udt_whitelist.clone())
.unwrap_or_default();
INSTANCE.get_or_init(|| {
ContractsContext::new(
network.unwrap_or(DEFAULT_CONTRACT_NETWORK),
udt_whitelist.unwrap_or(vec![]),
udt_whitelist.unwrap_or_default(),
)
});
INSTANCE.get().unwrap()
Expand All @@ -490,7 +497,7 @@ pub fn get_cell_deps_by_contracts(contracts: Vec<Contract>) -> CellDepVec {
init_contracts_context(None, None).get_cell_deps(contracts)
}

fn get_udt_info(script: &Script) -> Option<&UdtScriptInfo> {
fn get_udt_info(script: &Script) -> Option<&UdtArgInfo> {
init_contracts_context(None, None).get_udt_info(script)
}

Expand Down
3 changes: 1 addition & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ pub async fn main() {
let ckb_chain_config = config.ckb_chain.expect("ckb-chain service is required for ckb service. \
Add ckb-chain service to the services list in the config file and relevant configuration to the ckb_chain section of the config file.");

let _ =
init_contracts_context(ckb_config.network, Some(ckb_chain_config.udt_whitelist()));
let _ = init_contracts_context(ckb_config.network, Some(&ckb_chain_config));

let ckb_chain_actor = Actor::spawn_linked(
Some("ckb-chain".to_string()),
Expand Down

0 comments on commit 3ca56d1

Please sign in to comment.