Skip to content

Commit

Permalink
add support for intersection/union types with type parameters from me…
Browse files Browse the repository at this point in the history
…thod definition (#515)
  • Loading branch information
goshacodes authored Apr 6, 2024
1 parent 02c41bd commit 65df9e6
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 2 deletions.
19 changes: 17 additions & 2 deletions shared/src/main/scala-3/org/scalamock/clazz/Utils.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package org.scalamock.clazz

import scala.quoted.*
import org.scalamock.context.MockContext

import scala.annotation.{experimental, tailrec}
private[clazz] class Utils(using val quotes: Quotes):
Expand Down Expand Up @@ -53,6 +52,22 @@ private[clazz] class Utils(using val quotes: Quotes):
case _ =>
tpe

def resolveAndOrTypeParamRefs: TypeRepr =
tpe match {
case AndType(left: ParamRef, right: ParamRef) =>
TypeRepr.of[Any]
case AndType(left: ParamRef, right) =>
right.resolveAndOrTypeParamRefs
case AndType(left, right: ParamRef) =>
left.resolveAndOrTypeParamRefs
case OrType(_: ParamRef, _) =>
TypeRepr.of[Any]
case OrType(_, _: ParamRef) =>
TypeRepr.of[Any]
case other =>
other
}

@experimental
def resolveParamRefs(resType: TypeRepr, methodArgs: List[List[Tree]]) =
tpe match
Expand Down Expand Up @@ -117,7 +132,7 @@ private[clazz] class Utils(using val quotes: Quotes):
.map(_.innerTypeOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = true))
.map { typeRepr =>
val adjusted =
typeRepr.widen.mapParamRefWithWildcard match
typeRepr.widen.mapParamRefWithWildcard.resolveAndOrTypeParamRefs match
case TypeBounds(lower, upper) => upper
case AppliedType(TypeRef(_, "<repeated>"), elemTyps) =>
TypeRepr.typeConstructorOf(classOf[Seq[_]]).appliedTo(elemTyps)
Expand Down
167 changes: 167 additions & 0 deletions shared/src/test/scala-3/com/paulbutcher/test/Scala3Spec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,173 @@ class Scala3Spec extends AnyFunSpec with MockFactory with Matchers {
m.method(1, new A with B) shouldBe 0
}

it("mock intersection type with type parameter from trait") {

trait B

trait C

trait TraitWithGenericIntersection[A] {
def methodWithGenericIntersection(x: A & B): Unit
}

val m = mock[TraitWithGenericIntersection[C]]

val obj = new B with C {}

(m.methodWithGenericIntersection _).expects(obj).returns(())

m.methodWithGenericIntersection(obj)
}

it("mock intersection type with left type parameter from method") {

trait B

trait C

trait TraitWithGenericIntersection {
def methodWithGenericIntersection[A](x: A & B): Unit

def methodWithGenericUnion[A](x: A | B): Unit
}

val m = mock[TraitWithGenericIntersection]

val obj = new C with B {}

(m.methodWithGenericIntersection[C] _).expects(obj).returns(())

m.methodWithGenericIntersection(obj)
}

it("mock intersection type with right type parameter from method") {

trait B

trait C

trait TraitWithGenericIntersection {
def methodWithGenericIntersection[A](x: B & A): Unit
}

val m = mock[TraitWithGenericIntersection]

val obj = new B with C {}

(m.methodWithGenericIntersection[C] _).expects(obj).returns(())

m.methodWithGenericIntersection(obj)
}

it("mock intersection type with both type parameters from method") {

trait B

trait C

trait TraitWithGenericIntersection {
def methodWithGenericIntersection[A, B](x: A & B): Unit
}

val m = mock[TraitWithGenericIntersection]

val obj = new B with C {}

(m.methodWithGenericIntersection[B, C] _).expects(obj).returns(())

m.methodWithGenericIntersection(obj)
}


it("mock intersection type with more then two types from method") {

trait B

trait C

trait D

trait TraitWithGenericIntersection {
def methodWithGenericIntersection[A, B, C](x: A & B & C): Unit
}

val m = mock[TraitWithGenericIntersection]

val obj = new B with C with D {}

(m.methodWithGenericIntersection[B, C, D] _).expects(obj).returns(())

m.methodWithGenericIntersection(obj)
}

it("mock intersection type with more then two types from method, one of witch is stable") {

trait B

trait C

trait D

trait TraitWithGenericIntersection {
def methodWithGenericIntersection[A, B](x: A & D & B): Unit
}

val m = mock[TraitWithGenericIntersection]

val obj = new B with C with D {}

(m.methodWithGenericIntersection[B, C] _).expects(obj).returns(())

m.methodWithGenericIntersection(obj)
}

it("mock union type with left type parameter from method") {

trait B

trait C

trait TraitWithGenericUnion {

def methodWithGenericUnion[A](x: A | B): Unit
}

val m = mock[TraitWithGenericUnion]

val obj1 = new C {}
val obj2 = new B {}

(m.methodWithGenericUnion[C] _).expects(obj1).returns(())
(m.methodWithGenericUnion[C] _).expects(obj2).returns(())

m.methodWithGenericUnion(obj1)
m.methodWithGenericUnion(obj2)
}

it("mock union type with right type parameter from method") {

trait B

trait C

trait TraitWithGenericUnion {

def methodWithGenericUnion[A](x: B | A): Unit
}

val m = mock[TraitWithGenericUnion]

val obj1 = new C {}
val obj2 = new B {}

(m.methodWithGenericUnion[C] _).expects(obj1).returns(())
(m.methodWithGenericUnion[C] _).expects(obj2).returns(())

m.methodWithGenericUnion(obj1)
m.methodWithGenericUnion(obj2)
}

it("mock methods returning function") {
trait Test {
def method(x: Int): Int => String
Expand Down

0 comments on commit 65df9e6

Please sign in to comment.