From 186e8a2136868343ce03be2a17d4685ea3ff8e3c Mon Sep 17 00:00:00 2001 From: "Taro L. Saito" Date: Sun, 25 Feb 2024 00:43:33 -0800 Subject: [PATCH] surface (fix): Fixes #3355 Surface.methodsOf for methods with multiple argument lists (#3406) --- .../surface/CompileTimeSurfaceFactory.scala | 36 +++++++++++-------- .../airframe/surface/GenericMethodTest.scala | 12 +++---- .../surface/MultipleConstructorArgsTest.scala | 13 ++++--- .../scala/wvlet/airframe/surface/i3355.scala | 26 ++++++++++++++ 4 files changed, 58 insertions(+), 29 deletions(-) create mode 100644 airframe-surface/src/test/scala/wvlet/airframe/surface/i3355.scala diff --git a/airframe-surface/src/main/scala-3/wvlet/airframe/surface/CompileTimeSurfaceFactory.scala b/airframe-surface/src/main/scala-3/wvlet/airframe/surface/CompileTimeSurfaceFactory.scala index 0184f77a8c..ef61cfcc88 100644 --- a/airframe-surface/src/main/scala-3/wvlet/airframe/surface/CompileTimeSurfaceFactory.scala +++ b/airframe-surface/src/main/scala-3/wvlet/airframe/surface/CompileTimeSurfaceFactory.scala @@ -2,6 +2,7 @@ package wvlet.airframe.surface import java.util.concurrent.atomic.AtomicInteger import scala.collection.immutable.ListMap import scala.quoted.* +import scala.reflect.ClassTag private[surface] object CompileTimeSurfaceFactory: @@ -505,12 +506,12 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q): // Build a table for resolving type parameters, e.g., class MyClass[A, B] -> Map("A" -> TypeRepr, "B" -> TypeRepr) val typeArgTable: Map[String, TypeRepr] = typeMappingTable(t, method) - val origParamSymss = method.paramSymss - - val declaredTypes = t.typeSymbol.declaredTypes.filterNot(_.flags.is(Flags.Module)) - val paramss = - if origParamSymss.nonEmpty && declaredTypes.nonEmpty then origParamSymss.tail - else origParamSymss + val paramss: List[List[Symbol]] = method.paramSymss.filter { lst => + // Empty arg is allowed + lst.isEmpty || + // Remove type params or implicit ClassTag evidences as MethodSurface can't pass type parameters + !lst.forall(x => x.isTypeParam || (x.flags.is(Flags.Implicit) && x.typeRef <:< TypeRepr.of[ClassTag[_]])) + } paramss.map { params => params.zipWithIndex @@ -750,7 +751,7 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q): val ret = surfaceOf(df.returnTpt.tpe) // println(s"==== method of: def ${m.name}") val params = methodParametersOf(targetType, m) - val args = methodArgsOf(targetType, m).flatten + val args = methodArgsOf(targetType, m) val methodCaller = createMethodCaller(targetType, m, args) '{ ClassMethodSurface(${ mod }, ${ owner }, ${ name }, ${ ret }, ${ params }.toIndexedSeq, ${ methodCaller }) @@ -770,7 +771,7 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q): private def createMethodCaller( objectType: TypeRepr, m: Symbol, - methodArgs: Seq[MethodArg] + methodArgss: List[List[MethodArg]] ): Expr[Option[(Any, Seq[Any]) => Any]] = // Build { (x: Any, args: Seq[Any]) => x.asInstanceOf[t].(.. args) } val methodTypeParams: List[TypeParamClause] = m.tree match @@ -788,11 +789,16 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q): val x = params(0).asInstanceOf[Term] val args = params(1).asInstanceOf[Term] val expr = clsCast(x, objectType).select(m) - val argList = methodArgs.zipWithIndex.collect { - // If the arg is implicit, no need to explicitly bind it - case (arg, i) if !arg.isImplicit => - val extracted = Select.unique(args, "apply").appliedTo(Literal(IntConstant(i))) - clsCast(extracted, arg.tpe) + + var index = 0 + val argList: List[List[Term]] = methodArgss.map { lst => + lst.collect { + // If the arg is implicit, no need to explicitly bind it + case arg if !arg.isImplicit => + val extracted = Select.unique(args, "apply").appliedTo(Literal(IntConstant(index))) + index += 1 + clsCast(extracted, arg.tpe) + } } if argList.isEmpty then val newExpr = m.tree match @@ -806,14 +812,14 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q): else // Bind to function arguments val newExpr = - if methodTypeParams.isEmpty then expr.appliedToArgs(argList.toList) + if methodTypeParams.isEmpty then expr.appliedToArgss(argList) else // For generic functions, type params also need to be applied val dummyTypeParams = methodTypeParams.map(x => TypeRepr.of[Any]) // println(s"---> ${m.name} type param count: ${methodTypeParams.size}, arg size: ${argList.size}") expr .appliedToTypes(dummyTypeParams) - .appliedToArgs(argList.toList) + .appliedToArgss(argList) newExpr.changeOwner(sym) ) '{ Some(${ lambda.asExprOf[(Any, Seq[Any]) => Any] }) } diff --git a/airframe-surface/src/test/scala/wvlet/airframe/surface/GenericMethodTest.scala b/airframe-surface/src/test/scala/wvlet/airframe/surface/GenericMethodTest.scala index baa61d8b15..68e0d3606d 100644 --- a/airframe-surface/src/test/scala/wvlet/airframe/surface/GenericMethodTest.scala +++ b/airframe-surface/src/test/scala/wvlet/airframe/surface/GenericMethodTest.scala @@ -13,22 +13,20 @@ */ package wvlet.airframe.surface -object GenericMethodTest { +import wvlet.airspec.AirSpec + +object GenericMethodTest extends AirSpec { class A { def helloX[X](v: X): String = "hello" } -} - -class GenericMethodTest extends SurfaceSpec { - import GenericMethodTest.* test("generic method") { val methods = Surface.methodsOf[A] - assertEquals(methods.size, 1) + methods.size shouldBe 1 val m = methods(0) val obj = new GenericMethodTest.A - assertEquals(m.call(obj, "dummy"), "hello") + m.call(obj, "dummy") shouldBe "hello" } } diff --git a/airframe-surface/src/test/scala/wvlet/airframe/surface/MultipleConstructorArgsTest.scala b/airframe-surface/src/test/scala/wvlet/airframe/surface/MultipleConstructorArgsTest.scala index 013b1a7739..7aa0f2c7d9 100644 --- a/airframe-surface/src/test/scala/wvlet/airframe/surface/MultipleConstructorArgsTest.scala +++ b/airframe-surface/src/test/scala/wvlet/airframe/surface/MultipleConstructorArgsTest.scala @@ -13,21 +13,20 @@ */ package wvlet.airframe.surface -object MultipleConstructorArgsTest { +import wvlet.airspec.AirSpec + +object MultipleConstructorArgsTest extends AirSpec { case class MultiC(a: Int)(implicit val s: String) { def msg: String = s"${a}:${s}" } -} -import MultipleConstructorArgsTest.* -class MultipleConstructorArgsTest extends SurfaceSpec { test("support muliple constructor args") { val s: Surface = Surface.of[MultiC] - assert(s.objectFactory.nonEmpty) + s.objectFactory shouldBe defined val f = s.objectFactory.get - assert(s.params.size == 2) + s.params.size shouldBe 2 val i = f.newInstance(Seq(1, "hello")) - assert(i.asInstanceOf[MultiC].msg == "1:hello") + i.asInstanceOf[MultiC].msg shouldBe "1:hello" } } diff --git a/airframe-surface/src/test/scala/wvlet/airframe/surface/i3355.scala b/airframe-surface/src/test/scala/wvlet/airframe/surface/i3355.scala new file mode 100644 index 0000000000..3576173679 --- /dev/null +++ b/airframe-surface/src/test/scala/wvlet/airframe/surface/i3355.scala @@ -0,0 +1,26 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package wvlet.airframe.surface + +import wvlet.airspec.AirSpec + +object i3355 extends AirSpec { + case class ParamLists() { + def multiParam()(): Unit = () + } + + test("find methods of multiple method params") { + Surface.methodsOf[ParamLists] + } +}