diff --git a/src/derives/arg_enum.rs b/src/derives/arg_enum.rs index fb94afb..cc9e205 100644 --- a/src/derives/arg_enum.rs +++ b/src/derives/arg_enum.rs @@ -14,33 +14,33 @@ use syn::punctuated; use syn::token; pub fn derive_arg_enum(ast: &syn::DeriveInput) -> proc_macro2::TokenStream { - unimplemented!() + let from_str_block = impl_from_str(ast); + let variants_block = impl_variants(ast); - // let from_str_block = impl_from_str(ast)?; - // let variants_block = impl_variants(ast)?; - - // quote! { - // #from_str_block - // #variants_block - // } + quote! { + #from_str_block + #variants_block + } } -/* fn impl_from_str(ast: &syn::DeriveInput) -> proc_macro2::TokenStream { let ident = &ast.ident; - let is_case_sensitive = ast.attrs.iter().any(|v| v.name() == "case_sensitive"); - let variants = variants(ast)?; + let is_case_sensitive = ast + .attrs + .iter() + .any(|v| v.path.segments.iter().any(|s| s.ident == "case_sensitive")); + let variants = variants(ast); let strings = variants .iter() - .map(|ref variant| String::from(variant.ident.as_ref())) + .map(|variant| variant.ident.to_string()) .collect::>(); // All of these need to be iterators. let ident_slice = [ident.clone()]; let idents = ident_slice.iter().cycle(); - let for_error_message = strings.clone(); + let for_error_message = strings.join(", "); let condition_function_slice = [match is_case_sensitive { true => quote! { str::eq }, @@ -48,47 +48,50 @@ fn impl_from_str(ast: &syn::DeriveInput) -> proc_macro2::TokenStream { }]; let condition_function = condition_function_slice.iter().cycle(); - Ok(quote! { + quote! { impl ::std::str::FromStr for #ident { type Err = String; fn from_str(input: &str) -> ::std::result::Result { match input { #(val if #condition_function(val, #strings) => Ok(#idents::#variants),)* - _ => Err({ - let v = #for_error_message; - format!("valid values: {}", - v.join(" ,")) - }), + _ => Err( + format!("valid values: {}", #for_error_message) + ), } } } - }) + + impl ::std::convert::Into<&'static str> for #ident { + fn into(self) -> &'static str { + match self { + #(val => stringify!(#variants),)* + } + } + } + } } +// See if we can return an array instead of a vec fn impl_variants(ast: &syn::DeriveInput) -> proc_macro2::TokenStream { let ident = &ast.ident; - let variants = variants(ast)? - .iter() - .map(|ref variant| String::from(variant.ident.as_ref())) - .collect::>(); - let length = variants.len(); + let local_variants = variants(ast); - Ok(quote! { + quote! { impl #ident { - fn variants() -> [&'static str; #length] { - #variants + fn variants() -> ::std::vec::Vec<#ident> { + use #ident::*; + vec![#local_variants] } } - }) + } } fn variants(ast: &syn::DeriveInput) -> &punctuated::Punctuated { use syn::Data::*; match ast.data { - Enum(ref data) => data.variants, + Enum(ref data) => &data.variants, _ => panic!("Only enums are supported for deriving the ArgEnum trait"), } } -*/ diff --git a/src/derives/attrs.rs b/src/derives/attrs.rs index 2944c85..7ed7408 100644 --- a/src/derives/attrs.rs +++ b/src/derives/attrs.rs @@ -366,7 +366,11 @@ impl Attrs { pub fn methods(&self) -> proc_macro2::TokenStream { let methods = self.methods.iter().map(|&Method { ref name, ref args }| { let name = syn::Ident::new(&name, proc_macro2::Span::call_site()); - quote!( .#name(#args) ) + if name == "short" { + quote!( .#name(#args.chars().nth(0).unwrap()) ) + } else { + quote!( .#name(#args) ) + } }); quote!( #(#methods)* ) } diff --git a/tests/arg_enum_basic.rs b/tests/arg_enum_basic.rs index 53d8ac3..64b8997 100644 --- a/tests/arg_enum_basic.rs +++ b/tests/arg_enum_basic.rs @@ -14,7 +14,7 @@ extern crate clap_derive; use clap::{App, Arg}; -#[derive(ArgEnum, Debug, PartialEq)] +#[derive(ArgEnum, Debug, PartialEq, Copy, Clone)] enum ArgChoice { Foo, Bar, @@ -28,15 +28,24 @@ fn when_lowercase() { Arg::with_name("arg") .required(true) .takes_value(true) - .possible_values(&ArgChoice::variants()), - ) - .get_matches_from_safe(vec!["", "foo"]) + .possible_values(ArgChoice::variants()), + ).get_matches_from_safe(vec!["", "foo"]) .unwrap(); let t = value_t!(matches.value_of("arg"), ArgChoice); assert!(t.is_ok()); assert_eq!(t.unwrap(), ArgChoice::Foo); } +#[test] +fn when_lowercase_derive() { + #[derive(Clap)] + struct Opt { + choice: ArgChoice, + } + + assert_eq!(Opt::parse_from(&["opt", "foo"]).choice, ArgChoice::Foo); +} + #[test] fn when_capitalized() { let matches = App::new(env!("CARGO_PKG_NAME")) @@ -44,9 +53,8 @@ fn when_capitalized() { Arg::with_name("arg") .required(true) .takes_value(true) - .possible_values(&ArgChoice::variants()), - ) - .get_matches_from_safe(vec!["", "Foo"]) + .possible_values(ArgChoice::variants()), + ).get_matches_from_safe(vec!["", "Foo"]) .unwrap(); let t = value_t!(matches.value_of("arg"), ArgChoice); assert!(t.is_ok());