Skip to content

Commit

Permalink
Replace ethereum_types with alloy::primitives in smt_trie crate
Browse files Browse the repository at this point in the history
  • Loading branch information
sergerad committed Nov 5, 2024
1 parent 4a747b2 commit 84e3271
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 78 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion smt_trie/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ homepage.workspace = true
keywords.workspace = true

[dependencies]
ethereum-types.workspace = true
alloy.workspace = true
plonky2.workspace = true
rand.workspace = true
serde = { workspace = true, features = ["derive", "rc"] }
Expand Down
18 changes: 9 additions & 9 deletions smt_trie/src/bits.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::ops::Add;

use ethereum_types::{BigEndianHash, H256, U256};
use alloy::primitives::{B256, U256};
use serde::{Deserialize, Serialize};

pub type Bit = bool;
Expand All @@ -22,11 +22,11 @@ impl From<U256> for Bits {
}
}

impl From<H256> for Bits {
fn from(packed: H256) -> Self {
impl From<B256> for Bits {
fn from(packed: B256) -> Self {
Bits {
count: 256,
packed: packed.into_uint(),
packed: packed.into(),
}
}
}
Expand All @@ -38,7 +38,7 @@ impl Add for Bits {
assert!(self.count + rhs.count <= 256, "Overflow");
Self {
count: self.count + rhs.count,
packed: self.packed * (U256::one() << rhs.count) + rhs.packed,
packed: self.packed * (U256::from(1) << rhs.count) + rhs.packed,
}
}
}
Expand All @@ -47,7 +47,7 @@ impl Bits {
pub const fn empty() -> Self {
Bits {
count: 0,
packed: U256::zero(),
packed: U256::ZERO,
}
}

Expand All @@ -57,19 +57,19 @@ impl Bits {

pub fn pop_next_bit(&mut self) -> Bit {
assert!(!self.is_empty(), "Cannot pop from empty bits");
let b = !(self.packed & U256::one()).is_zero();
let b = !(self.packed & U256::from(1)).is_zero();
self.packed >>= 1;
self.count -= 1;
b
}

pub fn get_bit(&self, i: usize) -> Bit {
assert!(i < self.count, "Index out of bounds");
!(self.packed & (U256::one() << (self.count - 1 - i))).is_zero()
!(self.packed & (U256::from(1) << (self.count - 1 - i))).is_zero()
}

pub fn push_bit(&mut self, bit: Bit) {
self.packed = self.packed * 2 + U256::from(bit as u64);
self.packed = self.packed * U256::from(2) + U256::from(bit as u64);
self.count += 1;
}

Expand Down
2 changes: 1 addition & 1 deletion smt_trie/src/code.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/// Functions to hash contract bytecode using Poseidon.
/// See `hashContractBytecode()` in https://github.com/0xPolygonHermez/zkevm-commonjs/blob/main/src/smt-utils.js for reference implementation.
use ethereum_types::U256;
use alloy::primitives::U256;
use plonky2::field::types::Field;
use plonky2::hash::poseidon::{self, Poseidon};

Expand Down
6 changes: 3 additions & 3 deletions smt_trie/src/keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

/// This module contains functions to generate keys for the SMT.
/// See https://github.com/0xPolygonHermez/zkevm-commonjs/blob/main/src/smt-utils.js for reference implementation.
use ethereum_types::{Address, U256};
use alloy::primitives::{Address, U256};
use plonky2::{field::types::Field, hash::poseidon::Poseidon};

use crate::smt::{Key, F};
Expand Down Expand Up @@ -74,8 +74,8 @@ pub fn key_storage(addr: Address, slot: U256) -> Key {
let capacity: [F; 4] = {
let mut arr = [F::ZERO; 12];
for i in 0..4 {
arr[2 * i] = F::from_canonical_u32(slot.0[i] as u32);
arr[2 * i + 1] = F::from_canonical_u32((slot.0[i] >> 32) as u32);
arr[2 * i] = F::from_canonical_u32(slot.as_limbs()[i] as u32);
arr[2 * i + 1] = F::from_canonical_u32((slot.as_limbs()[i] >> 32) as u32);
}
F::poseidon(arr)[0..4].try_into().unwrap()
};
Expand Down
45 changes: 28 additions & 17 deletions smt_trie/src/smt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use std::borrow::Borrow;
use std::collections::{HashMap, HashSet};

use ethereum_types::U256;
use alloy::primitives::U256;
use plonky2::field::goldilocks_field::GoldilocksField;
use plonky2::field::types::{Field, PrimeField64};
use plonky2::hash::poseidon::{Poseidon, PoseidonHash};
Expand Down Expand Up @@ -145,7 +145,7 @@ impl<D: Db> Smt<D> {
.copied()
.unwrap_or_default()
.is_zero());
U256::zero()
U256::ZERO
};
} else {
let b = keys.get_bit(level as usize);
Expand Down Expand Up @@ -347,7 +347,7 @@ impl<D: Db> Smt<D> {
/// Delete the key in the SMT.
pub fn delete(&mut self, key: Key) {
self.kv_store.remove(&key);
self.set(key, U256::zero());
self.set(key, U256::ZERO);
}

/// Set the key to the hash in the SMT.
Expand Down Expand Up @@ -416,7 +416,7 @@ impl<D: Db> Smt<D> {
&self,
keys: I,
) -> Vec<U256> {
let mut v = vec![U256::zero(); 2]; // For empty hash node.
let mut v = vec![U256::ZERO; 2]; // For empty hash node.
let key = Key(self.root.elements);

let mut keys_to_include = HashSet::new();
Expand All @@ -433,7 +433,7 @@ impl<D: Db> Smt<D> {

serialize(self, key, &mut v, Bits::empty(), &keys_to_include);
if v.len() == 2 {
v.extend([U256::zero(); 2]);
v.extend([U256::ZERO; 2]);
}
v
}
Expand All @@ -457,7 +457,7 @@ fn serialize<D: Db>(

if !keys_to_include.contains(&cur_bits) || smt.db.get_node(&key).is_none() {
let index = v.len();
v.push(HASH_TYPE.into());
v.push(U256::from_limbs([0, 0, 0, HASH_TYPE.into()]));
v.push(key2u(key));
index
} else if let Some(node) = smt.db.get_node(&key) {
Expand All @@ -473,22 +473,32 @@ fn serialize<D: Db>(
let rem_key = Key(node.0[0..4].try_into().unwrap());
let val = limbs2f(val_a);
let index = v.len();
v.push(LEAF_TYPE.into());
v.push(U256::from_limbs([0, 0, 0, LEAF_TYPE.into()]));
v.push(key2u(rem_key));
v.push(val);
index
} else {
let key_left = Key(node.0[0..4].try_into().unwrap());
let key_right = Key(node.0[4..8].try_into().unwrap());
let index = v.len();
v.push(INTERNAL_TYPE.into());
v.push(U256::zero());
v.push(U256::zero());
let i_left =
serialize(smt, key_left, v, cur_bits.add_bit(false), keys_to_include).into();
v.push(U256::from_limbs([0, 0, 0, INTERNAL_TYPE.into()]));
v.push(U256::ZERO);
v.push(U256::ZERO);
let i_left = U256::from(serialize(
smt,
key_left,
v,
cur_bits.add_bit(false),
keys_to_include,
));
v[index + 1] = i_left;
let i_right =
serialize(smt, key_right, v, cur_bits.add_bit(true), keys_to_include).into();
let i_right = U256::from(serialize(
smt,
key_right,
v,
cur_bits.add_bit(true),
keys_to_include,
));
v[index + 2] = i_right;
index
}
Expand All @@ -507,15 +517,16 @@ pub fn hash_serialize_u256(v: &[U256]) -> U256 {
}

fn _hash_serialize(v: &[U256], ptr: usize) -> HashOut {
assert!(v[ptr] <= u8::MAX.into());
match v[ptr].as_u64() as u8 {
let byte: u8 = v[ptr].try_into().unwrap();
match byte {
HASH_TYPE => u2h(v[ptr + 1]),

INTERNAL_TYPE => {
let mut node = Node([F::ZERO; 12]);
for b in 0..2 {
let child_index = v[ptr + 1 + b];
let child_hash = _hash_serialize(v, child_index.as_usize());
let child_index: usize = child_index.try_into().unwrap();
let child_hash = _hash_serialize(v, child_index);
node.0[b * 4..(b + 1) * 4].copy_from_slice(&child_hash.elements);
}
F::poseidon(node.0)[0..4].try_into().unwrap()
Expand Down
Loading

0 comments on commit 84e3271

Please sign in to comment.