Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

log scalafix --check diffs as errors via a sbt logger #106

Merged
merged 3 commits into from
May 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions src/main/scala/scalafix/internal/sbt/LoggingOutputStream.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package scalafix.internal.sbt

import java.io.{ByteArrayOutputStream, OutputStream}

import sbt.{Level, Logger}

/** Split an OutputStream into messages and feed them to a given logger at a specified level. Not thread-safe. */
class LoggingOutputStream(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Surprisingly, I couldn't find any helper class in sbt... The best I found was https://github.com/sbt/util/blob/d31b9c509384ecaf37f2ffaec8731250e3c60532/internal/util-logging/src/main/scala/sbt/internal/util/LoggerWriter.scala, which is internal and doesn't match the types we need in that case.

logger: Logger,
level: Level.Value,
separator: String
) extends OutputStream {

private val baos = new ByteArrayOutputStream {
def maybeStripSuffix(suffix: Array[Byte]): Option[String] = {
def endsWithSuffix: Boolean =
count >= suffix.length && suffix.zipWithIndex.forall {
case (b: Byte, i: Int) =>
b == buf(count - separatorBytes.length + i)
}

if (endsWithSuffix)
Some(new String(buf, 0, count - separatorBytes.length))
else None
}
}

private val separatorBytes = separator.getBytes
require(separatorBytes.length > 0)

override def write(b: Int): Unit = {
baos.write(b)
baos.maybeStripSuffix(separatorBytes).foreach { message =>
logger.log(level, message)
baos.reset()
}
}
}

object LoggingOutputStream {
def apply(logger: Logger, level: Level.Value): OutputStream =
new LoggingOutputStream(logger, level, System.lineSeparator)
}
15 changes: 8 additions & 7 deletions src/main/scala/scalafix/internal/sbt/ScalafixInterface.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package scalafix.internal.sbt

import java.io.PrintStream
import java.net.URLClassLoader
import java.nio.file.Path
import java.{util => jutil}
Expand Down Expand Up @@ -67,6 +66,11 @@ object Arg {
override def apply(sa: ScalafixArguments): ScalafixArguments =
sa // caching is currently implemented in sbt-scalafix itself
}

case class PrintStream(printStream: java.io.PrintStream) extends Arg {
override def apply(sa: ScalafixArguments): ScalafixArguments =
sa.withPrintStream(printStream)
}
}

class ScalafixInterface private (
Expand All @@ -86,13 +90,11 @@ class ScalafixInterface private (
private def this(
api: ScalafixAPI,
toolClasspath: URLClassLoader,
mainCallback: ScalafixMainCallback,
printStream: PrintStream
mainCallback: ScalafixMainCallback
) = this(
api
.newArguments()
.withMainCallback(mainCallback)
.withPrintStream(printStream)
.withToolClasspath(toolClasspath),
Seq(Arg.ToolClasspath(toolClasspath))
)
Expand Down Expand Up @@ -164,8 +166,7 @@ object ScalafixInterface {
def fromToolClasspath(
scalafixDependencies: Seq[ModuleID],
scalafixCustomResolvers: Seq[Repository],
logger: Logger = Compat.ConsoleLogger(System.out),
printStream: PrintStream = System.out
logger: Logger = Compat.ConsoleLogger(System.out)
): () => ScalafixInterface =
new LazyValue({ () =>
val jars = ScalafixCoursier.scalafixCliJars(scalafixCustomResolvers)
Expand All @@ -175,7 +176,7 @@ object ScalafixInterface {
val classloader = new URLClassLoader(urls, interfacesParent)
val api = ScalafixAPI.classloadInstance(classloader)
val callback = new ScalafixLogger(logger)
new ScalafixInterface(api, classloader, callback, printStream)
new ScalafixInterface(api, classloader, callback)
.addToolClasspath(
scalafixDependencies,
scalafixCustomResolvers,
Expand Down
5 changes: 4 additions & 1 deletion src/main/scala/scalafix/sbt/ScalafixPlugin.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package scalafix.sbt

import java.io.PrintStream
import java.nio.file.{Path, Paths}

import com.geirsson.coursiersmall.Repository
Expand Down Expand Up @@ -168,7 +169,8 @@ object ScalafixPlugin extends AutoPlugin {
loadedRules = () => scalafixInterface().availableRules(),
terminalWidth = Some(JLineAccess.terminalWidth)
).parser.parsed

val errorLogger =
new PrintStream(LoggingOutputStream(streams.value.log, Level.Error))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about deferring the logging until the point we know the status code, to adjust the level accordingly, but looking at the usage of that printStream, I think it's reasonnable to always log as error, except maybe for --stdout, but that's a lot of custom code for something that I don't see being used through sbt

val projectDepsInternal = products.in(ScalafixConfig).value ++
internalDependencyClasspath.in(ScalafixConfig).value.map(_.data)
val projectDepsExternal =
Expand All @@ -192,6 +194,7 @@ object ScalafixPlugin extends AutoPlugin {
val mainInterface = mainInterface0
.withArgs(maybeNoCache: _*)
.withArgs(
Arg.PrintStream(errorLogger),
Arg.Config(scalafixConf),
Arg.Rules(shell.rules),
Arg.ParsedArgs(shell.extra)
Expand Down
16 changes: 16 additions & 0 deletions src/sbt-test/sbt-scalafix/basic/build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,19 @@ lazy val example = project
)

lazy val tests = project

lazy val checkLogs = taskKey[Unit]("Check that diffs are logged as errors")

checkLogs := {
val taskStreams = streams.in(scalafix).in(Compile).in(example).value
val reader = taskStreams.readText(taskStreams.key)
val logLines = Stream
.continually(reader.readLine())
.takeWhile(_ != null)
.map(_.replaceAll("\u001B\\[[;\\d]*m", "")) // remove control chars (colors)
.force
assert(
logLines.exists(_ == "[error] -import scala.concurrent.Future"),
"diff should be logged as error"
)
}
2 changes: 2 additions & 0 deletions src/sbt-test/sbt-scalafix/basic/test
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
-> example/scalafix --test
$ sleep 2000
> checkLogs
> example/scalafix
> example/scalafix --test
> tests/test
147 changes: 147 additions & 0 deletions src/test/scala/scalafix/internal/sbt/LoggingOutputStreamSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package scalafix.internal.sbt

import java.io.PrintStream

import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.must.Matchers
import sbt.{Level, Logger}

import scala.collection.mutable

class LoggingOutputStreamSuite extends AnyFunSuite with Matchers {

val sep = System.lineSeparator
val CRLF = "\r\n"

def withStubLogger(
level: Level.Value,
separator: String = sep
)(
testFun: (
LoggingOutputStream,
mutable.Seq[(Level.Value, String)]
) => Unit
): Unit = {
val logs = mutable.ListBuffer[(Level.Value, String)]()

val logger = new Logger {
override def log(level: Level.Value, message: => String): Unit =
logs += ((level, message))

override def success(message: => String): Unit = ???
override def trace(t: => Throwable): Unit = ???
}

testFun(new LoggingOutputStream(logger, level, separator), logs)
}

test("capture objects printed through PrintStream.println") {
withStubLogger(Level.Warn) { (outputStream, logs) =>
val data = 1234
new PrintStream(outputStream).println(data)
logs must be(mutable.Seq((Level.Warn, String.valueOf(data))))
}
}

test("capture messages of increasing length") {
withStubLogger(Level.Warn) { (outputStream, logs) =>
val word = "foo"
val printStream = new PrintStream(outputStream)
printStream.println(word)
printStream.println(word * 2)
printStream.println(word * 3)
logs.map(_._2) must be(mutable.Seq(word, word * 2, word * 3))
}
}

test("capture messages of decreasing length") {
withStubLogger(Level.Warn) { (outputStream, logs) =>
val word = "foo"
val printStream = new PrintStream(outputStream)
printStream.println(word * 3)
printStream.println(word * 2)
printStream.println(word)
logs.map(_._2) must be(mutable.Seq(word * 3, word * 2, word))
}
}

test("capture messages of non-monotonic length") {
withStubLogger(Level.Warn) { (outputStream, logs) =>
val word = "foo"
val printStream = new PrintStream(outputStream)
printStream.println(word * 3)
printStream.println(word)
printStream.println(word * 2)
logs.map(_._2) must be(mutable.Seq(word * 3, word, word * 2))
}
}

test("capture heading empty message") {
withStubLogger(Level.Warn) { (outputStream, logs) =>
val message = "hello world!"
outputStream.write(s"$sep$message$sep".getBytes)
logs.map(_._2) must be(mutable.Seq("", message))
}
}

test("capture in-between empty message") {
withStubLogger(Level.Warn) { (outputStream, logs) =>
val message1 = "hello world!"
val message2 = "here we are"
outputStream.write(s"$message1$sep$sep$message2$sep".stripMargin.getBytes)
logs.map(_._2) must be(mutable.Seq(message1, "", message2))
}
}

test("capture trailing empty message") {
withStubLogger(Level.Warn) { (outputStream, logs) =>
val message = "hello world!"
outputStream.write(s"$message$sep$sep".getBytes)
logs.map(_._2) must be(mutable.Seq(message, ""))
}
}

test("capture multi-byte characters") {
withStubLogger(Level.Warn) { (outputStream, logs) =>
val messageWithNonAsciiChar = "il était un petit navire"
messageWithNonAsciiChar.getBytes.length must be > messageWithNonAsciiChar.length //assert test is correct
outputStream.write(s"$messageWithNonAsciiChar$sep".getBytes)
logs.map(_._2) must be(mutable.Seq(messageWithNonAsciiChar))
}
}

test("handle multi-character separator") {
withStubLogger(Level.Warn, CRLF) { (outputStream, logs) =>
val message1 = "hello world!"
val message2 = "here we are"
outputStream.write(s"$message1$CRLF$CRLF$message2$CRLF".getBytes)
logs.map(_._2) must be(mutable.Seq(message1, "", message2))
}
}

test("capture very long messages") {
withStubLogger(Level.Warn) { (outputStream, logs) =>
val veryLongMessage = "a" * 1000000 // this would exhaust memory on quadratic implementations
outputStream.write(s"$veryLongMessage$sep".getBytes)
logs.map(_._2) must be(mutable.Seq(veryLongMessage))
}
}

test(
"capture very long messages containing a subset of the line separator"
) {
withStubLogger(Level.Warn, CRLF) { (outputStream, logs) =>
val veryLongMessage = CRLF.head.toString * 1000000
outputStream.write(s"$veryLongMessage$CRLF".getBytes)
logs.map(_._2) must be(mutable.Seq(veryLongMessage))
}
}

test("fail verbosely for invalid separator") {
an[IllegalArgumentException] should be thrownBy new LoggingOutputStream(
Logger.Null,
Level.Warn,
separator = ""
)
}
}
4 changes: 2 additions & 2 deletions src/test/scala/scalafix/internal/sbt/ScalafixAPISuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ class ScalafixAPISuite extends AnyFunSuite {
.fromToolClasspath(
List("com.geirsson" %% "example-scalafix-rule" % "1.3.0"),
ScalafixCoursier.defaultResolvers,
logger,
new PrintStream(baos)
logger
)()
.withArgs(Arg.PrintStream(new PrintStream(baos)))
val tmp = Files.createTempFile("scalafix", "Tmp.scala")
tmp.toFile.deleteOnExit()
Files.write(
Expand Down