Skip to content

Commit

Permalink
Partial support for ParserForClass
Browse files Browse the repository at this point in the history
  • Loading branch information
lolgab committed Feb 28, 2022
1 parent 58b6f17 commit 71c0dd1
Show file tree
Hide file tree
Showing 18 changed files with 406 additions and 268 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ import acyclic.skipped

import scala.language.experimental.macros

private [mainargs] trait ParserForClassCompanionVersionSpecific {
private[mainargs] trait ParserForClassCompanionVersionSpecific {
def apply[T]: ParserForClass[T] = macro Macros.parserForClass[T]
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ import acyclic.skipped

import scala.language.experimental.macros

private [mainargs] trait ParserForMethodsCompanionVersionSpecific {
private[mainargs] trait ParserForMethodsCompanionVersionSpecific {
def apply[B](base: B): ParserForMethods[B] = macro Macros.parserForMethods[B]
}
100 changes: 64 additions & 36 deletions mainargs/src-3/Macros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,17 @@ package mainargs
import scala.quoted._

object Macros {
private def mainAnnotation(using Quotes) = quotes.reflect.TypeRepr.of[mainargs.main].typeSymbol
private def argAnnotation(using Quotes) = quotes.reflect.TypeRepr.of[mainargs.arg].typeSymbol
def parserForMethods[B](base: Expr[B])(using Quotes, Type[B]): Expr[ParserForMethods[B]] = {
import quotes.reflect._
val allMethods = TypeRepr.of[B].typeSymbol.memberMethods
val mainAnnotation = TypeRepr.of[mainargs.main].typeSymbol
val argAnnotation = TypeRepr.of[mainargs.arg].typeSymbol
val annotatedMethodsWithMainAnnotations = allMethods.flatMap { methodSymbol =>
methodSymbol.getAnnotation(mainAnnotation).map(methodSymbol -> _)
}.sortBy(_._1.pos.map(_.start))
val mainDatasExprs: Seq[Expr[MainData[Any, B]]] = annotatedMethodsWithMainAnnotations.map { (annotatedMethod, mainAnnotation) =>
val params = annotatedMethod.paramSymss.headOption.getOrElse(throw new Exception("Multiple parameter lists not supported"))
val defaultParams = getDefaultParams(annotatedMethod)
val argSigs = Expr.ofList(params.map { param =>
val paramTree = param.tree.asInstanceOf[ValDef]
val paramTpe = paramTree.tpt.tpe
val arg = param.getAnnotation(argAnnotation).map(_.asExpr.asInstanceOf[Expr[mainargs.arg]]).getOrElse('{ new mainargs.arg() })
val paramType = paramTpe.asType
paramType match
case '[t] =>
val defaultParam: Expr[Option[B => t]] = defaultParams.get(param) match {
case Some(v) => '{ Some(((_: B) => $v).asInstanceOf[B => t]) }
case None => '{ None }
}
val argReader = Expr.summon[mainargs.ArgReader[t]].getOrElse{
report.error(
s"No mainargs.ArgReader of ${paramTpe.typeSymbol.fullName} found for parameter ${param.name}",
param.pos.get
)
'{ ??? }
}
'{ ArgSig.create[t, B](${ Expr(param.name) }, ${ arg }, ${ defaultParam })(using ${ argReader }).asInstanceOf[mainargs.ArgSig[Any, B]] }
})

val invokeRaw: Expr[(B, Seq[Any]) => Any] = {
def callOf(args: Expr[Seq[Any]]) = call(annotatedMethod, '{ Seq( ${ args }) })
'{ (b: B, params: Seq[Any]) =>
${ callOf('{ params }) }
}
}

'{ MainData.create[Any, B](${ Expr(annotatedMethod.name) }, ${ mainAnnotation.asExprOf[mainargs.main] }, ${ argSigs }, ${ invokeRaw }) }
}
val mainDatas = Expr.ofList(mainDatasExprs)
val mainDatas = Expr.ofList(annotatedMethodsWithMainAnnotations.map { (annotatedMethod, mainAnnotationInstance) =>
createMainData[Any, B](annotatedMethod, mainAnnotationInstance)
})

'{
new ParserForMethods[B](
Expand All @@ -53,6 +22,65 @@ object Macros {
}
}

def parserForClass[B](using Quotes, Type[B]): Expr[ParserForClass[B]] = {
import quotes.reflect._
val typeReprOfB = TypeRepr.of[B]
val companionModule = typeReprOfB match {
case TypeRef(a,b) => TermRef(a,b)
}
val typeSymbolOfB = typeReprOfB.typeSymbol
val companionModuleType = typeSymbolOfB.companionModule.tree.asInstanceOf[ValDef].tpt.tpe.asType
val companionModuleExpr = Ident(companionModule).asExpr
val mainAnnotationInstance = typeSymbolOfB.getAnnotation(mainAnnotation).getOrElse {
report.error(
s"cannot find @main annotation on ${companionModule.name}",
typeSymbolOfB.pos.get
)
???
}
val annotatedMethod = TypeRepr.of[B].typeSymbol.companionModule.memberMethod("apply").head
companionModuleType match
case '[bCompanion] =>
val mainData = createMainData[B, bCompanion](annotatedMethod, mainAnnotationInstance)
'{
new ParserForClass[B](
ClassMains[B](${ mainData }.asInstanceOf[MainData[B, Any]], () => ${ Ident(companionModule).asExpr })
)
}
}

def createMainData[T: Type, B: Type](using Quotes)(method: quotes.reflect.Symbol, annotation: quotes.reflect.Term): Expr[MainData[T, B]] = {
import quotes.reflect.*
val params = method.paramSymss.headOption.getOrElse(throw new Exception("Multiple parameter lists not supported"))
val defaultParams = getDefaultParams(method)
val argSigs = Expr.ofList(params.map { param =>
val paramTree = param.tree.asInstanceOf[ValDef]
val paramTpe = paramTree.tpt.tpe
val arg = param.getAnnotation(argAnnotation).map(_.asExpr.asInstanceOf[Expr[mainargs.arg]]).getOrElse('{ new mainargs.arg() })
val paramType = paramTpe.asType
paramType match
case '[t] =>
val defaultParam: Expr[Option[B => t]] = defaultParams.get(param) match {
case Some(v) => '{ Some(((_: B) => $v).asInstanceOf[B => t]) }
case None => '{ None }
}
val argReader = Expr.summon[mainargs.ArgReader[t]].getOrElse{
report.error(
s"No mainargs.ArgReader of ###companionModule### found for parameter ${param.name}",
param.pos.get
)
'{ ??? }
}
'{ (ArgSig.create[t, B](${ Expr(param.name) }, ${ arg }, ${ defaultParam })(using ${ argReader })).asInstanceOf[ArgSig[Any, B]] }
})

val invokeRaw: Expr[(B, Seq[Any]) => T] = {
def callOf(args: Expr[Seq[Any]]) = call(method, '{ Seq( ${ args }) })
'{ ((b: B, params: Seq[Any]) => ${ callOf('{ params }) }).asInstanceOf[(B, Seq[Any]) => T] }
}
'{ MainData.create[T, B](${ Expr(method.name) }, ${ annotation.asExprOf[mainargs.main] }, ${ argSigs }, ${ invokeRaw }) }
}

/** Call a method given by its symbol.
*
* E.g.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ package mainargs
import scala.language.experimental.macros

private [mainargs] trait ParserForClassCompanionVersionSpecific {
inline def apply[T]: ParserForClass[T] = ???
inline def apply[T]: ParserForClass[T] = ${ Macros.parserForClass[T] }
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ package mainargs

private [mainargs] trait ParserForMethodsCompanionVersionSpecific {
inline def apply[B](base: B) : ParserForMethods[B] = ${ Macros.parserForMethods[B]('base) }
}
}
101 changes: 0 additions & 101 deletions mainargs/test/src-2/ClassTests.scala

This file was deleted.

7 changes: 4 additions & 3 deletions mainargs/test/src-2/OldVarargsTests.scala
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package mainargs
import utest._

object OldVarargsTests extends VarargsTests{
object Base{
object OldVarargsTests extends VarargsTests {
object Base {

@main
def pureVariadic(nums: Int*) = nums.sum

@main
def mixedVariadic(@arg(short = 'f') first: Int, args: String*) = first + args.mkString
def mixedVariadic(@arg(short = 'f') first: Int, args: String*) =
first + args.mkString
}

val check = new Checker(ParserForMethods(Base), allowPositional = true)
Expand Down
53 changes: 0 additions & 53 deletions mainargs/test/src-2/ParserTests.scala

This file was deleted.

5 changes: 5 additions & 0 deletions mainargs/test/src-2/VersionSpecific.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package mainargs

object VersionSpecific {
val isScala3 = false
}
5 changes: 5 additions & 0 deletions mainargs/test/src-3/VersionSpecific.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package mainargs

object VersionSpecific {
val isScala3 = true
}
8 changes: 4 additions & 4 deletions mainargs/test/src-jvm-2/AmmoniteTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ object AmmoniteConfig{
@arg(doc = "Print this message")
help: Flag
)
implicit val coreParser: ParserForClass[Core] = ParserForClass[Core]
implicit val coreParser = ParserForClass[Core]

@main
case class Predef(
Expand All @@ -86,7 +86,7 @@ object AmmoniteConfig{
"choose an additional predef to use using `--predef")
noHomePredef: Flag
)
implicit val predefParser: ParserForClass[Predef] = ParserForClass[Predef]
implicit val predefParser = ParserForClass[Predef]

@main
case class Repl(
Expand All @@ -105,12 +105,12 @@ object AmmoniteConfig{
"friendliness.")
classBased: Flag
)
implicit val replParser: ParserForClass[Repl] = ParserForClass[Repl]
implicit val replParser = ParserForClass[Repl]
}


object AmmoniteTests extends TestSuite{
val parser: ParserForClass[AmmoniteConfig] = ParserForClass[AmmoniteConfig]
val parser = ParserForClass[AmmoniteConfig]
val tests = Tests {


Expand Down
2 changes: 0 additions & 2 deletions mainargs/test/src-jvm-2/MillTests.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// package mainargs
// import utest._


// object MillTests extends TestSuite{

// implicit object PathRead extends TokensReader[os.Path]("path", strs => Right(os.Path(strs.head, os.pwd)))
Expand Down Expand Up @@ -107,4 +106,3 @@
// }
// }
// }

Loading

0 comments on commit 71c0dd1

Please sign in to comment.