diff --git a/scalajslib/test/src/mill/scalajslib/ScalaTestsErrorTests.scala b/scalajslib/test/src/mill/scalajslib/ScalaTestsErrorTests.scala new file mode 100644 index 00000000000..8a9cb541b1b --- /dev/null +++ b/scalajslib/test/src/mill/scalajslib/ScalaTestsErrorTests.scala @@ -0,0 +1,31 @@ +package mill.scalajslib + +import mill._ +import mill.define.Discover +import mill.scalalib.TestModule +import mill.util.TestUtil +import utest._ + +object ScalaTestsErrorTests extends TestSuite { + object ScalaTestsError extends TestUtil.BaseModule { + object scalaTestsError extends ScalaJSModule { + def scalaVersion = sys.props.getOrElse("TEST_SCALA_3_3_VERSION", ???) + def scalaJSVersion = sys.props.getOrElse("TEST_SCALAJS_VERSION", ???) + object test extends ScalaTests with TestModule.Utest + } + + override lazy val millDiscover = Discover[this.type] + } + + def tests: Tests = Tests { + test("extends-ScalaTests") { + val error = intercept[ExceptionInInitializerError] { + ScalaTestsError.scalaTestsError.test + } + val message = error.getCause.getMessage + assert( + message == s"scalaTestsError is a `ScalaJSModule`. scalaTestsError.test needs to extend `ScalaJSTests`." + ) + } + } +} diff --git a/scalalib/src/mill/scalalib/ScalaModule.scala b/scalalib/src/mill/scalalib/ScalaModule.scala index 2071117b7ce..5a391444cb8 100644 --- a/scalalib/src/mill/scalalib/ScalaModule.scala +++ b/scalalib/src/mill/scalalib/ScalaModule.scala @@ -1,7 +1,15 @@ package mill package scalalib -import mill.api.{DummyInputStream, JarManifest, PathRef, Result, SystemStreams, internal} +import mill.api.{ + DummyInputStream, + JarManifest, + MillException, + PathRef, + Result, + SystemStreams, + internal +} import mill.main.BuildInfo import mill.util.{Jvm, Util} import mill.util.Jvm.createJar @@ -20,6 +28,29 @@ trait ScalaModule extends JavaModule with TestModule.ScalaModuleBase { outer => type ScalaModuleTests = ScalaTests trait ScalaTests extends JavaModuleTests with ScalaModule { + try { + if ( + Class.forName("mill.scalajslib.ScalaJSModule").isInstance(outer) && !Class.forName( + "mill.scalajslib.ScalaJSModule$ScalaJSTests" + ).isInstance(this) + ) throw new MillException( + s"$outer is a `ScalaJSModule`. $this needs to extend `ScalaJSTests`." + ) + } catch { + case _: ClassNotFoundException => // if we can't find the classes, we certainly are not in a ScalaJSModule + } + try { + if ( + Class.forName("mill.scalanativelib.ScalaNativeModule").isInstance(outer) && !Class.forName( + "mill.scalanativelib.ScalaNativeModule$ScalaNativeTests" + ).isInstance(this) + ) throw new MillException( + s"$outer is a `ScalaNativeModule`. $this needs to extend `ScalaNativeTests`." + ) + } catch { + case _: ClassNotFoundException => // if we can't find the classes, we certainly are not in a ScalaNativeModule + } + override def scalaOrganization: Target[String] = outer.scalaOrganization() override def scalaVersion: Target[String] = outer.scalaVersion() override def scalacPluginIvyDeps: Target[Agg[Dep]] = outer.scalacPluginIvyDeps() diff --git a/scalanativelib/test/src/mill/scalanativelib/ScalaTestsErrorTests.scala b/scalanativelib/test/src/mill/scalanativelib/ScalaTestsErrorTests.scala new file mode 100644 index 00000000000..af843d2102c --- /dev/null +++ b/scalanativelib/test/src/mill/scalanativelib/ScalaTestsErrorTests.scala @@ -0,0 +1,31 @@ +package mill.scalanativelib + +import mill._ +import mill.define.Discover +import mill.scalalib.TestModule +import mill.util.TestUtil +import utest._ + +object ScalaTestsErrorTests extends TestSuite { + object ScalaTestsError extends TestUtil.BaseModule { + object scalaTestsError extends ScalaNativeModule { + def scalaVersion = sys.props.getOrElse("TEST_SCALA_3_3_VERSION", ???) + def scalaNativeVersion = sys.props.getOrElse("TEST_SCALANATIVE_VERSION", ???) + object test extends ScalaTests with TestModule.Utest + } + + override lazy val millDiscover = Discover[this.type] + } + + def tests: Tests = Tests { + test("extends-ScalaTests") { + val error = intercept[ExceptionInInitializerError] { + ScalaTestsError.scalaTestsError.test + } + val message = error.getCause.getMessage + assert( + message == s"scalaTestsError is a `ScalaNativeModule`. scalaTestsError.test needs to extend `ScalaNativeTests`." + ) + } + } +}