Skip to content

Commit

Permalink
Fix error reporting of incomplete short args (#162)
Browse files Browse the repository at this point in the history
This was accidentally broken in
#102 and
#112, and wasn't covered by
tests. Noticed when trying to update Ammonite to the latest version of
MainArgs in com-lihaoyi/Ammonite#1549

Restored the special casing for tracking/handling incomplete arguments
and added some unit test cases
  • Loading branch information
lihaoyi authored Sep 13, 2024
1 parent 80ef899 commit 936d89e
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 13 deletions.
10 changes: 7 additions & 3 deletions mainargs/src/TokenGrouping.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ object TokenGrouping {
var i = 0
var currentMap = current
var failure = false
var incomplete: Option[ArgSig] = None

while (i < chars.length) {
val c = chars(i)
Expand All @@ -79,6 +80,7 @@ object TokenGrouping {
rest2 match {
case Nil =>
// If there is no next token, it is an error
incomplete = Some(a)
failure = true
case next :: remaining =>
currentMap = Util.appendMap(currentMap, a, next)
Expand All @@ -95,7 +97,7 @@ object TokenGrouping {

}

if (failure) None else Some((rest2, currentMap))
if (failure) Left(incomplete) else Right((rest2, currentMap))
}

def lookupArgMap(k: String, m: Map[String, ArgSig]): Option[(ArgSig, mainargs.TokensReader[_])] = {
Expand All @@ -111,8 +113,10 @@ object TokenGrouping {
// special handling for combined short args of the style "-xvf" or "-j10"
if (head.startsWith("-") && head.lift(1).exists(c => c != '-')){
parseCombinedShortTokens(current, head, rest) match{
case None => complete(remaining, current)
case Some((rest2, currentMap)) => rec(rest2, currentMap)
case Left(Some(incompleteArg)) =>
Result.Failure.MismatchedArguments(Nil, Nil, Nil, incomplete = Some(incompleteArg))
case Left(None) => complete(remaining, current)
case Right((rest2, currentMap)) => rec(rest2, currentMap)
}

} else if (head.startsWith("-") && head.exists(_ != '-')) {
Expand Down
45 changes: 38 additions & 7 deletions mainargs/test/src/CoreTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ object CoreBase {
i: Int,
@arg(doc = "Pass in a custom `s` to override it")
s: String = "lols"
) = s * i
)
= s * i
@main
def baz(arg: Int) = arg

@main
def ex() = throw MyException

Expand Down Expand Up @@ -44,7 +48,10 @@ class CoreTests(allowPositional: Boolean) extends TestSuite {
| qux
| Qux is a function that does stuff
| -i <int>
| -s <str> Pass in a custom `s` to override it
| -s <str> Pass in a custom `s` to override it
|
| baz
| --arg <int>
|
| ex
|""".stripMargin
Expand All @@ -56,27 +63,31 @@ class CoreTests(allowPositional: Boolean) extends TestSuite {

assert(
names ==
List("foo", "bar", "qux", "ex")
List("foo", "bar", "qux", "baz", "ex")
)
val evaledArgs = check.mains.value.map(_.flattenedArgSigs.map {
case (ArgSig(name, s, docs, None, parser, _, _), _) => (s, docs, None, parser)
case (ArgSig(name, s, docs, None, parser, _, _), _) => (name, s, docs, None, parser)
case (ArgSig(name, s, docs, Some(default), parser, _, _), _) =>
(s, docs, Some(default(CoreBase)), parser)
(name, s, docs, Some(default(CoreBase)), parser)
})

assert(
evaledArgs == List(
List(),
List((Some('i'), None, None, TokensReader.IntRead)),
List((None, Some('i'), None, None, TokensReader.IntRead)),
List(
(Some('i'), None, None, TokensReader.IntRead),
(None, Some('i'), None, None, TokensReader.IntRead),
(
None,
Some('s'),
Some("Pass in a custom `s` to override it"),
Some("lols"),
TokensReader.StringRead
)
),
List(
(Some("arg"), None, None, None, TokensReader.IntRead),
),
List()
)
)
Expand Down Expand Up @@ -127,6 +138,26 @@ class CoreTests(allowPositional: Boolean) extends TestSuite {
None
) =>
}
test("incomplete") {
// Make sure both long args and short args properly report
// incomplete arguments as distinct from other mismatches
test - assertMatch(check.parseInvoke(List("qux", "-s"))) {
case Result.Failure.MismatchedArguments(
Nil,
Nil,
Nil,
Some(_)
) =>
}
test - assertMatch(check.parseInvoke(List("baz", "--arg"))) {
case Result.Failure.MismatchedArguments(
Nil,
Nil,
Nil,
Some(_)
) =>
}
}
}

test("tooManyParams") - check(
Expand Down
6 changes: 3 additions & 3 deletions mainargs/test/src/FlagTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,10 @@ object FlagTests extends TestSuite {
test - check(
List("bool", "-ab"),
Result.Failure.MismatchedArguments(
Vector(new ArgSig(None, Some('b'), None, None, TokensReader.BooleanRead, false, false)),
List("-ab"),
Nil,
None
Nil,
Nil,
Some(new ArgSig(None, Some('b'), None, None, TokensReader.BooleanRead, false, false))
)
)

Expand Down

0 comments on commit 936d89e

Please sign in to comment.