diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/Scalafmt.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/Scalafmt.scala index 5abf8c4467..2dc5700b3d 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/Scalafmt.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/Scalafmt.scala @@ -3,14 +3,12 @@ package org.scalafmt import java.nio.file.Path import metaconfig.Configured -import scala.annotation.tailrec import scala.meta.Dialect import scala.meta.Input import scala.meta.parsers.ParseException import scala.util.Failure import scala.util.Success import scala.util.Try -import scala.util.matching.Regex import org.scalafmt.config.Config import org.scalafmt.Error.PreciseIncomplete @@ -100,36 +98,13 @@ object Scalafmt { } } - private def flatMapAll[A, B](xs: Iterator[A])(f: A => Try[B]): Try[Seq[B]] = { - val res = Seq.newBuilder[B] - @tailrec - def iter: Try[Seq[B]] = - if (!xs.hasNext) Success(res.result()) - else - f(xs.next()) match { - case Success(x) => res += x; iter - case Failure(e) => Failure(e) - } - iter - } - - // see: https://ammonite.io/#Save/LoadSession - private val ammonitePattern: Regex = "(?:\\s*\\n@(?=\\s))+".r - private def doFormat( code: String, style: ScalafmtConfig, file: String, range: Set[Range] ): Try[String] = - if (FileOps.isAmmonite(file)) { - // XXX: we won't support ranges as we don't keep track of lines - val chunks = ammonitePattern.split(code) - if (chunks.length <= 1) doFormatOne(code, style, file, range) - else - flatMapAll(chunks.iterator)(doFormatOne(_, style, file)) - .map(_.mkString("\n@\n")) - } else if (FileOps.isMarkdown(file)) { + if (FileOps.isMarkdown(file)) { val markdown = MarkdownFile.parse(Input.VirtualFile(file, code)) val resultIterator: Iterator[Try[String]] = @@ -150,8 +125,10 @@ object Scalafmt { } else doFormatOne(code, style, file, range) - private[scalafmt] def toInput(code: String, file: String): Input.VirtualFile = - Input.VirtualFile(file, code) + private[scalafmt] def toInput(code: String, file: String): Input = { + val fileInput = Input.VirtualFile(file, code) + if (FileOps.isAmmonite(file)) Input.Ammonite(fileInput) else fileInput + } private def doFormatOne( code: String, @@ -162,7 +139,7 @@ object Scalafmt { if (code.matches("\\s*")) Try("\n") else { val runner = style.runner - val codeToInput: String => Input.VirtualFile = toInput(_, file) + val codeToInput: String => Input = toInput(_, file) val parsed = runner.parse(Rewrite(codeToInput(code), style, codeToInput)) parsed.fold( _.details match { diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/config/ScalafmtParser.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/config/ScalafmtParser.scala index 78df1892af..6f26e75391 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/config/ScalafmtParser.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/config/ScalafmtParser.scala @@ -1,14 +1,23 @@ package org.scalafmt.config -import scala.meta.Tree +import scala.meta._ import scala.meta.parsers.Parse +import scala.meta.parsers.Parsed sealed class ScalafmtParser(val parse: Parse[_ <: Tree]) object ScalafmtParser { case object Case extends ScalafmtParser(Parse.parseCase) case object Stat extends ScalafmtParser(Parse.parseStat) - case object Source extends ScalafmtParser(Parse.parseSource) + case object Source extends ScalafmtParser(SourceParser) implicit val codec = ReaderUtil.oneOf[ScalafmtParser](Case, Stat, Source) + + private object SourceParser extends Parse[Tree] { + override def apply(input: Input, dialect: Dialect): Parsed[Tree] = { + val isAmmonite = input.isInstanceOf[Input.Ammonite] + val parser = if (isAmmonite) Parse.parseAmmonite else Parse.parseSource + parser(input, dialect) + } + } } diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/Router.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/Router.scala index 7c3b0ec427..1fb237d257 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/Router.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/Router.scala @@ -83,6 +83,10 @@ class Router(formatOps: FormatOps) { val newlines = formatToken.newlinesBetween formatToken match { + // between sources (EOF -> @ -> BOF) + case FormatToken(_: T.EOF, _, _) => Seq(Split(Newline, 0)) + case ft @ FormatToken(_, _: T.BOF, _) => + Seq(Split(NoSplit.orNL(next(ft).right.is[T.EOF]), 0)) case FormatToken(_: T.BOF, right, _) => val policy = right match { case T.Ident(name) // shebang in .sc files diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/Rewrite.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/Rewrite.scala index f6f9afcbac..fb25872c27 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/Rewrite.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/Rewrite.scala @@ -5,7 +5,6 @@ import scala.collection.mutable import metaconfig.ConfCodecEx import scala.meta._ -import scala.meta.Input.VirtualFile import scala.meta.tokens.Token.LF import scala.meta.transversers.SimpleTraverser @@ -16,7 +15,7 @@ import org.scalafmt.util.{TokenOps, TokenTraverser, TreeOps, Trivia, Whitespace} case class RewriteCtx( style: ScalafmtConfig, - fileName: String, + input: Input, tree: Tree ) { implicit val dialect = style.dialect @@ -24,7 +23,7 @@ case class RewriteCtx( private val patchBuilder = mutable.Map.empty[(Int, Int), TokenPatch] val tokens = tree.tokens - val tokenTraverser = new TokenTraverser(tokens, fileName) + val tokenTraverser = new TokenTraverser(tokens, input) val matchingParens = TreeOps.getMatchingParentheses(tokens) @inline def getMatching(a: Token): Token = @@ -142,7 +141,7 @@ object Rewrite { val default: Seq[Rewrite] = name2rewrite.values.toSeq def apply( - input: VirtualFile, + input: Input, style: ScalafmtConfig, toInput: String => Input ): Input = { @@ -152,7 +151,7 @@ object Rewrite { } else { style.runner.parse(input) match { case Parsed.Success(ast) => - val ctx = RewriteCtx(style, input.path, ast) + val ctx = RewriteCtx(style, input, ast) val rewriteSessions = rewrites.map(_.create(ctx)).toList val traverser = new SimpleTraverser { override def apply(tree: Tree): Unit = { diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/util/TokenTraverser.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/util/TokenTraverser.scala index 622e2794fc..eaa8d67447 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/util/TokenTraverser.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/util/TokenTraverser.scala @@ -1,12 +1,11 @@ package org.scalafmt.util -import org.scalafmt.sysops.FileOps - import scala.annotation.tailrec +import scala.meta.Input import scala.meta.tokens.Token import scala.meta.tokens.Tokens -class TokenTraverser(tokens: Tokens, filename: String) { +class TokenTraverser(tokens: Tokens, input: Input) { private[this] val (tok2idx, excludedTokens) = { val map = Map.newBuilder[Token, Int] val excluded = Set.newBuilder[TokenOps.TokenHash] @@ -22,7 +21,7 @@ class TokenTraverser(tokens: Tokens, filename: String) { map += (tok -> i) i += 1 } - if (FileOps.isAmmonite(filename)) { + if (input.isInstanceOf[Input.Ammonite]) { val realTokens = tokens.dropWhile(_.is[Token.BOF]) realTokens.headOption.foreach { // shebang in .sc files diff --git a/scalafmt-tests/src/test/resources/test/Dialect.source b/scalafmt-tests/src/test/resources/test/Dialect.source index 8517da37b1..fb085efaf6 100644 --- a/scalafmt-tests/src/test/resources/test/Dialect.source +++ b/scalafmt-tests/src/test/resources/test/Dialect.source @@ -74,6 +74,7 @@ import mill._ interp.repositories() = interp.repositories() ++ Seq(coursier.MavenRepository("https://jitpack.io")) +@ @ import $ivy.`com.github.yyadavalli::mill-ensime:0.0.2` <<< #2204 annotations diff --git a/scalafmt-tests/src/test/scala/org/scalafmt/FormatTests.scala b/scalafmt-tests/src/test/scala/org/scalafmt/FormatTests.scala index 30c3885dad..bcde94dcbf 100644 --- a/scalafmt-tests/src/test/scala/org/scalafmt/FormatTests.scala +++ b/scalafmt-tests/src/test/scala/org/scalafmt/FormatTests.scala @@ -61,7 +61,7 @@ class FormatTests extends FunSuite with CanRunTests with FormatAssertions { t.style.rewrite.rules.isEmpty && FormatTokensRewrite.getEnabledFactories(t.style).isEmpty && !t.style.assumeStandardLibraryStripMargin && - !FileOps.isAmmonite(t.filename) && !FileOps.isMarkdown(t.filename) && + !FileOps.isMarkdown(t.filename) && t.style.onTestFailure.isEmpty ) assertFormatPreservesAst(