Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Put hierarchy checks for test trait behind an overridable def #2876

Merged
merged 5 commits into from
Nov 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions scalajslib/test/src/mill/scalajslib/ScalaTestsErrorTests.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package mill.scalajslib

import mill._
import mill.define.Discover
import mill.scalalib.TestModule
import mill.util.TestUtil
Expand All @@ -12,8 +11,10 @@ object ScalaTestsErrorTests extends TestSuite {
def scalaVersion = sys.props.getOrElse("TEST_SCALA_3_3_VERSION", ???)
def scalaJSVersion = sys.props.getOrElse("TEST_SCALAJS_VERSION", ???)
object test extends ScalaTests with TestModule.Utest
object testDisabledError extends ScalaTests with TestModule.Utest {
override def hierarchyChecks(): Unit = {}
}
}

override lazy val millDiscover = Discover[this.type]
}

Expand All @@ -24,8 +25,12 @@ object ScalaTestsErrorTests extends TestSuite {
}
val message = error.getCause.getMessage
assert(
message == s"scalaTestsError is a `ScalaJSModule`. scalaTestsError.test needs to extend `ScalaJSTests`."
message == s"scalaTestsError is a `mill.scalajslib.ScalaJSModule`. scalaTestsError.test needs to extend `ScalaJSTests`."
)
}
test("extends-ScalaTests-disabled-hierarchy-check") {
// expect no throws exception
ScalaTestsError.scalaTestsError.testDisabledError
}
}
}
33 changes: 32 additions & 1 deletion scalalib/src/mill/scalalib/JavaModule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import coursier.parse.ModuleParser
import coursier.util.ModuleMatcher
import mainargs.Flag
import mill.Agg
import mill.api.{Ctx, JarManifest, PathRef, Result, internal}
import mill.api.{Ctx, JarManifest, MillException, PathRef, Result, internal}
import mill.define.{Command, ModuleRef, Segment, Task, TaskModule}
import mill.scalalib.internal.ModuleUtils
import mill.scalalib.api.CompilationResult
Expand All @@ -34,6 +34,9 @@ trait JavaModule
def zincWorker: ModuleRef[ZincWorkerModule] = ModuleRef(mill.scalalib.ZincWorkerModule)

trait JavaModuleTests extends JavaModule with TestModule {
// Run some consistence checks
hierarchyChecks()

override def moduleDeps: Seq[JavaModule] = Seq(outer)
override def repositoriesTask: Task[Seq[Repository]] = outer.repositoriesTask
override def resolutionCustomizer: Task[Option[coursier.Resolution => coursier.Resolution]] =
Expand All @@ -47,6 +50,34 @@ trait JavaModule
PathRef(this.millSourcePath / src.path.relativeTo(outer.millSourcePath))
}
}

/**
* JavaModule and its derivates define inner test modules.
* To avoid unexpected misbehavior due to the use of the wrong inner test trait
* we apply some hierarchy consistency checks.
* If for some reasons, those are too restrictive to you, you can override this method.
* @throws MillException
*/
protected def hierarchyChecks(): Unit = {
val outerInnerSets = Seq(
("mill.scalajslib.ScalaJSModule", "ScalaJSTests"),
("mill.scalanativelib.ScalaNativeModule", "ScalaNativeTests"),
("mill.scalalib.SbtModule", "SbtModuleTests"),
("mill.scalalib.MavenModule", "MavenModuleTests")
)
for {
(mod, testModShort) <- outerInnerSets
testMod = s"${mod}$$${testModShort}"
}
try {
if (Class.forName(mod).isInstance(outer) && !Class.forName(testMod).isInstance(this))
throw new MillException(
s"$outer is a `${mod}`. $this needs to extend `${testModShort}`."
)
} catch {
case _: ClassNotFoundException => // if we can't find the classes, we certainly are not in a ScalaJSModule
}
}
}

def defaultCommandName(): String = "run"
Expand Down
33 changes: 1 addition & 32 deletions scalalib/src/mill/scalalib/ScalaModule.scala
Original file line number Diff line number Diff line change
@@ -1,15 +1,7 @@
package mill
package scalalib

import mill.api.{
DummyInputStream,
JarManifest,
MillException,
PathRef,
Result,
SystemStreams,
internal
}
import mill.api.{DummyInputStream, JarManifest, PathRef, Result, SystemStreams, internal}
import mill.main.BuildInfo
import mill.util.{Jvm, Util}
import mill.util.Jvm.createJar
Expand All @@ -28,29 +20,6 @@ trait ScalaModule extends JavaModule with TestModule.ScalaModuleBase { outer =>
type ScalaModuleTests = ScalaTests

trait ScalaTests extends JavaModuleTests with ScalaModule {
try {
if (
Class.forName("mill.scalajslib.ScalaJSModule").isInstance(outer) && !Class.forName(
"mill.scalajslib.ScalaJSModule$ScalaJSTests"
).isInstance(this)
) throw new MillException(
s"$outer is a `ScalaJSModule`. $this needs to extend `ScalaJSTests`."
)
} catch {
case _: ClassNotFoundException => // if we can't find the classes, we certainly are not in a ScalaJSModule
}
try {
if (
Class.forName("mill.scalanativelib.ScalaNativeModule").isInstance(outer) && !Class.forName(
"mill.scalanativelib.ScalaNativeModule$ScalaNativeTests"
).isInstance(this)
) throw new MillException(
s"$outer is a `ScalaNativeModule`. $this needs to extend `ScalaNativeTests`."
)
} catch {
case _: ClassNotFoundException => // if we can't find the classes, we certainly are not in a ScalaNativeModule
}

override def scalaOrganization: Target[String] = outer.scalaOrganization()
override def scalaVersion: Target[String] = outer.scalaVersion()
override def scalacPluginIvyDeps: Target[Agg[Dep]] = outer.scalacPluginIvyDeps()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ object ScalaTestsErrorTests extends TestSuite {
def scalaVersion = sys.props.getOrElse("TEST_SCALA_3_3_VERSION", ???)
def scalaNativeVersion = sys.props.getOrElse("TEST_SCALANATIVE_VERSION", ???)
object test extends ScalaTests with TestModule.Utest
object testDisabledError extends ScalaTests with TestModule.Utest {
override def hierarchyChecks(): Unit = {}
}
}

override lazy val millDiscover = Discover[this.type]
Expand All @@ -24,8 +27,12 @@ object ScalaTestsErrorTests extends TestSuite {
}
val message = error.getCause.getMessage
assert(
message == s"scalaTestsError is a `ScalaNativeModule`. scalaTestsError.test needs to extend `ScalaNativeTests`."
message == s"scalaTestsError is a `mill.scalanativelib.ScalaNativeModule`. scalaTestsError.test needs to extend `ScalaNativeTests`."
)
}
test("extends-ScalaTests-disabled-hierarchy-check") {
// expect no throws exception
ScalaTestsError.scalaTestsError.testDisabledError
}
}
}
Loading