Skip to content

Commit

Permalink
[Issue-#89] Fix call of main methods (refer to the actual method owne…
Browse files Browse the repository at this point in the history
…r not just the method) (#142)

Fixes #89
  • Loading branch information
arainko authored Jul 26, 2024
1 parent 6634d17 commit 9660dd3
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 11 deletions.
21 changes: 10 additions & 11 deletions mainargs/src-3/Macros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,11 @@ object Macros {
val argSigs = Expr.ofList(argSigsExprs)

val invokeRaw: Expr[(B, Seq[Any]) => T] = {
def callOf(args: Expr[Seq[Any]]) = call(method, '{ Seq( ${ args }) }).asExprOf[T]
'{ ((b: B, params: Seq[Any]) => ${ callOf('{ params }) }) }

def callOf(methodOwner: Expr[Any], args: Expr[Seq[Any]]) =
call(methodOwner, method, '{ Seq($args) }).asExprOf[T]

'{ (b: B, params: Seq[Any]) => ${ callOf('b, 'params) } }
}
'{ MainData.create[T, B](${ Expr(method.name) }, ${ mainAnnotation.asExprOf[mainargs.main] }, ${ argSigs }, ${ invokeRaw }) }
}
Expand All @@ -115,8 +118,9 @@ object Macros {
*
*/
private def call(using Quotes)(
method: quotes.reflect.Symbol,
argss: Expr[Seq[Seq[Any]]]
methodOwner: Expr[Any],
method: quotes.reflect.Symbol,
argss: Expr[Seq[Seq[Any]]]
): Expr[_] = {
// Copy pasted from Cask.
// https://github.com/com-lihaoyi/cask/blob/65b9c8e4fd528feb71575f6e5ef7b5e2e16abbd9/cask/src-3/cask/router/Macros.scala#L106
Expand All @@ -127,8 +131,6 @@ object Macros {
report.throwError("At least one parameter list must be declared.", method.pos.get)
}

val fct = Ref(method)

val accesses: List[List[Term]] = for (i <- paramss.indices.toList) yield {
for (j <- paramss(i).indices.toList) yield {
val tpe = paramss(i)(j).tree.asInstanceOf[ValDef].tpt.tpe
Expand All @@ -137,12 +139,9 @@ object Macros {
}
}

val base = Apply(fct, accesses.head)
val application: Apply = accesses.tail.foldLeft(base)((lhs, args) => Apply(lhs, args))
val expr = application.asExpr
expr
methodOwner.asTerm.select(method).appliedToArgss(accesses).asExpr
}


/** Lookup default values for a method's parameters. */
private def getDefaultParams(using Quotes)(method: quotes.reflect.Symbol): Map[quotes.reflect.Symbol, Expr[Any]] = {
Expand Down
26 changes: 26 additions & 0 deletions mainargs/test/src/MultiTraitTests.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package mainargs
import utest._

trait CommandList {
@main
def list(@arg v: String): String = v
}

trait CommandCopy {
@main
def copy(@arg from: String, @arg to: String): (String, String) = (from, to)
}

object Joined extends CommandCopy with CommandList {
@main
def test(@arg from: String, @arg to: String): (String, String) = (from, to)
}

object MultiTraitTests extends TestSuite {
val check = new Checker(ParserForMethods(Joined), allowPositional = true)
val tests = Tests {
test - check(List("copy", "fromArg", "toArg"), Result.Success(("fromArg", "toArg")))
test - check(List("test", "fromArg", "toArg"), Result.Success(("fromArg", "toArg")))
test - check(List("list", "vArg"), Result.Success("vArg"))
}
}

0 comments on commit 9660dd3

Please sign in to comment.