From d892f6df8f2b128b66c12726e1f38b4076f5ec57 Mon Sep 17 00:00:00 2001
From: Dan Gohman <dev@sunfishcode.online>
Date: Mon, 9 Nov 2020 06:43:16 -0800
Subject: [PATCH] imp: Make `clap_derive` call `FromStr::from_str` only once
 per argument.

Currently the way `clap_derive` uses a `from_str` function is to call
it once as a validator, discarding the parsed value, and then to call
it again to recompute the parsed value. This is unfortunate in
cases where `from_str` is expensive or has side effects.

This PR changes `clap_derive` to not register `from_str` as a validator
so that it doesn't do the first of these two calls. Then, instead of
doing `unwrap()` on the other call, it handles the error. This eliminates
the redundancy, and also avoids the small performance hit mentioned in
[the documentation about validators].

[the documentation about validators]: https://docs.rs/clap-v3/3.0.0-beta.1/clap_v3/struct.Arg.html#method.validator

This PR doesn't yet use colorized messages for errors generated during
parsing because the `ColorWhen` setting isn't currently available.
That's fixable with some refactoring, but I'm interested in getting
feedback on the overall approach here first.
---
 clap_derive/src/derives/dummies.rs         |   4 +-
 clap_derive/src/derives/from_argmatches.rs | 143 ++++++++++++++-------
 clap_derive/src/derives/into_app.rs        |  32 +----
 clap_derive/src/derives/subcommand.rs      |  18 ++-
 clap_derive/tests/arguments.rs             |   2 +-
 clap_derive/tests/effectful.rs             |  28 ++++
 src/derive.rs                              |  22 +++-
 7 files changed, 159 insertions(+), 90 deletions(-)
 create mode 100644 clap_derive/tests/effectful.rs

diff --git a/clap_derive/src/derives/dummies.rs b/clap_derive/src/derives/dummies.rs
index 582a7c0f401f..0f5303d0adbf 100644
--- a/clap_derive/src/derives/dummies.rs
+++ b/clap_derive/src/derives/dummies.rs
@@ -33,7 +33,7 @@ pub fn into_app(name: &Ident) {
 pub fn from_arg_matches(name: &Ident) {
     append_dummy(quote! {
         impl ::clap::FromArgMatches for #name {
-            fn from_arg_matches(_m: &::clap::ArgMatches) -> Self {
+            fn try_from_arg_matches(_m: &::clap::ArgMatches) -> Result<Self, ::clap::Error> {
                 unimplemented!()
             }
         }
@@ -43,7 +43,7 @@ pub fn from_arg_matches(name: &Ident) {
 pub fn subcommand(name: &Ident) {
     append_dummy(quote! {
         impl ::clap::Subcommand for #name {
-            fn from_subcommand(_name: &str, _matches: Option<&::clap::ArgMatches>) -> Option<Self> {
+            fn from_subcommand(_name: &str, _matches: Option<&::clap::ArgMatches>) -> Result<Self, ::clap::Error> {
                 unimplemented!()
             }
             fn augment_subcommands(_app: ::clap::App<'_>) -> ::clap::App<'_> {
diff --git a/clap_derive/src/derives/from_argmatches.rs b/clap_derive/src/derives/from_argmatches.rs
index 40113029cee0..3bc6bf4f663b 100644
--- a/clap_derive/src/derives/from_argmatches.rs
+++ b/clap_derive/src/derives/from_argmatches.rs
@@ -37,8 +37,8 @@ pub fn gen_for_struct(
         )]
         #[deny(clippy::correctness)]
         impl ::clap::FromArgMatches for #struct_name {
-            fn from_arg_matches(matches: &::clap::ArgMatches) -> Self {
-                #struct_name #constructor
+            fn try_from_arg_matches(matches: &::clap::ArgMatches) -> Result<Self, ::clap::Error> {
+                Ok(#struct_name #constructor)
             }
         }
     }
@@ -59,10 +59,9 @@ pub fn gen_for_enum(name: &syn::Ident) -> proc_macro2::TokenStream {
         )]
         #[deny(clippy::correctness)]
         impl ::clap::FromArgMatches for #name {
-            fn from_arg_matches(matches: &::clap::ArgMatches) -> Self {
+            fn try_from_arg_matches(matches: &::clap::ArgMatches) -> Result<Self, ::clap::Error> {
                 let (name, subcmd) = matches.subcommand();
                 <#name as ::clap::Subcommand>::from_subcommand(name, subcmd)
-                    .unwrap()
             }
         }
     }
@@ -86,24 +85,39 @@ pub fn gen_constructor(
                     (Ty::Option, Some(sub_type)) => sub_type,
                     _ => &field.ty,
                 };
-                let unwrapper = match **ty {
-                    Ty::Option => quote!(),
-                    _ => quote_spanned!( ty.span()=> .unwrap() ),
-                };
-                quote_spanned! { kind.span()=>
-                    #field_name: {
-                        let (name, subcmd) = matches.subcommand();
-                        <#subcmd_type as ::clap::Subcommand>::from_subcommand(
-                            name,
-                            subcmd
-                        )
-                        #unwrapper
+                match **ty {
+                    Ty::Option => {
+                        quote_spanned! { kind.span()=>
+                            #field_name: {
+                                let (name, subcmd) = matches.subcommand();
+                                match <#subcmd_type as ::clap::Subcommand>::from_subcommand(
+                                    name,
+                                    subcmd
+                                ) {
+                                    Ok(cmd) => Some(cmd),
+                                    Err(::clap::Error {
+                                        kind: ::clap::ErrorKind::UnrecognizedSubcommand,
+                                        ..
+                                    }) => None,
+                                    Err(e) => return Err(e),
+                                }
+                            }
+                        }
                     }
+                    _ => quote_spanned! { kind.span()=>
+                        #field_name: {
+                            let (name, subcmd) = matches.subcommand();
+                            <#subcmd_type as ::clap::Subcommand>::from_subcommand(
+                                name,
+                                subcmd
+                            )?
+                        }
+                    },
                 }
             }
 
             Kind::Flatten => quote_spanned! { kind.span()=>
-                #field_name: ::clap::FromArgMatches::from_arg_matches(matches)
+                #field_name: ::clap::FromArgMatches::try_from_arg_matches(matches)?
             },
 
             Kind::Skip(val) => match val {
@@ -126,7 +140,7 @@ pub fn gen_constructor(
                     TryFromStr => (
                         quote_spanned!(span=> value_of),
                         quote_spanned!(span=> values_of),
-                        quote_spanned!(func.span()=> |s| #func(s).unwrap()),
+                        quote_spanned!(func.span()=> #func),
                     ),
                     FromOsStr => (
                         quote_spanned!(span=> value_of_os),
@@ -136,7 +150,7 @@ pub fn gen_constructor(
                     TryFromOsStr => (
                         quote_spanned!(span=> value_of_os),
                         quote_spanned!(span=> values_of_os),
-                        quote_spanned!(func.span()=> |s| #func(s).unwrap()),
+                        quote_spanned!(func.span()=> #func),
                     ),
                     FromOccurrences => (
                         quote_spanned!(span=> occurrences_of),
@@ -149,39 +163,74 @@ pub fn gen_constructor(
                 let flag = *attrs.parser().kind == ParserKind::FromFlag;
                 let occurrences = *attrs.parser().kind == ParserKind::FromOccurrences;
                 let name = attrs.cased_name();
+                let map_err = match *parser.kind {
+                    TryFromStr | TryFromOsStr => quote_spanned! { ty.span() =>
+                        .map_err(|e| ::clap::Error::with_description(
+                                format!("Invalid value for '<{}>': {}", #name, e),
+                                ::clap::ErrorKind::InvalidValue))?
+                    },
+                    _ => quote!(),
+                };
                 let field_value = match **ty {
                     Ty::Bool => quote_spanned! { ty.span()=>
                         matches.is_present(#name)
                     },
 
-                    Ty::Option => quote_spanned! { ty.span()=>
-                        matches.#value_of(#name)
-                            .map(#parse)
-                    },
+                    Ty::Option => {
+                        quote_spanned! { ty.span()=>
+                            match matches.#value_of(#name) {
+                                Some(value) => Some(#parse(value)#map_err),
+                                None => None
+                            }
+                        }
+                    }
 
-                    Ty::OptionOption => quote_spanned! { ty.span()=>
-                        if matches.is_present(#name) {
-                            Some(matches.#value_of(#name).map(#parse))
-                        } else {
-                            None
+                    Ty::OptionOption => {
+                        quote_spanned! { ty.span()=>
+                            if matches.is_present(#name) {
+                                Some(match matches.#value_of(#name) {
+                                    Some(value) => Some(#parse(value)#map_err),
+                                    None => None
+                                })
+                            } else {
+                                None
+                            }
                         }
-                    },
+                    }
 
-                    Ty::OptionVec => quote_spanned! { ty.span()=>
-                        if matches.is_present(#name) {
-                            Some(matches.#values_of(#name)
-                                 .map(|v| v.map(#parse).collect())
-                                 .unwrap_or_else(Vec::new))
-                        } else {
-                            None
+                    Ty::OptionVec => {
+                        quote_spanned! { ty.span()=>
+                            if matches.is_present(#name) {
+                                Some(match matches.#values_of(#name) {
+                                    Some(values) => {
+                                        let mut parsed = Vec::with_capacity(values.len());
+                                        for v in values {
+                                            parsed.push(#parse(v)#map_err);
+                                        }
+                                        parsed
+                                    }
+                                    None => Vec::new()
+                                })
+                            } else {
+                                None
+                            }
                         }
-                    },
+                    }
 
-                    Ty::Vec => quote_spanned! { ty.span()=>
-                        matches.#values_of(#name)
-                            .map(|v| v.map(#parse).collect())
-                            .unwrap_or_else(Vec::new)
-                    },
+                    Ty::Vec => {
+                        quote_spanned! { ty.span()=>
+                            match matches.#values_of(#name) {
+                                Some(values) => {
+                                    let mut parsed = Vec::with_capacity(values.len());
+                                    for v in values {
+                                        parsed.push(#parse(v)#map_err)
+                                    }
+                                    parsed
+                                }
+                                None => Vec::new()
+                            }
+                        }
+                    }
 
                     Ty::Other if occurrences => quote_spanned! { ty.span()=>
                         #parse(matches.#value_of(#name))
@@ -191,11 +240,11 @@ pub fn gen_constructor(
                         #parse(matches.is_present(#name))
                     },
 
-                    Ty::Other => quote_spanned! { ty.span()=>
-                        matches.#value_of(#name)
-                            .map(#parse)
-                            .unwrap()
-                    },
+                    Ty::Other => {
+                        quote_spanned! {
+                            ty.span()=> #parse(matches.#value_of(#name).unwrap())#map_err
+                        }
+                    }
                 };
 
                 quote_spanned!(field.span()=> #field_name: #field_value )
diff --git a/clap_derive/src/derives/into_app.rs b/clap_derive/src/derives/into_app.rs
index da4c8c1afc93..787b2e2dc937 100644
--- a/clap_derive/src/derives/into_app.rs
+++ b/clap_derive/src/derives/into_app.rs
@@ -182,39 +182,15 @@ pub fn gen_app_augmentation(
                 })
             }
             Kind::Arg(ty) => {
-                let convert_type = match **ty {
-                    Ty::Vec | Ty::Option => sub_type(&field.ty).unwrap_or(&field.ty),
-                    Ty::OptionOption | Ty::OptionVec => {
-                        sub_type(&field.ty).and_then(sub_type).unwrap_or(&field.ty)
-                    }
-                    _ => &field.ty,
-                };
-
-                let occurrences = *attrs.parser().kind == ParserKind::FromOccurrences;
-                let flag = *attrs.parser().kind == ParserKind::FromFlag;
-
                 let parser = attrs.parser();
-                let func = &parser.func;
-                let validator = match *parser.kind {
-                    ParserKind::TryFromStr => quote_spanned! { func.span()=>
-                        .validator(|s| {
-                            #func(s.as_str())
-                            .map(|_: #convert_type| ())
-                            .map_err(|e| e.to_string())
-                        })
-                    },
-                    ParserKind::TryFromOsStr => quote_spanned! { func.span()=>
-                        .validator_os(|s| #func(&s).map(|_: #convert_type| ()))
-                    },
-                    _ => quote!(),
-                };
+                let occurrences = parser.kind == ParserKind::FromOccurrences;
+                let flag = parser.kind == ParserKind::FromFlag;
 
                 let modifier = match **ty {
                     Ty::Bool => quote!(),
 
                     Ty::Option => quote_spanned! { ty.span()=>
                         .takes_value(true)
-                        #validator
                     },
 
                     Ty::OptionOption => quote_spanned! { ty.span()=>
@@ -222,20 +198,17 @@ pub fn gen_app_augmentation(
                         .multiple(false)
                         .min_values(0)
                         .max_values(1)
-                        #validator
                     },
 
                     Ty::OptionVec => quote_spanned! { ty.span()=>
                         .takes_value(true)
                         .multiple(true)
                         .min_values(0)
-                        #validator
                     },
 
                     Ty::Vec => quote_spanned! { ty.span()=>
                         .takes_value(true)
                         .multiple(true)
-                        #validator
                     },
 
                     Ty::Other if occurrences => quote_spanned! { ty.span()=>
@@ -252,7 +225,6 @@ pub fn gen_app_augmentation(
                         quote_spanned! { ty.span()=>
                             .takes_value(true)
                             .required(#required)
-                            #validator
                         }
                     }
                 };
diff --git a/clap_derive/src/derives/subcommand.rs b/clap_derive/src/derives/subcommand.rs
index 7b04eb7a0795..7998a8ccdb29 100644
--- a/clap_derive/src/derives/subcommand.rs
+++ b/clap_derive/src/derives/subcommand.rs
@@ -155,7 +155,7 @@ fn gen_from_subcommand(
 
         quote! {
             (#sub_name, Some(matches)) => {
-                Some(#name :: #variant_name #constructor_block)
+                Ok(#name :: #variant_name #constructor_block)
             }
         }
     });
@@ -165,8 +165,13 @@ fn gen_from_subcommand(
             Unnamed(ref fields) if fields.unnamed.len() == 1 => {
                 let ty = &fields.unnamed[0];
                 quote! {
-                    if let Some(res) = <#ty as ::clap::Subcommand>::from_subcommand(other.0, other.1) {
-                        return Some(#name :: #variant_name (res));
+                    match <#ty as ::clap::Subcommand>::from_subcommand(other.0, other.1) {
+                        Ok(res) => return Ok(#name :: #variant_name (res)),
+                        Err(::clap::Error {
+                            kind: ::clap::ErrorKind::UnrecognizedSubcommand,
+                            ..
+                        }) => {}
+                        Err(e) => return Err(e),
                     }
                 }
             }
@@ -180,13 +185,16 @@ fn gen_from_subcommand(
     quote! {
         fn from_subcommand<'b>(
             name: &'b str,
-            sub: Option<&'b ::clap::ArgMatches>) -> Option<Self>
+            sub: Option<&'b ::clap::ArgMatches>) -> Result<Self, ::clap::Error>
         {
             match (name, sub) {
                 #( #match_arms ),*,
                 other => {
                     #( #child_subcommands )*;
-                    None
+                    Err(::clap::Error::with_description(
+                        format!("The subcommand '{}' wasn't recognized", name),
+                        ::clap::ErrorKind::UnrecognizedSubcommand
+                    ))
                 }
             }
         }
diff --git a/clap_derive/tests/arguments.rs b/clap_derive/tests/arguments.rs
index e9740a4037c3..b7d97c564d0f 100644
--- a/clap_derive/tests/arguments.rs
+++ b/clap_derive/tests/arguments.rs
@@ -79,7 +79,7 @@ fn arguments_safe() {
     );
 
     assert_eq!(
-        clap::ErrorKind::ValueValidation,
+        clap::ErrorKind::InvalidValue,
         Opt::try_parse_from(&["test", "NOPE"]).err().unwrap().kind
     );
 }
diff --git a/clap_derive/tests/effectful.rs b/clap_derive/tests/effectful.rs
new file mode 100644
index 000000000000..90262f2a27f2
--- /dev/null
+++ b/clap_derive/tests/effectful.rs
@@ -0,0 +1,28 @@
+use clap::Clap;
+use std::str::FromStr;
+use std::sync::atomic::{AtomicU32, Ordering::SeqCst};
+
+static NUM_CALLS: AtomicU32 = AtomicU32::new(0);
+
+#[derive(Debug)]
+struct Effectful {}
+
+impl FromStr for Effectful {
+    type Err = String;
+
+    fn from_str(_s: &str) -> Result<Self, Self::Err> {
+        NUM_CALLS.fetch_add(1, SeqCst);
+        Ok(Self {})
+    }
+}
+
+#[derive(Clap, Debug)]
+struct Opt {
+    effectful: Effectful,
+}
+
+#[test]
+fn effectful() {
+    let _opt = Opt::parse_from(&["test", "arg"]);
+    assert_eq!(NUM_CALLS.load(SeqCst), 1);
+}
diff --git a/src/derive.rs b/src/derive.rs
index eba8c1e4b06f..7f6af53e9cc6 100644
--- a/src/derive.rs
+++ b/src/derive.rs
@@ -14,7 +14,7 @@ pub trait Clap: FromArgMatches + IntoApp + Sized {
     /// Parse from `std::env::args()`, return Err on error.
     fn try_parse() -> Result<Self, Error> {
         let matches = <Self as IntoApp>::into_app().try_get_matches()?;
-        Ok(<Self as FromArgMatches>::from_arg_matches(&matches))
+        <Self as FromArgMatches>::try_from_arg_matches(&matches)
     }
 
     /// Parse from iterator, exit on error
@@ -36,7 +36,7 @@ pub trait Clap: FromArgMatches + IntoApp + Sized {
         T: Into<OsString> + Clone,
     {
         let matches = <Self as IntoApp>::into_app().try_get_matches_from(itr)?;
-        Ok(<Self as FromArgMatches>::from_arg_matches(&matches))
+        <Self as FromArgMatches>::try_from_arg_matches(&matches)
     }
 }
 
@@ -53,13 +53,21 @@ pub trait IntoApp: Sized {
 /// Extract values from ArgMatches into the struct.
 pub trait FromArgMatches: Sized {
     /// @TODO @release @docs
-    fn from_arg_matches(matches: &ArgMatches) -> Self;
+    fn from_arg_matches(matches: &ArgMatches) -> Self {
+        match Self::try_from_arg_matches(matches) {
+            Ok(me) => me,
+            Err(e) => e.exit(),
+        }
+    }
+
+    /// @TODO @release @docs
+    fn try_from_arg_matches(matches: &ArgMatches) -> Result<Self, Error>;
 }
 
 /// @TODO @release @docs
 pub trait Subcommand: Sized {
     /// @TODO @release @docs
-    fn from_subcommand(name: &str, matches: Option<&ArgMatches>) -> Option<Self>;
+    fn from_subcommand(name: &str, matches: Option<&ArgMatches>) -> Result<Self, Error>;
     /// @TODO @release @docs
     fn augment_subcommands(app: App<'_>) -> App<'_>;
 }
@@ -108,10 +116,14 @@ impl<T: FromArgMatches> FromArgMatches for Box<T> {
     fn from_arg_matches(matches: &ArgMatches) -> Self {
         Box::new(<T as FromArgMatches>::from_arg_matches(matches))
     }
+
+    fn try_from_arg_matches(matches: &ArgMatches) -> Result<Self, Error> {
+        Ok(Box::new(<T as FromArgMatches>::try_from_arg_matches(matches)?))
+    }
 }
 
 impl<T: Subcommand> Subcommand for Box<T> {
-    fn from_subcommand(name: &str, matches: Option<&ArgMatches>) -> Option<Self> {
+    fn from_subcommand(name: &str, matches: Option<&ArgMatches>) -> Result<Self, Error> {
         <T as Subcommand>::from_subcommand(name, matches).map(Box::new)
     }
     fn augment_subcommands(app: App<'_>) -> App<'_> {