Skip to content

Commit

Permalink
Partial Support for Scala 3
Browse files Browse the repository at this point in the history
This implements the support for the `@main` annotated methods in Scala 3.
While `ParserForMethods` is implemented and passes all tests,
`ParserForClass` is not implemented yet.
  • Loading branch information
lolgab committed Feb 27, 2022
1 parent b1bdb8b commit 2c910b4
Show file tree
Hide file tree
Showing 20 changed files with 275 additions and 85 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ name: ci

on:
push:
branches:
- master
tags:
- '*'
pull_request:
branches:
- master
Expand Down
53 changes: 30 additions & 23 deletions build.sc
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
import mill._, scalalib._, scalajslib._, scalanativelib._, publish._
import mill.scalalib.api.Util.isScala3
import scalalib._
import $ivy.`de.tototec::de.tobiasroeser.mill.vcs.version_mill0.9:0.1.1`
import $ivy.`de.tototec::de.tobiasroeser.mill.vcs.version::0.1.1`
import de.tobiasroeser.mill.vcs.version.VcsVersion
import $ivy.`com.github.lolgab::mill-mima_mill0.9:0.0.4`
import $ivy.`com.github.lolgab::mill-mima::0.0.4`
import com.github.lolgab.mill.mima._

val scala212 = "2.12.13"
val scala213 = "2.13.4"
val scala3 = "3.0.2"

val scalaVersions = Seq(scala212, scala213, scala3)
val scala2Versions = scalaVersions.filter(_.startsWith("2."))

val scalaJSVersions = for {
scalaV <- Seq(scala213, scala212)
scalaV <- scalaVersions
scalaJSV <- Seq("1.4.0")
} yield (scalaV, scalaJSV)

val scalaNativeVersions = for {
scalaV <- Seq(scala213, scala212)
scalaV <- scala2Versions
scalaNativeV <- Seq("0.4.0")
} yield (scalaV, scalaNativeV)

Expand All @@ -39,48 +44,50 @@ trait MainArgsPublishModule extends PublishModule with CrossScalaModule with Mim
)
)

def scalacOptions = super.scalacOptions() ++ Seq("-P:acyclic:force")
def scalacOptions = super.scalacOptions() ++ (if (!isScala3(crossScalaVersion)) Seq("-P:acyclic:force") else Seq.empty)

def scalacPluginIvyDeps = super.scalacPluginIvyDeps() ++ Agg(ivy"com.lihaoyi::acyclic:0.2.0")
def scalacPluginIvyDeps = super.scalacPluginIvyDeps() ++ (if (!isScala3(crossScalaVersion)) Agg(ivy"com.lihaoyi::acyclic:0.2.0") else Agg.empty)

def compileIvyDeps = super.compileIvyDeps() ++ Agg(
ivy"com.lihaoyi::acyclic:0.2.0",
ivy"org.scala-lang:scala-reflect:$crossScalaVersion"
)
def compileIvyDeps = super.compileIvyDeps() ++ (if (!isScala3(crossScalaVersion)) Agg(
ivy"com.lihaoyi::acyclic:0.2.0",
ivy"org.scala-lang:scala-reflect:$crossScalaVersion"
) else Agg.empty)

def ivyDeps = Agg(
ivy"org.scala-lang.modules::scala-collection-compat::2.4.0"
)
ivy"org.scala-lang.modules::scala-collection-compat::2.4.4"
) ++ Agg(ivy"com.lihaoyi::pprint:0.6.6")
}

trait Common extends CrossScalaModule {
def millSourcePath = build.millSourcePath / "mainargs"
def sources = T.sources(
millSourcePath / "src",
millSourcePath / s"src-$platform"
super.sources() ++ Seq(PathRef(millSourcePath / s"src-$platform"))
)
def platform: String
}

trait CommonTestModule extends ScalaModule with TestModule {
def ivyDeps = Agg(ivy"com.lihaoyi::utest::0.7.6")
def testFrameworks = Seq("utest.runner.Framework")
def sources = T.sources(
millSourcePath / "src",
millSourcePath / s"src-$platform"
)
trait CommonTestModule extends ScalaModule with TestModule.Utest {
def ivyDeps = Agg(ivy"com.lihaoyi::utest::0.7.10")
def sources = T.sources {
val scalaMajor = if(isScala3(scalaVersion())) "3" else "2"
super.sources() ++ Seq(
millSourcePath / "src",
millSourcePath / s"src-$platform",
millSourcePath / s"src-$platform-$scalaMajor"
).map(PathRef(_))
}
def platform: String
}


object mainargs extends Module {
object jvm extends Cross[JvmMainArgsModule](scala212, scala213)
object jvm extends Cross[JvmMainArgsModule](scalaVersions: _*)
class JvmMainArgsModule(val crossScalaVersion: String)
extends Common with ScalaModule with MainArgsPublishModule {
def platform = "jvm"
object test extends Tests with CommonTestModule{
def platform = "jvm"
def ivyDeps = super.ivyDeps() ++ Agg(ivy"com.lihaoyi::os-lib:0.7.1")
def ivyDeps = super.ivyDeps() ++ Agg(ivy"com.lihaoyi::os-lib:0.7.8")
}
}

Expand Down
File renamed without changes.
9 changes: 9 additions & 0 deletions mainargs/src-2/ParserForClassCompanionVersionSpecific.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package mainargs

import acyclic.skipped

import scala.language.experimental.macros

private [mainargs] trait ParserForClassCompanionVersionSpecific {
def apply[T]: ParserForClass[T] = macro Macros.parserForClass[T]
}
9 changes: 9 additions & 0 deletions mainargs/src-2/ParserForMethodsCompanionVersionSpecific.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package mainargs

import acyclic.skipped

import scala.language.experimental.macros

private [mainargs] trait ParserForMethodsCompanionVersionSpecific {
def apply[B](base: B): ParserForMethods[B] = macro Macros.parserForMethods[B]
}
142 changes: 142 additions & 0 deletions mainargs/src-3/Macros.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package mainargs

import scala.quoted._

object Macros {
def parserForMethods[B](base: Expr[B])(using Quotes, Type[B]): Expr[ParserForMethods[B]] = {
import quotes.reflect._
val allMethods = TypeRepr.of[B].typeSymbol.memberMethods
val mainAnnotation = TypeRepr.of[mainargs.main].typeSymbol
val argAnnotation = TypeRepr.of[mainargs.arg].typeSymbol
val annotatedMethodsWithMainAnnotations = allMethods.flatMap { methodSymbol =>
methodSymbol.getAnnotation(mainAnnotation).map(methodSymbol -> _)
}.sortBy(_._1.pos.map(_.start))
val mainDatasExprs: Seq[Expr[MainData[Any, B]]] = annotatedMethodsWithMainAnnotations.map { (annotatedMethod, mainAnnotation) =>
val doc: Option[String] = mainAnnotation match {
case Apply(_, args) => args.collectFirst {
case NamedArg("doc", Literal(constant)) => constant.value.asInstanceOf[String]
}
case _ => None
}
val params = annotatedMethod.paramSymss.headOption.getOrElse(throw new Exception("Multiple parameter lists not supported"))
val defaultParams = getDefaultParams(annotatedMethod)
val argSigs = Expr.ofList(params.map { param =>
val paramTree = param.tree.asInstanceOf[ValDef]
val paramTpe = paramTree.tpt.tpe
val arg = param.getAnnotation(argAnnotation).map(_.asExpr.asInstanceOf[Expr[mainargs.arg]]).getOrElse('{ new mainargs.arg() })
val paramType = paramTpe.asType
paramType match
case '[t] =>
val defaultParam: Expr[Option[B => t]] = defaultParams.get(param) match {
case Some(v) => '{ Some(((_: B) => $v).asInstanceOf[B => t]) }
case None => '{ None }
}
val argReader = Expr.summon[mainargs.ArgReader[t]].getOrElse{
report.error(
s"No mainargs.ArgReader of ${paramTpe.typeSymbol.fullName} found for parameter ${param.name}",
param.pos.get
)
'{ ??? }
}
'{ ArgSig.create[t, B](${ Expr(param.name) }, ${ arg }, ${ defaultParam })(using ${ argReader }) }
case '[mainargs.Leftover[t]] =>
val argReader = Expr.summon[ArgReader.Leftover[t]].getOrElse{
report.error(
s"No mainargs.ArgReader of ${paramTpe.typeSymbol.fullName} found for parameter ${param.name}",
param.pos.get
)
'{ ??? }
}
'{ ArgSig.createVararg[t, B](${ Expr(param.name) }, ${ arg })(using ${ argReader }) }
})

val invokeRaw: Expr[(B, Seq[Any]) => Any] = {
def callOf(args: Expr[Seq[Any]]) = call(annotatedMethod, '{ Seq( ${ args }) })
'{ (b: B, params: Seq[Any]) =>
${ callOf('{ params }) }
}
}

'{ MainData[Any, B](${ Expr(annotatedMethod.name) }, ${ argSigs }, ${Expr(doc)}, ${ invokeRaw }) }
}
val mainDatas = Expr.ofList(mainDatasExprs)

'{
new ParserForMethods[B](
MethodMains[B](${ mainDatas }, () => ${ base })
)
}
}

/** Call a method given by its symbol.
*
* E.g.
*
* assuming:
*
* def foo(x: Int, y: String)(z: Int)
*
* val argss: List[List[Any]] = ???
*
* then:
*
* call(<symbol of foo>, '{argss})
*
* will expand to:
*
* foo(argss(0)(0), argss(0)(1))(argss(1)(0))
*
*/
private def call(using Quotes)(
method: quotes.reflect.Symbol,
argss: Expr[Seq[Seq[Any]]]
): Expr[_] = {
// Copy pasted from Cask.
// https://github.com/com-lihaoyi/cask/blob/65b9c8e4fd528feb71575f6e5ef7b5e2e16abbd9/cask/src-3/cask/router/Macros.scala#L106
import quotes.reflect._
val paramss = method.paramSymss

if (paramss.isEmpty) {
report.error("At least one parameter list must be declared.", method.pos.get)
return '{???}
}

val fct = Ref(method)

val accesses: List[List[Term]] = for (i <- paramss.indices.toList) yield {
for (j <- paramss(i).indices.toList) yield {
val tpe = paramss(i)(j).tree.asInstanceOf[ValDef].tpt.tpe
tpe.asType match
case '[t] => '{ $argss(${Expr(i)})(${Expr(j)}).asInstanceOf[t] }.asTerm
}
}

val base = Apply(fct, accesses.head)
val application: Apply = accesses.tail.foldLeft(base)((lhs, args) => Apply(lhs, args))
val expr = application.asExpr
expr
}


/** Lookup default values for a method's parameters. */
private def getDefaultParams(using Quotes)(method: quotes.reflect.Symbol): Map[quotes.reflect.Symbol, Expr[Any]] = {
// Copy pasted from Cask.
// https://github.com/com-lihaoyi/cask/blob/65b9c8e4fd528feb71575f6e5ef7b5e2e16abbd9/cask/src-3/cask/router/Macros.scala#L38
import quotes.reflect._

val params = method.paramSymss.flatten
val defaults = collection.mutable.Map.empty[Symbol, Expr[Any]]

val Name = (method.name + """\$default\$(\d+)""").r

val idents = method.owner.tree.asInstanceOf[ClassDef].body
idents.foreach{
case deff @ DefDef(Name(idx), _, _, _) =>
val expr = Ref(deff.symbol).asExpr
defaults += (params(idx.toInt - 1) -> expr)
case _ =>
}

defaults.toMap
}
}
7 changes: 7 additions & 0 deletions mainargs/src-3/ParserForClassCompanionVersionSpecific.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package mainargs

import scala.language.experimental.macros

private [mainargs] trait ParserForClassCompanionVersionSpecific {
inline def apply[T]: ParserForClass[T] = ???
}
5 changes: 5 additions & 0 deletions mainargs/src-3/ParserForMethodsCompanionVersionSpecific.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package mainargs

private [mainargs] trait ParserForMethodsCompanionVersionSpecific {
inline def apply[B](base: B) : ParserForMethods[B] = ${ Macros.parserForMethods[B]('base) }
}
3 changes: 3 additions & 0 deletions mainargs/src-3/acyclic.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package acyclic

def skipped = ???
12 changes: 6 additions & 6 deletions mainargs/src/Parser.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package mainargs

import acyclic.skipped

import scala.language.experimental.macros
import java.io.PrintStream
object ParserForMethods{
def apply[B](base: B): ParserForMethods[B] = macro Macros.parserForMethods[B]
}

object ParserForMethods extends ParserForMethodsCompanionVersionSpecific
class ParserForMethods[B](val mains: MethodMains[B]){
def helpText(totalWidth: Int = 100,
docsOnNewLine: Boolean = false,
Expand Down Expand Up @@ -102,9 +104,7 @@ class ParserForMethods[B](val mains: MethodMains[B]){
}
}

object ParserForClass{
def apply[T]: ParserForClass[T] = macro Macros.parserForClass[T]
}
object ParserForClass extends ParserForClassCompanionVersionSpecific
class ParserForClass[T](val mains: ClassMains[T]) extends SubParser[T]{
def helpText(totalWidth: Int = 100,
docsOnNewLine: Boolean = false,
Expand Down
6 changes: 3 additions & 3 deletions mainargs/src/TokensReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ object TokensReader{
implicit object FloatRead extends TokensReader[Float]("float", strs => tryEither(strs.last.toFloat))
implicit object DoubleRead extends TokensReader[Double]("double", strs => tryEither(strs.last.toDouble))

implicit def OptionRead[T: TokensReader] = new TokensReader[Option[T]](
implicit def OptionRead[T: TokensReader]: TokensReader[Option[T]] = new TokensReader[Option[T]](
implicitly[TokensReader[T]].shortName,
strs => {
strs.lastOption match{
Expand All @@ -31,7 +31,7 @@ object TokensReader{
},
allowEmpty = true
)
implicit def SeqRead[C[_] <: Iterable[_], T: TokensReader](implicit factory: Factory[T, C[T]]) = new TokensReader[C[T]](
implicit def SeqRead[C[_] <: Iterable[_], T: TokensReader](implicit factory: Factory[T, C[T]]): TokensReader[C[T]] = new TokensReader[C[T]](
implicitly[TokensReader[T]].shortName,
strs => {
strs
Expand All @@ -50,7 +50,7 @@ object TokensReader{
alwaysRepeatable = true,
allowEmpty = true
)
implicit def MapRead[K: TokensReader, V: TokensReader] = new TokensReader[Map[K, V]](
implicit def MapRead[K: TokensReader, V: TokensReader]: TokensReader[Map[K, V]] = new TokensReader[Map[K, V]](
"k=v",
strs => {
strs.foldLeft[Either[String, Map[K, V]]](Right(Map())){
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ object ClassTests extends TestSuite{
@main
case class Qux(moo: String, b: Bar)

implicit val fooParser = ParserForClass[Foo]
implicit val barParser = ParserForClass[Bar]
implicit val quxParser = ParserForClass[Qux]
implicit val fooParser: ParserForClass[Foo] = ParserForClass[Foo]
implicit val barParser: ParserForClass[Bar] = ParserForClass[Bar]
implicit val quxParser: ParserForClass[Qux] = ParserForClass[Qux]

object Main{
@main
def run(bar: Bar,
bool: Boolean = false) = {
bar.w.value + " " + bar.f.x + " " + bar.f.y + " " + bar.zzzz + " " + bool
s"${bar.w.value} ${bar.f.x} ${bar.f.y} ${bar.zzzz} $bool"
}
}

Expand Down Expand Up @@ -99,4 +99,3 @@ object ClassTests extends TestSuite{
}
}
}

File renamed without changes.
16 changes: 16 additions & 0 deletions mainargs/test/src-2/OldVarargsTests.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package mainargs
import utest._

object OldVarargsTests extends VarargsTests{
object Base{

@main
def pureVariadic(nums: Int*) = nums.sum

@main
def mixedVariadic(@arg(short = 'f') first: Int, args: String*) = first + args.mkString
}

val check = new Checker(ParserForMethods(Base), allowPositional = true)
val isNewVarargsTests = false
}
File renamed without changes.
Loading

0 comments on commit 2c910b4

Please sign in to comment.