Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

feat(abigen): support overloaded functions with different casing #650

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions ethers-contract/ethers-contract-abigen/src/contract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use ethers_core::{
macros::{ethers_contract_crate, ethers_core_crate, ethers_providers_crate},
};

use crate::contract::methods::MethodAlias;
use proc_macro2::{Ident, Literal, TokenStream};
use quote::quote;
use serde::Deserialize;
Expand Down Expand Up @@ -78,7 +79,7 @@ pub struct Context {
contract_name: Ident,

/// Manually specified method aliases.
method_aliases: BTreeMap<String, Ident>,
method_aliases: BTreeMap<String, MethodAlias>,

/// Derives added to event structs and enums.
event_derives: Vec<Path>,
Expand Down Expand Up @@ -204,7 +205,11 @@ impl Context {
// method will be re-defined.
let mut method_aliases = BTreeMap::new();
for (signature, alias) in args.method_aliases.into_iter() {
let alias = syn::parse_str(&alias)?;
let alias = MethodAlias {
function_name: util::safe_ident(&alias),
struct_name: util::safe_pascal_case_ident(&alias),
};

if method_aliases.insert(signature.clone(), alias).is_some() {
return Err(anyhow!("duplicate method signature '{}' in method aliases", signature,))
}
Expand Down
122 changes: 84 additions & 38 deletions ethers-contract/ethers-contract-abigen/src/contract/methods.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::{btree_map::Entry, BTreeMap};
use std::collections::{btree_map::Entry, BTreeMap, HashMap};

use anyhow::{Context as _, Result};
use inflector::Inflector;
Expand Down Expand Up @@ -47,7 +47,7 @@ impl Context {
fn expand_call_struct(
&self,
function: &Function,
alias: Option<&Ident>,
alias: Option<&MethodAlias>,
) -> Result<TokenStream> {
let call_name = expand_call_struct_name(function, alias);
let fields = self.expand_input_pairs(function)?;
Expand Down Expand Up @@ -82,7 +82,7 @@ impl Context {
}

/// Expands all structs
fn expand_call_structs(&self, aliases: BTreeMap<String, Ident>) -> Result<TokenStream> {
fn expand_call_structs(&self, aliases: BTreeMap<String, MethodAlias>) -> Result<TokenStream> {
let mut struct_defs = Vec::new();
let mut struct_names = Vec::new();
let mut variant_names = Vec::new();
Expand Down Expand Up @@ -236,7 +236,11 @@ impl Context {
}

/// Expands a single function with the given alias
fn expand_function(&self, function: &Function, alias: Option<Ident>) -> Result<TokenStream> {
fn expand_function(
&self,
function: &Function,
alias: Option<MethodAlias>,
) -> Result<TokenStream> {
let name = expand_function_name(function, alias.as_ref());
let selector = expand_selector(function.selector());

Expand Down Expand Up @@ -275,10 +279,21 @@ impl Context {
// The first function or the function with the least amount of arguments should
// be named as in the ABI, the following functions suffixed with _with_ +
// additional_params[0].name + (_and_(additional_params[1+i].name))*
fn get_method_aliases(&self) -> Result<BTreeMap<String, Ident>> {
fn get_method_aliases(&self) -> Result<BTreeMap<String, MethodAlias>> {
let mut aliases = self.method_aliases.clone();

// it might be the case that there are functions with different capitalization so we sort
// them all by lc name first
let mut all_functions = HashMap::new();
for function in self.abi.functions() {
all_functions
.entry(function.name.to_lowercase())
.or_insert_with(Vec::new)
.push(function);
}

// find all duplicates, where no aliases where provided
for functions in self.abi.functions.values() {
for functions in all_functions.values() {
if functions.iter().filter(|f| !aliases.contains_key(&f.abi_signature())).count() <= 1 {
// no overloads, hence no conflicts
continue
Expand Down Expand Up @@ -318,7 +333,7 @@ impl Context {
let mut diffs = Vec::new();

/// helper function that checks if there are any conflicts due to parameter names
fn name_conflicts(idx: usize, diffs: &[(usize, Vec<&Param>, &Function)]) -> bool {
fn name_conflicts(idx: usize, diffs: &[(usize, Vec<&Param>, &&Function)]) -> bool {
let diff = &diffs.iter().find(|(i, _, _)| *i == idx).expect("diff exists").1;

for (_, other, _) in diffs.iter().filter(|(i, _, _)| *i != idx) {
Expand All @@ -333,7 +348,6 @@ impl Context {
}
false
}

// compare each overloaded function with the `first_fun`
for (idx, overloaded_fun) in functions.into_iter().skip(1) {
// attempt to find diff in the input arguments
Expand All @@ -357,12 +371,36 @@ impl Context {
for (idx, diff, overloaded_fun) in &diffs {
let alias = match diff.len() {
0 => {
// this should not happen since functions with same name and inputs are
// illegal
anyhow::bail!(
"Function with same name and parameter types defined twice: {}",
overloaded_fun.name
);
// this may happen if there are functions with different casing,
// like `INDEX`and `index`
if overloaded_fun.name != first_fun.name {
let overloaded_id = overloaded_fun.name.to_snake_case();
let first_fun_id = first_fun.name.to_snake_case();
if first_fun_id != overloaded_id {
// no conflict
overloaded_id
} else {
let overloaded_alias = MethodAlias {
function_name: util::safe_ident(&overloaded_fun.name),
struct_name: util::safe_ident(&overloaded_fun.name),
};
aliases.insert(overloaded_fun.abi_signature(), overloaded_alias);

let first_fun_alias = MethodAlias {
function_name: util::safe_ident(&first_fun.name),
struct_name: util::safe_ident(&first_fun.name),
};
aliases.insert(first_fun.abi_signature(), first_fun_alias);
continue
}
} else {
// this should not happen since functions with same name and inputs are
// illegal
anyhow::bail!(
"Function with same name and parameter types defined twice: {}",
overloaded_fun.name
);
}
}
1 => {
// single additional input params
Expand Down Expand Up @@ -404,13 +442,17 @@ impl Context {
}
}
};
aliases.insert(overloaded_fun.abi_signature(), util::safe_ident(&alias));
let alias = MethodAlias::new(&alias);
aliases.insert(overloaded_fun.abi_signature(), alias);
}

if needs_alias_for_first_fun_using_idx {
// insert an alias for the root duplicated call
let prev_alias = format!("{}{}", first_fun.name.to_snake_case(), first_fun_idx);
aliases.insert(first_fun.abi_signature(), util::safe_ident(&prev_alias));

let alias = MethodAlias::new(&prev_alias);

aliases.insert(first_fun.abi_signature(), alias);
}
}

Expand All @@ -426,7 +468,7 @@ impl Context {
for function in functions {
if let Entry::Vacant(entry) = aliases.entry(function.abi_signature()) {
// use the full name as alias
entry.insert(util::ident(name.as_str()));
entry.insert(MethodAlias::new(name.as_str()));
}
}
}
Expand Down Expand Up @@ -455,43 +497,47 @@ fn expand_selector(selector: Selector) -> TokenStream {
quote! { [#( #bytes ),*] }
}

fn expand_function_name(function: &Function, alias: Option<&Ident>) -> Ident {
/// Represents the aliases to use when generating method related elements
#[derive(Debug, Clone)]
pub struct MethodAlias {
pub function_name: Ident,
pub struct_name: Ident,
}

impl MethodAlias {
pub fn new(alias: &str) -> Self {
MethodAlias {
function_name: util::safe_snake_case_ident(alias),
struct_name: util::safe_pascal_case_ident(alias),
}
}
}

fn expand_function_name(function: &Function, alias: Option<&MethodAlias>) -> Ident {
if let Some(alias) = alias {
// snake_case strips leading and trailing underscores so we simply add them back if the
// alias starts/ends with underscores
let alias = alias.to_string();
let ident = alias.to_snake_case();
util::ident(&util::preserve_underscore_delim(&ident, &alias))
alias.function_name.clone()
} else {
util::safe_ident(&function.name.to_snake_case())
}
}

/// Expands to the name of the call struct
fn expand_call_struct_name(function: &Function, alias: Option<&Ident>) -> Ident {
fn expand_call_struct_name(function: &Function, alias: Option<&MethodAlias>) -> Ident {
let name = if let Some(alias) = alias {
// pascal_case strips leading and trailing underscores so we simply add them back if the
// alias starts/ends with underscores
let alias = alias.to_string();
let ident = alias.to_pascal_case();
let alias = util::preserve_underscore_delim(&ident, &alias);
format!("{}Call", alias)
format!("{}Call", alias.struct_name)
} else {
format!("{}Call", function.name.to_pascal_case())
};
util::ident(&name)
}

/// Expands to the name of the call struct
fn expand_call_struct_variant_name(function: &Function, alias: Option<&Ident>) -> Ident {
let name = if let Some(alias) = alias {
let alias = alias.to_string();
let ident = alias.to_pascal_case();
util::preserve_underscore_delim(&ident, &alias)
fn expand_call_struct_variant_name(function: &Function, alias: Option<&MethodAlias>) -> Ident {
if let Some(alias) = alias {
alias.struct_name.clone()
} else {
function.name.to_pascal_case()
};
util::ident(&name)
util::safe_ident(&function.name.to_pascal_case())
}
}

/// Expands to the tuple struct definition
Expand Down
14 changes: 13 additions & 1 deletion ethers-contract/ethers-contract-abigen/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use quote::quote;

use syn::{Ident as SynIdent, Path};

/// Expands a identifier string into an token.
/// Expands a identifier string into a token.
pub fn ident(name: &str) -> Ident {
Ident::new(name, Span::call_site())
}
Expand All @@ -22,6 +22,18 @@ pub fn safe_ident(name: &str) -> Ident {
syn::parse_str::<SynIdent>(name).unwrap_or_else(|_| ident(&format!("{}_", name)))
}

/// Expands an identifier as snakecase and preserve any leading or trailing underscores
pub fn safe_snake_case_ident(name: &str) -> Ident {
let i = name.to_snake_case();
ident(&preserve_underscore_delim(&i, name))
}

/// Expands an identifier as pascal case and preserve any leading or trailing underscores
pub fn safe_pascal_case_ident(name: &str) -> Ident {
let i = name.to_pascal_case();
ident(&preserve_underscore_delim(&i, name))
}

/// Reapplies leading and trailing underscore chars to the ident
/// Example `ident = "pascalCase"; alias = __pascalcase__` -> `__pascalCase__`
pub fn preserve_underscore_delim(ident: &str, alias: &str) -> String {
Expand Down
17 changes: 17 additions & 0 deletions ethers-contract/tests/abigen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,3 +402,20 @@ fn can_generate_nested_types() {
let decoded_call = MyfunCall::decode(encoded_call.as_ref()).unwrap();
assert_eq!(call, decoded_call);
}

#[test]
fn can_handle_case_sensitive_calls() {
abigen!(
StakedOHM,
r#"[
index()
INDEX()
]"#,
);

let (client, _mock) = Provider::mocked();
let contract = StakedOHM::new(Address::default(), Arc::new(client));

let _ = contract.index();
let _ = contract.INDEX();
}