Skip to content

Commit

Permalink
Replace quoted type variables in signature of HOAS pattern result
Browse files Browse the repository at this point in the history
To be able to construct the lambda returned by the HOAS pattern we need:
first resolve the type variables and then use the result to construct the
signature of the lambdas.

To simplify this transformation, `QuoteMatcher` returns a `Seq[MatchResult]`
instead of an untyped `Tuple` containing `Expr[?]`. The tuple is created
once we have accumulated and processed all extracted values.

Fixes scala#15165
  • Loading branch information
nicolasstucki committed Feb 17, 2023
1 parent d298a3b commit ea2f748
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 39 deletions.
65 changes: 39 additions & 26 deletions compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ object QuoteMatcher {

private def withEnv[T](env: Env)(body: Env ?=> T): T = body(using env)

def treeMatch(scrutineeTerm: Tree, patternTerm: Tree)(using Context): Option[Tuple] =
def treeMatch(scrutineeTerm: Tree, patternTerm: Tree)(using Context): Option[Seq[MatchResult]] =
given Env = Map.empty
scrutineeTerm =?= patternTerm

Expand Down Expand Up @@ -203,31 +203,12 @@ object QuoteMatcher {
// Matches an open term and wraps it into a lambda that provides the free variables
case Apply(TypeApply(Ident(_), List(TypeTree())), SeqLiteral(args, _) :: Nil)
if pattern.symbol.eq(defn.QuotedRuntimePatterns_higherOrderHole) =>
def hoasClosure = {
val names: List[TermName] = args.map {
case Block(List(DefDef(nme.ANON_FUN, _, _, Apply(Ident(name), _))), _) => name.asTermName
case arg => arg.symbol.name.asTermName
}
val argTypes = args.map(x => x.tpe.widenTermRefExpr)
val methTpe = MethodType(names)(_ => argTypes, _ => pattern.tpe)
val meth = newAnonFun(ctx.owner, methTpe)
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
val argsMap = args.map(_.symbol).zip(lambdaArgss.head).toMap
val body = new TreeMap {
override def transform(tree: Tree)(using Context): Tree =
tree match
case tree: Ident => summon[Env].get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
case tree => super.transform(tree)
}.transform(scrutinee)
TreeOps(body).changeNonLocalOwners(meth)
}
Closure(meth, bodyFn)
}
val env = summon[Env]
val capturedArgs = args.map(_.symbol)
val captureEnv = summon[Env].filter((k, v) => !capturedArgs.contains(v))
withEnv(captureEnv) {
scrutinee match
case ClosedPatternTerm(scrutinee) => matched(hoasClosure)
case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, args, env)
case _ => notMatched
}

Expand Down Expand Up @@ -452,20 +433,52 @@ object QuoteMatcher {
accumulator.apply(Set.empty, term)
}

enum MatchResult:
case ClosedTree(tree: Tree)
case OpenTree(tree: Tree, patternTpe: Type, args: List[Tree], env: Env)

def toExpr(mapTypeHoles: TypeMap)(using Context): Expr[Any] = this match
case MatchResult.ClosedTree(tree) =>
new ExprImpl(tree, SpliceScope.getCurrent)
case MatchResult.OpenTree(tree, patternTpe, args, env) =>
def hoasClosure = {
val names: List[TermName] = args.map {
case Block(List(DefDef(nme.ANON_FUN, _, _, Apply(Ident(name), _))), _) => name.asTermName
case arg => arg.symbol.name.asTermName
}
val paramTypes = args.map(x => mapTypeHoles(x.tpe.widenTermRefExpr))
val methTpe = MethodType(names)(_ => paramTypes, _ => mapTypeHoles(patternTpe))
val meth = newAnonFun(ctx.owner, methTpe)
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
val argsMap = args.map(_.symbol).zip(lambdaArgss.head).toMap
val body = new TreeMap {
override def transform(tree: Tree)(using Context): Tree =
tree match
case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
case tree => super.transform(tree)
}.transform(tree)
TreeOps(body).changeNonLocalOwners(meth)
}
Closure(meth, bodyFn)
}
new ExprImpl(hoasClosure, SpliceScope.getCurrent)

/** Result of matching a part of an expression */
private type Matching = Option[Tuple]
private type Matching = Option[Seq[MatchResult]]

private object Matching {

def notMatched: Matching = None

val matched: Matching = Some(Tuple())
val matched: Matching = Some(Seq())

def matched(tree: Tree)(using Context): Matching =
Some(Tuple1(new ExprImpl(tree, SpliceScope.getCurrent)))
Some(Seq(MatchResult.ClosedTree(tree)))

def matchedOpen(tree: Tree, patternTpe: Type, args: List[Tree], env: Env)(using Context): Matching =
Some(Seq(MatchResult.OpenTree(tree, patternTpe, args, env)))

extension (self: Matching)
def asOptionOfTuple: Option[Tuple] = self

/** Concatenates the contents of two successful matchings or return a `notMatched` */
def &&& (that: => Matching): Matching = self match {
Expand Down
33 changes: 20 additions & 13 deletions compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3137,20 +3137,27 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
ctx1.gadtState.addToConstraint(typeHoles)
ctx1

val matchings = QuoteMatcher.treeMatch(scrutinee, pat1)(using ctx1)

if typeHoles.isEmpty then matchings
else {
// After matching and doing all subtype checks, we have to approximate all the type bindings
// that we have found, seal them in a quoted.Type and add them to the result
def typeHoleApproximation(sym: Symbol) =
val fromAboveAnnot = sym.hasAnnotation(dotc.core.Symbols.defn.QuotedRuntimePatterns_fromAboveAnnot)
val fullBounds = ctx1.gadt.fullBounds(sym)
val tp = if fromAboveAnnot then fullBounds.hi else fullBounds.lo
reflect.TypeReprMethods.asType(tp)
matchings.map { tup =>
Tuple.fromIArray(typeHoles.map(typeHoleApproximation).toArray.asInstanceOf[IArray[Object]]) ++ tup
// After matching and doing all subtype checks, we have to approximate all the type bindings
// that we have found, seal them in a quoted.Type and add them to the result
def typeHoleApproximation(sym: Symbol) =
val fromAboveAnnot = sym.hasAnnotation(dotc.core.Symbols.defn.QuotedRuntimePatterns_fromAboveAnnot)
val fullBounds = ctx1.gadt.fullBounds(sym)
if fromAboveAnnot then fullBounds.hi else fullBounds.lo

QuoteMatcher.treeMatch(scrutinee, pat1)(using ctx1).map { matchings =>
import QuoteMatcher.MatchResult.*
lazy val spliceScope = SpliceScope.getCurrent
val typeHoleApproximations = typeHoles.map(typeHoleApproximation)
val typeHoleMapping = Map(typeHoles.zip(typeHoleApproximations)*)
val typeHoleMap = new Types.TypeMap {
def apply(tp: Types.Type): Types.Type = tp match
case Types.TypeRef(Types.NoPrefix, _) => typeHoleMapping.getOrElse(tp.typeSymbol, tp)
case _ => mapOver(tp)
}
val matchedExprs = matchings.map(_.toExpr(typeHoleMap))
val matchedTypes = typeHoleApproximations.map(reflect.TypeReprMethods.asType)
val results = matchedTypes ++ matchedExprs
Tuple.fromIArray(results.toArray.asInstanceOf[IArray[Object]])
}
}

Expand Down
9 changes: 9 additions & 0 deletions tests/pos-macros/i15165a/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import scala.quoted.*

inline def valToFun[T](inline expr: T): T =
${ impl('expr) }

def impl[T: Type](expr: Expr[T])(using quotes: Quotes): Expr[T] =
expr match
case '{ { val ident = ($a: α); $rest(ident): T } } =>
'{ { (y: α) => $rest(y) }.apply(???) }
4 changes: 4 additions & 0 deletions tests/pos-macros/i15165a/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def test = valToFun {
val a: Int = 1
a + 1
}
16 changes: 16 additions & 0 deletions tests/pos-macros/i15165b/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import scala.quoted.*

inline def valToFun[T](inline expr: T): T =
${ impl('expr) }

def impl[T: Type](expr: Expr[T])(using quotes: Quotes): Expr[T] =
expr match
case '{ { val ident = ($a: α); $rest(ident): T } } =>
'{
{ (y: α) =>
${
val bound = '{ ${ rest }(y) }
Expr.betaReduce(bound)
}
}.apply($a)
}
4 changes: 4 additions & 0 deletions tests/pos-macros/i15165b/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def test = valToFun {
val a: Int = 1
a + 1
}
9 changes: 9 additions & 0 deletions tests/pos-macros/i15165c/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import scala.quoted.*

inline def valToFun[T](inline expr: T): T =
${ impl('expr) }

def impl[T: Type](expr: Expr[T])(using quotes: Quotes): Expr[T] =
expr match
case '{ type α; { val ident = ($a: `α`); $rest(ident): `α` & T } } =>
'{ { (y: α) => $rest(y) }.apply(???) }
4 changes: 4 additions & 0 deletions tests/pos-macros/i15165c/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def test = valToFun {
val a: Int = 1
a + 1
}

0 comments on commit ea2f748

Please sign in to comment.