Skip to content

Commit

Permalink
surface (fix): Fixes #3355 Surface.methodsOf for methods with multipl…
Browse files Browse the repository at this point in the history
…e argument lists (#3406)
  • Loading branch information
xerial authored Feb 25, 2024
1 parent a2c69a1 commit 186e8a2
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package wvlet.airframe.surface
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.immutable.ListMap
import scala.quoted.*
import scala.reflect.ClassTag

private[surface] object CompileTimeSurfaceFactory:

Expand Down Expand Up @@ -505,12 +506,12 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q):
// Build a table for resolving type parameters, e.g., class MyClass[A, B] -> Map("A" -> TypeRepr, "B" -> TypeRepr)
val typeArgTable: Map[String, TypeRepr] = typeMappingTable(t, method)

val origParamSymss = method.paramSymss

val declaredTypes = t.typeSymbol.declaredTypes.filterNot(_.flags.is(Flags.Module))
val paramss =
if origParamSymss.nonEmpty && declaredTypes.nonEmpty then origParamSymss.tail
else origParamSymss
val paramss: List[List[Symbol]] = method.paramSymss.filter { lst =>
// Empty arg is allowed
lst.isEmpty ||
// Remove type params or implicit ClassTag evidences as MethodSurface can't pass type parameters
!lst.forall(x => x.isTypeParam || (x.flags.is(Flags.Implicit) && x.typeRef <:< TypeRepr.of[ClassTag[_]]))
}

paramss.map { params =>
params.zipWithIndex
Expand Down Expand Up @@ -750,7 +751,7 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q):
val ret = surfaceOf(df.returnTpt.tpe)
// println(s"==== method of: def ${m.name}")
val params = methodParametersOf(targetType, m)
val args = methodArgsOf(targetType, m).flatten
val args = methodArgsOf(targetType, m)
val methodCaller = createMethodCaller(targetType, m, args)
'{
ClassMethodSurface(${ mod }, ${ owner }, ${ name }, ${ ret }, ${ params }.toIndexedSeq, ${ methodCaller })
Expand All @@ -770,7 +771,7 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q):
private def createMethodCaller(
objectType: TypeRepr,
m: Symbol,
methodArgs: Seq[MethodArg]
methodArgss: List[List[MethodArg]]
): Expr[Option[(Any, Seq[Any]) => Any]] =
// Build { (x: Any, args: Seq[Any]) => x.asInstanceOf[t].<method>(.. args) }
val methodTypeParams: List[TypeParamClause] = m.tree match
Expand All @@ -788,11 +789,16 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q):
val x = params(0).asInstanceOf[Term]
val args = params(1).asInstanceOf[Term]
val expr = clsCast(x, objectType).select(m)
val argList = methodArgs.zipWithIndex.collect {
// If the arg is implicit, no need to explicitly bind it
case (arg, i) if !arg.isImplicit =>
val extracted = Select.unique(args, "apply").appliedTo(Literal(IntConstant(i)))
clsCast(extracted, arg.tpe)

var index = 0
val argList: List[List[Term]] = methodArgss.map { lst =>
lst.collect {
// If the arg is implicit, no need to explicitly bind it
case arg if !arg.isImplicit =>
val extracted = Select.unique(args, "apply").appliedTo(Literal(IntConstant(index)))
index += 1
clsCast(extracted, arg.tpe)
}
}
if argList.isEmpty then
val newExpr = m.tree match
Expand All @@ -806,14 +812,14 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q):
else
// Bind to function arguments
val newExpr =
if methodTypeParams.isEmpty then expr.appliedToArgs(argList.toList)
if methodTypeParams.isEmpty then expr.appliedToArgss(argList)
else
// For generic functions, type params also need to be applied
val dummyTypeParams = methodTypeParams.map(x => TypeRepr.of[Any])
// println(s"---> ${m.name} type param count: ${methodTypeParams.size}, arg size: ${argList.size}")
expr
.appliedToTypes(dummyTypeParams)
.appliedToArgs(argList.toList)
.appliedToArgss(argList)
newExpr.changeOwner(sym)
)
'{ Some(${ lambda.asExprOf[(Any, Seq[Any]) => Any] }) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,20 @@
*/
package wvlet.airframe.surface

object GenericMethodTest {
import wvlet.airspec.AirSpec

object GenericMethodTest extends AirSpec {
class A {
def helloX[X](v: X): String = "hello"
}
}

class GenericMethodTest extends SurfaceSpec {
import GenericMethodTest.*

test("generic method") {
val methods = Surface.methodsOf[A]
assertEquals(methods.size, 1)
methods.size shouldBe 1
val m = methods(0)

val obj = new GenericMethodTest.A
assertEquals(m.call(obj, "dummy"), "hello")
m.call(obj, "dummy") shouldBe "hello"
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,20 @@
*/
package wvlet.airframe.surface

object MultipleConstructorArgsTest {
import wvlet.airspec.AirSpec

object MultipleConstructorArgsTest extends AirSpec {
case class MultiC(a: Int)(implicit val s: String) {
def msg: String = s"${a}:${s}"
}
}

import MultipleConstructorArgsTest.*
class MultipleConstructorArgsTest extends SurfaceSpec {
test("support muliple constructor args") {
val s: Surface = Surface.of[MultiC]
assert(s.objectFactory.nonEmpty)
s.objectFactory shouldBe defined

val f = s.objectFactory.get
assert(s.params.size == 2)
s.params.size shouldBe 2
val i = f.newInstance(Seq(1, "hello"))
assert(i.asInstanceOf[MultiC].msg == "1:hello")
i.asInstanceOf[MultiC].msg shouldBe "1:hello"
}
}
26 changes: 26 additions & 0 deletions airframe-surface/src/test/scala/wvlet/airframe/surface/i3355.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package wvlet.airframe.surface

import wvlet.airspec.AirSpec

object i3355 extends AirSpec {
case class ParamLists() {
def multiParam()(): Unit = ()
}

test("find methods of multiple method params") {
Surface.methodsOf[ParamLists]
}
}

0 comments on commit 186e8a2

Please sign in to comment.