Skip to content

Commit

Permalink
refactor(snapshot): code split (#198)
Browse files Browse the repository at this point in the history
* refactor: extract logger functions

* refactor: extract shortned address to function

* chore: merge with origin branch

* chore: remove duplicated function declaration

* chore: move get_shortnet_target and set_logger_env to common

* refactor: extract function selector resolvers to resolve.rs

* refactor: move get_contract_bytecode to a new bytecode module

* refactor: move get_selectors to selectors module

* refactor: make bytecode disambled external from get_resolved_selectors

* refactor: remove logger from params of resolve_signature

* refactor: rework snapshot mod variable and args declaration order

* refactor: remove logger from params list and initialize internally

* fix: remove extrar logger param from get_contract_bytecode call

* chore: change rpc from llama to ankr

* fix(tests): remove unnecessary `0x` prefix constraint

* refactor: rename resolve_custom_event_signatures to resolve_event_signatures

* refactor: rename get_contract_bytecode to get_bytecode_from_target

* fix: update comments from get_bytecode_from_target to be module agnostic

* refactor: switch from logger::debug_max to debug_max

* fix: remove out of context log

* refactor: remove get_logger_and_trace function

* style: code format

* refactor: change target param from string reference to string slice

---------

Co-authored-by: Jon-Becker <[email protected]>
  • Loading branch information
iankressin and Jon-Becker authored Dec 11, 2023
1 parent cfbf6b1 commit 23addf5
Show file tree
Hide file tree
Showing 9 changed files with 572 additions and 336 deletions.
87 changes: 87 additions & 0 deletions common/src/ether/bytecode.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
use super::rpc::get_code;
use crate::{
constants::{ADDRESS_REGEX, BYTECODE_REGEX},
debug_max,
utils::io::logging::Logger,
};
use std::fs;

pub async fn get_bytecode_from_target(
target: &str,
rpc_url: &str,
) -> Result<String, Box<dyn std::error::Error>> {
let (logger, _) = Logger::new("");

if ADDRESS_REGEX.is_match(target)? {
// Target is a contract address, so we need to fetch the bytecode from the RPC provider.
get_code(target, rpc_url).await
} else if BYTECODE_REGEX.is_match(target)? {
debug_max!("using provided bytecode for snapshotting.");

// Target is already a bytecode, so we just need to remove 0x from the begining
Ok(target.replacen("0x", "", 1))
} else {
debug_max!("using provided file for snapshotting.");

// Target is a file path, so we need to read the bytecode from the file.
match fs::read_to_string(target) {
Ok(contents) => {
let _contents = contents.replace('\n', "");
if BYTECODE_REGEX.is_match(&_contents)? && _contents.len() % 2 == 0 {
Ok(_contents.replacen("0x", "", 1))
} else {
logger.error(&format!("file '{}' doesn't contain valid bytecode.", &target));
std::process::exit(1)
}
}
Err(_) => {
logger.error(&format!("failed to open file '{}' .", &target));
std::process::exit(1)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;

#[tokio::test]
async fn test_get_bytecode_when_target_is_address() {
let bytecode = get_bytecode_from_target(
"0x9f00c43700bc0000Ff91bE00841F8e04c0495000",
"https://rpc.ankr.com/eth",
)
.await
.unwrap();

assert!(BYTECODE_REGEX.is_match(&bytecode).unwrap());
}

#[tokio::test]
async fn test_get_bytecode_when_target_is_bytecode() {
let bytecode = get_bytecode_from_target(
"0x00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001",
"https://rpc.ankr.com/eth",
)
.await
.unwrap();

assert!(BYTECODE_REGEX.is_match(&bytecode).unwrap());
}

#[tokio::test]
async fn test_get_bytecode_when_target_is_file_path() {
let file_path = "./mock-file.txt";
let mock_bytecode = "0x00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001";

fs::write(file_path, mock_bytecode).unwrap();

let bytecode =
get_bytecode_from_target(file_path, "https://rpc.ankr.com/eth").await.unwrap();

assert!(BYTECODE_REGEX.is_match(&bytecode).unwrap());

fs::remove_file(file_path).unwrap();
}
}
1 change: 1 addition & 0 deletions common/src/ether/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod bytecode;
pub mod compiler;
pub mod evm;
pub mod lexers;
Expand Down
2 changes: 1 addition & 1 deletion common/src/ether/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ pub async fn get_code(
)
.map_err(|_| logger.error(&format!("failed to cache bytecode for contract: {:?}", &contract_address)));

Ok(bytecode_as_bytes.to_string())
Ok(bytecode_as_bytes.to_string().replacen("0x", "", 1))
})
.await
.map_err(|_| Box::from("failed to fetch bytecode"))
Expand Down
34 changes: 33 additions & 1 deletion common/src/ether/selectors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,41 @@ use tokio::task;

use crate::utils::{io::logging::Logger, strings::decode_hex};

use super::{evm::core::vm::VM, signatures::ResolveSelector};
use super::{
evm::core::vm::VM,
signatures::{ResolveSelector, ResolvedFunction},
};
use crate::debug_max;

// Find all function selectors and all the data associated to this function, represented by
// [`ResolvedFunction`]
pub async fn get_resolved_selectors(
disassembled_bytecode: &str,
skip_resolving: &bool,
evm: &VM,
) -> Result<
(HashMap<String, u128>, HashMap<String, Vec<ResolvedFunction>>),
Box<dyn std::error::Error>,
> {
let selectors = find_function_selectors(evm, disassembled_bytecode);

let mut resolved_selectors = HashMap::new();
if !skip_resolving {
resolved_selectors =
resolve_selectors::<ResolvedFunction>(selectors.keys().cloned().collect()).await;

debug_max!(&format!(
"resolved {} possible functions from {} detected selectors.",
resolved_selectors.len(),
selectors.len()
));
} else {
debug_max!(&format!("found {} possible function selectors.", selectors.len()));
}

Ok((selectors, resolved_selectors))
}

/// find all function selectors in the given EVM assembly.
pub fn find_function_selectors(evm: &VM, assembly: &str) -> HashMap<String, u128> {
let mut function_selectors = HashMap::new();
Expand Down
33 changes: 33 additions & 0 deletions common/src/utils/io/logging.rs
Original file line number Diff line number Diff line change
Expand Up @@ -673,11 +673,33 @@ impl Logger {
}
}

/// Set `RUST_LOG` variable to env if does not exist
///
/// ```
/// use heimdall_common::utils::io::logging::set_logger_env;
///
/// let verbosity = clap_verbosity_flag::Verbosity::new(-1, 0);
/// set_logger_env(&verbosity);
/// ```
pub fn set_logger_env(verbosity: &clap_verbosity_flag::Verbosity) {
let env_not_set = std::env::var("RUST_LOG").is_err();

if env_not_set {
let log_level = match verbosity.log_level() {
Some(level) => level.as_str(),
None => "SILENT",
};

std::env::set_var("RUST_LOG", log_level);
}
}

#[cfg(test)]
mod tests {
use std::time::Instant;

use super::*;
use std::env;

#[test]
fn test_raw_trace() {
Expand Down Expand Up @@ -956,4 +978,15 @@ mod tests {
let (_logger, _) = Logger::new("MAX");
debug_max!("log");
}

#[test]
fn test_set_logger_env_default() {
env::remove_var("RUST_LOG");

let verbosity = clap_verbosity_flag::Verbosity::new(-1, 0);

set_logger_env(&verbosity);

assert_eq!(env::var("RUST_LOG").unwrap(), "SILENT");
}
}
36 changes: 36 additions & 0 deletions common/src/utils/strings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,26 @@ pub fn classify_token(token: &str) -> TokenType {
TokenType::Function
}

/// Returns a collapsed version of a string if this string is greater than 66 characters in length.
/// The collapsed string consists of the first 66 characters, followed by an ellipsis ("..."), and
/// then the last 16 characters of the original string. ```
/// use heimdall_common::utils::strings::get_shortned_target;
///
/// let long_target = "0".repeat(80);
/// let shortened_target = get_shortned_target(&long_target);
/// ```
pub fn get_shortned_target(target: &str) -> String {
let mut shortened_target = target.to_string();

if shortened_target.len() > 66 {
shortened_target = shortened_target.chars().take(66).collect::<String>() +
"..." +
&shortened_target.chars().skip(shortened_target.len() - 16).collect::<String>();
}

shortened_target
}

#[cfg(test)]
mod tests {
use ethers::types::{I256, U256};
Expand Down Expand Up @@ -691,4 +711,20 @@ mod tests {
assert_eq!(classification, TokenType::Function);
}
}

#[test]
fn test_shorten_long_target() {
let long_target = "0".repeat(80);
let shortened_target = get_shortned_target(&long_target);

assert_eq!(shortened_target.len(), 85);
}

#[test]
fn test_shorten_short_target() {
let short_target = "0".repeat(66);
let shortened_target = get_shortned_target(&short_target);

assert_eq!(shortened_target.len(), 66);
}
}
Loading

0 comments on commit 23addf5

Please sign in to comment.