diff --git a/Cargo.lock b/Cargo.lock index d64c3c1b7..22bfe827c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4678,7 +4678,7 @@ checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" name = "smt_trie" version = "0.1.1" dependencies = [ - "ethereum-types", + "alloy", "hex-literal", "plonky2", "rand", diff --git a/smt_trie/Cargo.toml b/smt_trie/Cargo.toml index 6df15c11c..c1a4cdb9d 100644 --- a/smt_trie/Cargo.toml +++ b/smt_trie/Cargo.toml @@ -12,7 +12,7 @@ homepage.workspace = true keywords.workspace = true [dependencies] -ethereum-types.workspace = true +alloy.workspace = true plonky2.workspace = true rand.workspace = true serde = { workspace = true, features = ["derive", "rc"] } diff --git a/smt_trie/src/bits.rs b/smt_trie/src/bits.rs index 0fbcac0e1..40758ba38 100644 --- a/smt_trie/src/bits.rs +++ b/smt_trie/src/bits.rs @@ -1,6 +1,6 @@ use std::ops::Add; -use ethereum_types::{BigEndianHash, H256, U256}; +use alloy::primitives::{B256, U256}; use serde::{Deserialize, Serialize}; pub type Bit = bool; @@ -22,11 +22,11 @@ impl From for Bits { } } -impl From for Bits { - fn from(packed: H256) -> Self { +impl From for Bits { + fn from(packed: B256) -> Self { Bits { count: 256, - packed: packed.into_uint(), + packed: packed.into(), } } } @@ -38,7 +38,7 @@ impl Add for Bits { assert!(self.count + rhs.count <= 256, "Overflow"); Self { count: self.count + rhs.count, - packed: self.packed * (U256::one() << rhs.count) + rhs.packed, + packed: self.packed * (U256::from(1) << rhs.count) + rhs.packed, } } } @@ -47,7 +47,7 @@ impl Bits { pub const fn empty() -> Self { Bits { count: 0, - packed: U256::zero(), + packed: U256::ZERO, } } @@ -57,7 +57,7 @@ impl Bits { pub fn pop_next_bit(&mut self) -> Bit { assert!(!self.is_empty(), "Cannot pop from empty bits"); - let b = !(self.packed & U256::one()).is_zero(); + let b = !(self.packed & U256::from(1)).is_zero(); self.packed >>= 1; self.count -= 1; b @@ -65,11 +65,11 @@ impl Bits { pub fn get_bit(&self, i: usize) -> Bit { assert!(i < self.count, "Index out of bounds"); - !(self.packed & (U256::one() << (self.count - 1 - i))).is_zero() + !(self.packed & (U256::from(1) << (self.count - 1 - i))).is_zero() } pub fn push_bit(&mut self, bit: Bit) { - self.packed = self.packed * 2 + U256::from(bit as u64); + self.packed = self.packed * U256::from(2) + U256::from(bit as u64); self.count += 1; } diff --git a/smt_trie/src/code.rs b/smt_trie/src/code.rs index dd6b142b9..d3aefb864 100644 --- a/smt_trie/src/code.rs +++ b/smt_trie/src/code.rs @@ -1,6 +1,6 @@ /// Functions to hash contract bytecode using Poseidon. /// See `hashContractBytecode()` in https://github.com/0xPolygonHermez/zkevm-commonjs/blob/main/src/smt-utils.js for reference implementation. -use ethereum_types::U256; +use alloy::primitives::U256; use plonky2::field::types::Field; use plonky2::hash::poseidon::{self, Poseidon}; diff --git a/smt_trie/src/keys.rs b/smt_trie/src/keys.rs index 1f122adbb..e254361bd 100644 --- a/smt_trie/src/keys.rs +++ b/smt_trie/src/keys.rs @@ -2,7 +2,7 @@ /// This module contains functions to generate keys for the SMT. /// See https://github.com/0xPolygonHermez/zkevm-commonjs/blob/main/src/smt-utils.js for reference implementation. -use ethereum_types::{Address, U256}; +use alloy::primitives::{Address, U256}; use plonky2::{field::types::Field, hash::poseidon::Poseidon}; use crate::smt::{Key, F}; @@ -74,8 +74,8 @@ pub fn key_storage(addr: Address, slot: U256) -> Key { let capacity: [F; 4] = { let mut arr = [F::ZERO; 12]; for i in 0..4 { - arr[2 * i] = F::from_canonical_u32(slot.0[i] as u32); - arr[2 * i + 1] = F::from_canonical_u32((slot.0[i] >> 32) as u32); + arr[2 * i] = F::from_canonical_u32(slot.as_limbs()[i] as u32); + arr[2 * i + 1] = F::from_canonical_u32((slot.as_limbs()[i] >> 32) as u32); } F::poseidon(arr)[0..4].try_into().unwrap() }; diff --git a/smt_trie/src/smt.rs b/smt_trie/src/smt.rs index f9ea73319..c192a8f49 100644 --- a/smt_trie/src/smt.rs +++ b/smt_trie/src/smt.rs @@ -3,7 +3,7 @@ use std::borrow::Borrow; use std::collections::{HashMap, HashSet}; -use ethereum_types::U256; +use alloy::primitives::U256; use plonky2::field::goldilocks_field::GoldilocksField; use plonky2::field::types::{Field, PrimeField64}; use plonky2::hash::poseidon::{Poseidon, PoseidonHash}; @@ -145,7 +145,7 @@ impl Smt { .copied() .unwrap_or_default() .is_zero()); - U256::zero() + U256::ZERO }; } else { let b = keys.get_bit(level as usize); @@ -347,7 +347,7 @@ impl Smt { /// Delete the key in the SMT. pub fn delete(&mut self, key: Key) { self.kv_store.remove(&key); - self.set(key, U256::zero()); + self.set(key, U256::ZERO); } /// Set the key to the hash in the SMT. @@ -416,7 +416,7 @@ impl Smt { &self, keys: I, ) -> Vec { - let mut v = vec![U256::zero(); 2]; // For empty hash node. + let mut v = vec![U256::ZERO; 2]; // For empty hash node. let key = Key(self.root.elements); let mut keys_to_include = HashSet::new(); @@ -433,7 +433,7 @@ impl Smt { serialize(self, key, &mut v, Bits::empty(), &keys_to_include); if v.len() == 2 { - v.extend([U256::zero(); 2]); + v.extend([U256::ZERO; 2]); } v } @@ -457,7 +457,7 @@ fn serialize( if !keys_to_include.contains(&cur_bits) || smt.db.get_node(&key).is_none() { let index = v.len(); - v.push(HASH_TYPE.into()); + v.push(U256::from_limbs([0, 0, 0, HASH_TYPE.into()])); v.push(key2u(key)); index } else if let Some(node) = smt.db.get_node(&key) { @@ -473,7 +473,7 @@ fn serialize( let rem_key = Key(node.0[0..4].try_into().unwrap()); let val = limbs2f(val_a); let index = v.len(); - v.push(LEAF_TYPE.into()); + v.push(U256::from_limbs([0, 0, 0, LEAF_TYPE.into()])); v.push(key2u(rem_key)); v.push(val); index @@ -481,14 +481,24 @@ fn serialize( let key_left = Key(node.0[0..4].try_into().unwrap()); let key_right = Key(node.0[4..8].try_into().unwrap()); let index = v.len(); - v.push(INTERNAL_TYPE.into()); - v.push(U256::zero()); - v.push(U256::zero()); - let i_left = - serialize(smt, key_left, v, cur_bits.add_bit(false), keys_to_include).into(); + v.push(U256::from_limbs([0, 0, 0, INTERNAL_TYPE.into()])); + v.push(U256::ZERO); + v.push(U256::ZERO); + let i_left = U256::from(serialize( + smt, + key_left, + v, + cur_bits.add_bit(false), + keys_to_include, + )); v[index + 1] = i_left; - let i_right = - serialize(smt, key_right, v, cur_bits.add_bit(true), keys_to_include).into(); + let i_right = U256::from(serialize( + smt, + key_right, + v, + cur_bits.add_bit(true), + keys_to_include, + )); v[index + 2] = i_right; index } @@ -507,15 +517,16 @@ pub fn hash_serialize_u256(v: &[U256]) -> U256 { } fn _hash_serialize(v: &[U256], ptr: usize) -> HashOut { - assert!(v[ptr] <= u8::MAX.into()); - match v[ptr].as_u64() as u8 { + let byte: u8 = v[ptr].try_into().unwrap(); + match byte { HASH_TYPE => u2h(v[ptr + 1]), INTERNAL_TYPE => { let mut node = Node([F::ZERO; 12]); for b in 0..2 { let child_index = v[ptr + 1 + b]; - let child_hash = _hash_serialize(v, child_index.as_usize()); + let child_index: usize = child_index.try_into().unwrap(); + let child_hash = _hash_serialize(v, child_index); node.0[b * 4..(b + 1) * 4].copy_from_slice(&child_hash.elements); } F::poseidon(node.0)[0..4].try_into().unwrap() diff --git a/smt_trie/src/smt_test.rs b/smt_trie/src/smt_test.rs index c086e17dc..c500646e0 100644 --- a/smt_trie/src/smt_test.rs +++ b/smt_trie/src/smt_test.rs @@ -1,4 +1,4 @@ -use ethereum_types::U256; +use alloy::primitives::U256; use plonky2::field::types::{Field, Sample}; use plonky2::hash::hash_types::HashOut; use rand::seq::SliceRandom; @@ -18,11 +18,11 @@ fn test_add_and_rem() { let mut smt = Smt::::default(); let k = Key(F::rand_array()); - let v = U256(thread_rng().gen()); + let v = U256::from_le_bytes(thread_rng().gen::<[u8; 32]>()); smt.set(k, v); assert_eq!(v, smt.get(k)); - smt.set(k, U256::zero()); + smt.set(k, U256::ZERO); assert_eq!(smt.root.elements, [F::ZERO; 4]); let ser = smt.serialize(); @@ -48,7 +48,7 @@ fn test_add_and_rem_hermez() { .map(F::from_canonical_u64) ); - smt.set(k, U256::zero()); + smt.set(k, U256::ZERO); assert_eq!(smt.root.elements, [F::ZERO; 4]); let ser = smt.serialize(); @@ -60,8 +60,8 @@ fn test_update_element_1() { let mut smt = Smt::::default(); let k = Key(F::rand_array()); - let v1 = U256(thread_rng().gen()); - let v2 = U256(thread_rng().gen()); + let v1 = U256::from_le_bytes(thread_rng().gen::<[u8; 32]>()); + let v2 = U256::from_le_bytes(thread_rng().gen::<[u8; 32]>()); smt.set(k, v1); let root = smt.root; smt.set(k, v2); @@ -79,12 +79,12 @@ fn test_add_shared_element_2() { let k1 = Key(F::rand_array()); let k2 = Key(F::rand_array()); assert_ne!(k1, k2, "Unlucky"); - let v1 = U256(thread_rng().gen()); - let v2 = U256(thread_rng().gen()); + let v1 = U256::from_le_bytes(thread_rng().gen::<[u8; 32]>()); + let v2 = U256::from_le_bytes(thread_rng().gen::<[u8; 32]>()); smt.set(k1, v1); smt.set(k2, v2); - smt.set(k1, U256::zero()); - smt.set(k2, U256::zero()); + smt.set(k1, U256::ZERO); + smt.set(k2, U256::ZERO); assert_eq!(smt.root.elements, [F::ZERO; 4]); let ser = smt.serialize(); @@ -98,15 +98,15 @@ fn test_add_shared_element_3() { let k1 = Key(F::rand_array()); let k2 = Key(F::rand_array()); let k3 = Key(F::rand_array()); - let v1 = U256(thread_rng().gen()); - let v2 = U256(thread_rng().gen()); - let v3 = U256(thread_rng().gen()); + let v1 = U256::from_le_bytes(thread_rng().gen::<[u8; 32]>()); + let v2 = U256::from_le_bytes(thread_rng().gen::<[u8; 32]>()); + let v3 = U256::from_le_bytes(thread_rng().gen::<[u8; 32]>()); smt.set(k1, v1); smt.set(k2, v2); smt.set(k3, v3); - smt.set(k1, U256::zero()); - smt.set(k2, U256::zero()); - smt.set(k3, U256::zero()); + smt.set(k1, U256::ZERO); + smt.set(k2, U256::ZERO); + smt.set(k3, U256::ZERO); assert_eq!(smt.root.elements, [F::ZERO; 4]); let ser = smt.serialize(); @@ -120,7 +120,7 @@ fn test_add_remove_128() { let kvs = (0..128) .map(|_| { let k = Key(F::rand_array()); - let v = U256(thread_rng().gen()); + let v = U256::from_le_bytes(thread_rng().gen::<[u8; 32]>()); smt.set(k, v); (k, v) }) @@ -129,7 +129,7 @@ fn test_add_remove_128() { smt.set(k, v); } for &(k, _) in &kvs { - smt.set(k, U256::zero()); + smt.set(k, U256::ZERO); } assert_eq!(smt.root.elements, [F::ZERO; 4]); @@ -144,7 +144,7 @@ fn test_should_read_random() { let kvs = (0..128) .map(|_| { let k = Key(F::rand_array()); - let v = U256(thread_rng().gen()); + let v = U256::from_le_bytes(thread_rng().gen::<[u8; 32]>()); smt.set(k, v); (k, v) }) @@ -226,21 +226,25 @@ fn test_leaf_one_level_depth() { ] .map(F::from_canonical_u64)); - let v0 = U256::from_dec_str( + let v0 = U256::from_str_radix( "8163644824788514136399898658176031121905718480550577527648513153802600646339", + 10, ) .unwrap(); - let v1 = U256::from_dec_str( + let v1 = U256::from_str_radix( "115792089237316195423570985008687907853269984665640564039457584007913129639934", + 10, ) .unwrap(); - let v2 = U256::from_dec_str( + let v2 = U256::from_str_radix( "115792089237316195423570985008687907853269984665640564039457584007913129639935", + 10, ) .unwrap(); - let v3 = U256::from_dec_str("7943875943875408").unwrap(); - let v4 = U256::from_dec_str( + let v3 = U256::from_str_radix("7943875943875408", 10).unwrap(); + let v4 = U256::from_str_radix( "35179347944617143021579132182092200136526168785636368258055676929581544372820", + 10, ) .unwrap(); @@ -269,10 +273,10 @@ fn test_no_write_0() { let k1 = Key(F::rand_array()); let k2 = Key(F::rand_array()); - let v = U256(thread_rng().gen()); + let v = U256::from_le_bytes(thread_rng().gen::<[u8; 32]>()); smt.set(k1, v); let root = smt.root; - smt.set(k2, U256::zero()); + smt.set(k2, U256::ZERO); assert_eq!(smt.root, root); let ser = smt.serialize(); @@ -286,7 +290,7 @@ fn test_set_hash_first_level() { let kvs = (0..128) .map(|_| { let k = Key(F::rand_array()); - let v = U256(random()); + let v = U256::from_le_bytes(random::<[u8; 32]>()); smt.set(k, v); (k, v) }) @@ -299,11 +303,11 @@ fn test_set_hash_first_level() { let mut hash_smt = Smt::::default(); let zero = Bits { count: 1, - packed: U256::zero(), + packed: U256::ZERO, }; let one = Bits { count: 1, - packed: U256::one(), + packed: U256::from(1), }; hash_smt.set_hash( zero, @@ -334,7 +338,7 @@ fn test_set_hash_order() { .map(|i| { let k = Bits { count: level, - packed: i.into(), + packed: U256::from(i), }; let hash = HashOut { elements: F::rand_array(), @@ -353,7 +357,7 @@ fn test_set_hash_order() { break key; } }; - let val = U256(random()); + let val = U256::from_le_bytes(random::<[u8; 32]>()); smt.set(key, val); let mut second_smt = Smt::::default(); @@ -375,7 +379,7 @@ fn test_serialize_and_prune() { for _ in 0..128 { let k = Key(F::rand_array()); - let v = U256(random()); + let v = U256::from_le_bytes(random::<[u8; 32]>()); smt.set(k, v); } @@ -399,9 +403,9 @@ fn test_serialize_and_prune() { assert_eq!( trivial_ser, vec![ - U256::zero(), - U256::zero(), - HASH_TYPE.into(), + U256::ZERO, + U256::ZERO, + U256::from_le_bytes([0, 0, 0, 0, 0, 0, 0, HASH_TYPE]), hashout2u(smt.root) ] ); diff --git a/smt_trie/src/utils.rs b/smt_trie/src/utils.rs index 267b6b8e9..eba016817 100644 --- a/smt_trie/src/utils.rs +++ b/smt_trie/src/utils.rs @@ -1,4 +1,4 @@ -use ethereum_types::U256; +use alloy::primitives::U256; use plonky2::field::types::{Field, PrimeField64}; use plonky2::hash::poseidon::Poseidon; @@ -31,7 +31,9 @@ pub(crate) fn hash_key_hash(k: Key, h: [F; 4]) -> [F; 4] { /// Split a U256 into 8 32-bit limbs in little-endian order. pub(crate) fn f2limbs(x: U256) -> [F; 8] { - std::array::from_fn(|i| F::from_canonical_u32((x >> (32 * i)).low_u32())) + std::array::from_fn(|i| { + F::from_canonical_u32(*(x >> (32 * i)).as_limbs().last().unwrap() as u32) + }) } /// Pack 8 32-bit limbs in little-endian order into a U256. @@ -39,7 +41,7 @@ pub(crate) fn limbs2f(limbs: [F; 8]) -> U256 { limbs .into_iter() .enumerate() - .fold(U256::zero(), |acc, (i, x)| { + .fold(U256::ZERO, |acc, (i, x)| { acc + (U256::from(x.to_canonical_u64()) << (i * 32)) }) } @@ -51,19 +53,19 @@ pub fn hashout2u(h: HashOut) -> U256 { /// Convert a `Key` to a `U256`. pub fn key2u(key: Key) -> U256 { - U256(key.0.map(|x| x.to_canonical_u64())) + U256::from_limbs(key.0.map(|x| x.to_canonical_u64())) } /// Convert a `U256` to a `Hashout`. pub(crate) fn u2h(x: U256) -> HashOut { HashOut { - elements: x.0.map(F::from_canonical_u64), + elements: x.as_limbs().map(F::from_canonical_u64), } } /// Convert a `U256` to a `Key`. pub(crate) fn u2k(x: U256) -> Key { - Key(x.0.map(F::from_canonical_u64)) + Key(x.as_limbs().map(F::from_canonical_u64)) } /// Given a node, return the index of the unique non-zero sibling, or -1 if diff --git a/trace_decoder/src/world.rs b/trace_decoder/src/world.rs index fa68854e4..a514117cf 100644 --- a/trace_decoder/src/world.rs +++ b/trace_decoder/src/world.rs @@ -335,9 +335,8 @@ impl World for Type2World { Ok(()) } fn root(&mut self) -> H256 { - let mut it = [0; 32]; - smt_trie::utils::hashout2u(self.as_smt().root).to_big_endian(&mut it); - H256(it) + let root = smt_trie::utils::hashout2u(self.as_smt().root); + H256::from_slice(root.as_le_slice()) } } @@ -398,11 +397,16 @@ impl Type2World { (code_length, key_code_length), ] { if let Some(value) = value { - smt.set(key_fn(*addr), *value); + let addr = compat::address(*addr); + let value = compat::u256(*value); + smt.set(key_fn(addr), value); } } for (slot, value) in storage { - smt.set(key_storage(*addr, *slot), *value); + let addr = compat::address(*addr); + let slot = compat::u256(*slot); + let value = compat::u256(*value); + smt.set(key_storage(addr, slot), value); } } smt @@ -418,3 +422,18 @@ impl Type2World { } } } + +// TODO(serge): Remove this module once this crate uses alloy types. +mod compat { + use alloy::primitives::{Address, U256}; + + pub(crate) fn address(addr: ethereum_types::H160) -> Address { + Address::from_slice(addr.as_bytes()) + } + + pub(crate) fn u256(value: ethereum_types::U256) -> U256 { + let mut buf = [0u8; 32]; + value.to_little_endian(&mut buf); + U256::from_le_bytes(buf) + } +}