From 68923112d44ce0018fa523e8669ac37356f25bef Mon Sep 17 00:00:00 2001 From: Brice Jaglin Date: Wed, 20 May 2020 10:27:15 +0200 Subject: [PATCH] log `scalafix --check` diffs as errors via sbt logger Diffs were dumped on stdout. When scalafix is run in parallel with many other tasks across multiple sub-projects, this makes it much easier to find out which files are the culprits and why, as the logs stand out and are available via `sbt last scalafix` --- .../internal/sbt/LoggingOutputStream.scala | 43 +++++ .../scala/scalafix/sbt/ScalafixPlugin.scala | 5 +- src/sbt-test/sbt-scalafix/basic/build.sbt | 16 ++ src/sbt-test/sbt-scalafix/basic/test | 1 + .../sbt/LoggingOutputStreamSuite.scala | 147 ++++++++++++++++++ 5 files changed, 211 insertions(+), 1 deletion(-) create mode 100644 src/main/scala/scalafix/internal/sbt/LoggingOutputStream.scala create mode 100644 src/test/scala/scalafix/internal/sbt/LoggingOutputStreamSuite.scala diff --git a/src/main/scala/scalafix/internal/sbt/LoggingOutputStream.scala b/src/main/scala/scalafix/internal/sbt/LoggingOutputStream.scala new file mode 100644 index 00000000..32a031c7 --- /dev/null +++ b/src/main/scala/scalafix/internal/sbt/LoggingOutputStream.scala @@ -0,0 +1,43 @@ +package scalafix.internal.sbt + +import java.io.{ByteArrayOutputStream, OutputStream} + +import sbt.{Level, Logger} + +/** Split an OutputStream into messages and feed them to a given logger at a specified level. Not thread-safe. */ +class LoggingOutputStream( + logger: Logger, + level: Level.Value, + separator: String +) extends OutputStream { + + private val baos = new ByteArrayOutputStream { + def maybeStripSuffix(suffix: Array[Byte]): Option[String] = { + def endsWithSuffix: Boolean = + count >= suffix.length && suffix.zipWithIndex.forall { + case (b: Byte, i: Int) => + b == buf(count - separatorBytes.length + i) + } + + if (endsWithSuffix) + Some(new String(buf, 0, count - separatorBytes.length)) + else None + } + } + + private val separatorBytes = separator.getBytes + require(separatorBytes.length > 0) + + override def write(b: Int): Unit = { + baos.write(b) + baos.maybeStripSuffix(separatorBytes).foreach { message => + logger.log(level, message) + baos.reset() + } + } +} + +object LoggingOutputStream { + def apply(logger: Logger, level: Level.Value): OutputStream = + new LoggingOutputStream(logger, level, System.lineSeparator) +} diff --git a/src/main/scala/scalafix/sbt/ScalafixPlugin.scala b/src/main/scala/scalafix/sbt/ScalafixPlugin.scala index 069ad483..9dec5c20 100644 --- a/src/main/scala/scalafix/sbt/ScalafixPlugin.scala +++ b/src/main/scala/scalafix/sbt/ScalafixPlugin.scala @@ -1,5 +1,6 @@ package scalafix.sbt +import java.io.PrintStream import java.nio.file.{Path, Paths} import com.geirsson.coursiersmall.Repository @@ -168,7 +169,8 @@ object ScalafixPlugin extends AutoPlugin { loadedRules = () => scalafixInterface().availableRules(), terminalWidth = Some(JLineAccess.terminalWidth) ).parser.parsed - + val errorLogger = + new PrintStream(LoggingOutputStream(streams.value.log, Level.Error)) val projectDepsInternal = products.in(ScalafixConfig).value ++ internalDependencyClasspath.in(ScalafixConfig).value.map(_.data) val projectDepsExternal = @@ -192,6 +194,7 @@ object ScalafixPlugin extends AutoPlugin { val mainInterface = mainInterface0 .withArgs(maybeNoCache: _*) .withArgs( + Arg.PrintStream(errorLogger), Arg.Config(scalafixConf), Arg.Rules(shell.rules), Arg.ParsedArgs(shell.extra) diff --git a/src/sbt-test/sbt-scalafix/basic/build.sbt b/src/sbt-test/sbt-scalafix/basic/build.sbt index 1d30344c..2d5dbe29 100644 --- a/src/sbt-test/sbt-scalafix/basic/build.sbt +++ b/src/sbt-test/sbt-scalafix/basic/build.sbt @@ -25,3 +25,19 @@ lazy val example = project ) lazy val tests = project + +lazy val checkLogs = taskKey[Unit]("Check that diffs are logged as errors") + +checkLogs := { + val taskStreams = streams.in(scalafix).in(Compile).in(example).value + val reader = taskStreams.readText(taskStreams.key) + val logLines = Stream + .continually(reader.readLine()) + .takeWhile(_ != null) + .map(_.replaceAll("\u001B\\[[;\\d]*m", "")) // remove control chars (colors) + .force + assert( + logLines.exists(_ == "[error] -import scala.concurrent.Future"), + "diff should be logged as error" + ) +} diff --git a/src/sbt-test/sbt-scalafix/basic/test b/src/sbt-test/sbt-scalafix/basic/test index 2204f1af..6e5bdb0b 100644 --- a/src/sbt-test/sbt-scalafix/basic/test +++ b/src/sbt-test/sbt-scalafix/basic/test @@ -1,4 +1,5 @@ -> example/scalafix --test +> checkLogs > example/scalafix > example/scalafix --test > tests/test diff --git a/src/test/scala/scalafix/internal/sbt/LoggingOutputStreamSuite.scala b/src/test/scala/scalafix/internal/sbt/LoggingOutputStreamSuite.scala new file mode 100644 index 00000000..33662abc --- /dev/null +++ b/src/test/scala/scalafix/internal/sbt/LoggingOutputStreamSuite.scala @@ -0,0 +1,147 @@ +package scalafix.internal.sbt + +import java.io.PrintStream + +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.must.Matchers +import sbt.{Level, Logger} + +import scala.collection.mutable + +class LoggingOutputStreamSuite extends AnyFunSuite with Matchers { + + val sep = System.lineSeparator + val CRLF = "\r\n" + + def withStubLogger( + level: Level.Value, + separator: String = sep + )( + testFun: ( + LoggingOutputStream, + mutable.Seq[(Level.Value, String)] + ) => Unit + ): Unit = { + val logs = mutable.ListBuffer[(Level.Value, String)]() + + val logger = new Logger { + override def log(level: Level.Value, message: => String): Unit = + logs += ((level, message)) + + override def success(message: => String): Unit = ??? + override def trace(t: => Throwable): Unit = ??? + } + + testFun(new LoggingOutputStream(logger, level, separator), logs) + } + + test("capture objects printed through PrintStream.println") { + withStubLogger(Level.Warn) { (outputStream, logs) => + val data = 1234 + new PrintStream(outputStream).println(data) + logs must be(mutable.Seq((Level.Warn, String.valueOf(data)))) + } + } + + test("capture messages of increasing length") { + withStubLogger(Level.Warn) { (outputStream, logs) => + val word = "foo" + val printStream = new PrintStream(outputStream) + printStream.println(word) + printStream.println(word * 2) + printStream.println(word * 3) + logs.map(_._2) must be(mutable.Seq(word, word * 2, word * 3)) + } + } + + test("capture messages of decreasing length") { + withStubLogger(Level.Warn) { (outputStream, logs) => + val word = "foo" + val printStream = new PrintStream(outputStream) + printStream.println(word * 3) + printStream.println(word * 2) + printStream.println(word) + logs.map(_._2) must be(mutable.Seq(word * 3, word * 2, word)) + } + } + + test("capture messages of non-monotonic length") { + withStubLogger(Level.Warn) { (outputStream, logs) => + val word = "foo" + val printStream = new PrintStream(outputStream) + printStream.println(word * 3) + printStream.println(word) + printStream.println(word * 2) + logs.map(_._2) must be(mutable.Seq(word * 3, word, word * 2)) + } + } + + test("capture heading empty message") { + withStubLogger(Level.Warn) { (outputStream, logs) => + val message = "hello world!" + outputStream.write(s"$sep$message$sep".getBytes) + logs.map(_._2) must be(mutable.Seq("", message)) + } + } + + test("capture in-between empty message") { + withStubLogger(Level.Warn) { (outputStream, logs) => + val message1 = "hello world!" + val message2 = "here we are" + outputStream.write(s"$message1$sep$sep$message2$sep".stripMargin.getBytes) + logs.map(_._2) must be(mutable.Seq(message1, "", message2)) + } + } + + test("capture trailing empty message") { + withStubLogger(Level.Warn) { (outputStream, logs) => + val message = "hello world!" + outputStream.write(s"$message$sep$sep".getBytes) + logs.map(_._2) must be(mutable.Seq(message, "")) + } + } + + test("capture multi-byte characters") { + withStubLogger(Level.Warn) { (outputStream, logs) => + val messageWithNonAsciiChar = "il était un petit navire" + messageWithNonAsciiChar.getBytes.length must be > messageWithNonAsciiChar.length //assert test is correct + outputStream.write(s"$messageWithNonAsciiChar$sep".getBytes) + logs.map(_._2) must be(mutable.Seq(messageWithNonAsciiChar)) + } + } + + test("handle multi-character separator") { + withStubLogger(Level.Warn, CRLF) { (outputStream, logs) => + val message1 = "hello world!" + val message2 = "here we are" + outputStream.write(s"$message1$CRLF$CRLF$message2$CRLF".getBytes) + logs.map(_._2) must be(mutable.Seq(message1, "", message2)) + } + } + + test("capture very long messages") { + withStubLogger(Level.Warn) { (outputStream, logs) => + val veryLongMessage = "a" * 1000000 // this would exhaust memory on quadratic implementations + outputStream.write(s"$veryLongMessage$sep".getBytes) + logs.map(_._2) must be(mutable.Seq(veryLongMessage)) + } + } + + test( + "capture very long messages containing a subset of the line separator" + ) { + withStubLogger(Level.Warn, CRLF) { (outputStream, logs) => + val veryLongMessage = CRLF.head.toString * 1000000 + outputStream.write(s"$veryLongMessage$CRLF".getBytes) + logs.map(_._2) must be(mutable.Seq(veryLongMessage)) + } + } + + test("fail verbosely for invalid separator") { + an[IllegalArgumentException] should be thrownBy new LoggingOutputStream( + Logger.Null, + Level.Warn, + separator = "" + ) + } +}