diff --git a/docs/assertions.md b/docs/assertions.md index ccc63bb6..094a3896 100644 --- a/docs/assertions.md +++ b/docs/assertions.md @@ -1,4 +1,3 @@ - --- id: assertions title: Writing assertions @@ -135,7 +134,10 @@ Windows/Unix newlines and ANSI color codes. The "=> Obtained" section of `.stripMargin`. ## `intercept()` -Use `intercept()` when you expect a particular exception to be thrown by the test code (i.e. the test succeeds if the given Exception is thrown) + +Use `intercept()` when you expect a particular exception to be thrown by the +test code (i.e. the test succeeds if the given exception is thrown). + ```scala mdoc:crash intercept[java.lang.IllegalArgumentException]{ // code expected to throw exception here @@ -143,7 +145,10 @@ intercept[java.lang.IllegalArgumentException]{ ``` ## `interceptMessage()` -Like intercept() except you can also specify a specific message the given Exception must match. + +Like `intercept()` except additionally asserts that the thrown exception has a +specific error message. + ```scala mdoc:crash interceptMessage[java.lang.IllegalArgumentException]("argument type mismatch"){ // code expected to throw exception here @@ -163,3 +168,33 @@ Use `clues()` to include optional context why the test failed. ```scala mdoc:crash fail("test failed", clues(a + b)) ``` + +## `compileErrors()` + +Use `compileErrors()` to assert that an example code snippet fails with a +specific compile-time error message. + +```scala mdoc +assertNoDiff( + compileErrors("Set(2, 1).sorted"), + """|error: value sorted is not a member of scala.collection.immutable.Set[Int] + |Set(2, 1).sorted + | ^ + |""".stripMargin +) +``` + +The argument to `compileErrors` must be a string literal. It's not possible to +pass in more complicated expressions such as variables or string interpolators. + +```scala mdoc:fail +val code = """val x: String = 2""" +compileErrors(code) +compileErrors(s"/* code */ $code") +``` + +Inline the `code` variable to fix the compile error. + +```scala mdoc +compileErrors("val x: String = 2") +``` diff --git a/munit/shared/src/main/scala-2/munit/internal/MacroCompat.scala b/munit/shared/src/main/scala-2/munit/internal/MacroCompat.scala index 995f2c1e..1d239100 100644 --- a/munit/shared/src/main/scala-2/munit/internal/MacroCompat.scala +++ b/munit/shared/src/main/scala-2/munit/internal/MacroCompat.scala @@ -4,6 +4,8 @@ import munit.Clue import munit.Location import scala.language.experimental.macros import scala.reflect.macros.blackbox.Context +import scala.reflect.macros.TypecheckException +import scala.reflect.macros.ParseException object MacroCompat { @@ -57,4 +59,45 @@ object MacroCompat { valueType ) } + + trait CompileErrorMacro { + def compileErrors(code: String): String = macro compileErrorsImpl + } + + def compileErrorsImpl(c: Context)(code: c.Tree): c.Tree = { + import c.universe._ + val toParse: String = code match { + case Literal(Constant(literal: String)) => literal + case _ => + c.abort( + code.pos, + "cannot compile dynamic expressions, only constant literals.\n" + + "To fix this problem, pass in a string literal in double quotes \"...\"" + ) + } + + def formatError(message: String, pos: scala.reflect.api.Position): String = + new StringBuilder() + .append("error:") + .append(if (message.contains('\n')) "\n" else " ") + .append(message) + .append("\n") + .append(pos.lineContent) + .append("\n") + .append(" " * (pos.column - 1)) + .append("^") + .toString() + + val message: String = + try { + c.typecheck(c.parse(s"{\n$toParse\n}")) + "" + } catch { + case e: ParseException => + formatError(e.getMessage(), e.pos) + case e: TypecheckException => + formatError(e.getMessage(), e.pos) + } + Literal(Constant(message)) + } } diff --git a/munit/shared/src/main/scala-3/munit/internal/MacroCompat.scala b/munit/shared/src/main/scala-3/munit/internal/MacroCompat.scala index 83217426..3856cbae 100644 --- a/munit/shared/src/main/scala-3/munit/internal/MacroCompat.scala +++ b/munit/shared/src/main/scala-3/munit/internal/MacroCompat.scala @@ -28,4 +28,19 @@ object MacroCompat { '{ new Clue(${Expr(source)}, $value, ${Expr(valueType)}) } } + trait CompileErrorMacro { + inline def compileErrors(inline code: String): String = { + val errors = scala.compiletime.testing.typeCheckErrors(code) + errors.map { error => + val indent = " " * (error.column - 1) + val trimMessage = error.message.linesIterator.map { line => + if (line.matches(" +")) "" + else line + }.mkString("\n") + val separator = if (error.message.contains('\n')) "\n" else " " + s"error:${separator}${trimMessage}\n${error.lineContent}\n${indent}^" + }.mkString("\n") + } + } + } diff --git a/munit/shared/src/main/scala/munit/Assertions.scala b/munit/shared/src/main/scala/munit/Assertions.scala index ae91dc8b..a6397f65 100644 --- a/munit/shared/src/main/scala/munit/Assertions.scala +++ b/munit/shared/src/main/scala/munit/Assertions.scala @@ -8,9 +8,10 @@ import scala.util.control.NonFatal import scala.collection.mutable import munit.internal.console.AnsiColors import org.junit.AssumptionViolatedException +import munit.internal.MacroCompat object Assertions extends Assertions -trait Assertions { +trait Assertions extends MacroCompat.CompileErrorMacro { val munitLines = new Lines diff --git a/tests/shared/src/test/scala/munit/TypeCheckSuite.scala b/tests/shared/src/test/scala/munit/TypeCheckSuite.scala new file mode 100644 index 00000000..f2ac5aa7 --- /dev/null +++ b/tests/shared/src/test/scala/munit/TypeCheckSuite.scala @@ -0,0 +1,99 @@ +package munit + +class TypeCheckSuite extends FunSuite { + + def check( + options: TestOptions, + obtained: String, + compat: Map[String, String] + )(implicit loc: Location): Unit = { + test(options) { + val split = BuildInfo.scalaVersion.split("\\.") + val binaryVersion = split.take(2).mkString(".") + val majorVersion = split.head match { + case "0" => "3" + case n => n + } + val expected = compat + .get(BuildInfo.scalaVersion) + .orElse(compat.get(binaryVersion)) + .orElse(compat.get(majorVersion)) + .getOrElse { + compat(BuildInfo.scalaVersion) + } + assertNoDiff(obtained, expected)(loc) + } + } + + val msg = "Hello" + check( + "not a member", + compileErrors("msg.foobar"), + Map( + "2" -> + """|error: value foobar is not a member of String + |msg.foobar + | ^ + |""".stripMargin, + "3" -> + """|error: + |value foobar is not a member of String, but could be made available as an extension method. + | + |The following import might fix the problem: + | + | import munit.Clue.generate + | + |msg.foobar + | ^ + |""".stripMargin + ) + ) + + check( + "parse error", + compileErrors("val x: = 2"), + Map( + "2" -> """|error: identifier expected but '=' found. + |val x: = 2 + | ^ + |""".stripMargin, + "3" -> + // NOTE(olafur): I'm not sure what's going on with the second errors but + // that's what Dotty reports. + """|error: an identifier expected, but eof found + |val x: = 2 + | ^ + |error: Declaration of value x not allowed here: only classes can have declared but undefined members + |package munit + | ^ + |""".stripMargin + ) + ) + + check( + "type mismatch", + compileErrors("val n: Int = msg"), + Map( + "2" -> + """|error: + |type mismatch; + | found : String + | required: Int + |val n: Int = msg + | ^ + |""".stripMargin, + "3" -> + """|error: + |Found: (TypeCheckSuite.this.msg : String) + |Required: Int + | + |The following import might make progress towards fixing the problem: + | + | import munit.Clue.generate + | + |val n: Int = msg + | ^ + |""".stripMargin + ) + ) +} diff --git a/website/i18n/en.json b/website/i18n/en.json index 0f5192c2..39ecef7c 100644 --- a/website/i18n/en.json +++ b/website/i18n/en.json @@ -6,7 +6,7 @@ "tagline": "Scala testing library with actionable errors and extensible APIs", "docs": { "assertions": { - "title": "assertions" + "title": "Writing assertions" }, "filtering": { "title": "Filtering tests"