From 7937dd453d2656e6c11107f8f1821e3cccd8bfa7 Mon Sep 17 00:00:00 2001 From: Albert Meltzer <7529386+kitbellew@users.noreply.github.com> Date: Wed, 9 Feb 2022 09:15:20 -0800 Subject: [PATCH] Scalafmt: use ammonite parser to handle scripts --- .../main/scala/org/scalafmt/Scalafmt.scala | 32 +++---------------- .../org/scalafmt/config/ScalafmtParser.scala | 2 +- .../scala/org/scalafmt/internal/Router.scala | 4 +++ .../scala/org/scalafmt/rewrite/Rewrite.scala | 9 +++--- .../org/scalafmt/util/TokenTraverser.scala | 7 ++-- .../src/test/resources/test/Dialect.source | 1 + 6 files changed, 18 insertions(+), 37 deletions(-) diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/Scalafmt.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/Scalafmt.scala index 8f50863d5b..246123c6c6 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/Scalafmt.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/Scalafmt.scala @@ -3,14 +3,12 @@ package org.scalafmt import java.nio.file.Path import metaconfig.Configured -import scala.annotation.tailrec import scala.meta.Dialect import scala.meta.Input import scala.meta.parsers.ParseException import scala.util.Failure import scala.util.Success import scala.util.Try -import scala.util.matching.Regex import org.scalafmt.config.Config import org.scalafmt.Error.PreciseIncomplete @@ -100,36 +98,13 @@ object Scalafmt { } } - private def flatMapAll[A, B](xs: Iterator[A])(f: A => Try[B]): Try[Seq[B]] = { - val res = Seq.newBuilder[B] - @tailrec - def iter: Try[Seq[B]] = - if (!xs.hasNext) Success(res.result()) - else - f(xs.next()) match { - case Success(x) => res += x; iter - case Failure(e) => Failure(e) - } - iter - } - - // see: https://ammonite.io/#Save/LoadSession - private val ammonitePattern: Regex = "(?:\\s*\\n@(?=\\s))+".r - private def doFormat( code: String, style: ScalafmtConfig, file: String, range: Set[Range] ): Try[String] = - if (FileOps.isAmmonite(file)) { - // XXX: we won't support ranges as we don't keep track of lines - val chunks = ammonitePattern.split(code) - if (chunks.length <= 1) doFormatOne(code, style, file, range) - else - flatMapAll(chunks.iterator)(doFormatOne(_, style, file)) - .map(_.mkString("\n@\n")) - } else if (FileOps.isMarkdown(file)) { + if (FileOps.isMarkdown(file)) { val markdown = MarkdownFile.parse(Input.VirtualFile(file, code)) val resultIterator: Iterator[Try[String]] = @@ -159,7 +134,10 @@ object Scalafmt { if (code.matches("\\s*")) Try("\n") else { val runner = style.runner - def codeToInput(srcCode: String) = Input.VirtualFile(file, srcCode) + def codeToFile(srcCode: String) = Input.VirtualFile(file, srcCode) + val codeToInput: String => Input = + if (FileOps.isAmmonite(file)) x => Input.Ammonite(codeToFile(x)) + else codeToFile val parsed = runner.parse(Rewrite(codeToInput(code), style, codeToInput)) parsed.fold( _.details match { diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/config/ScalafmtParser.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/config/ScalafmtParser.scala index 78df1892af..b0dfabc556 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/config/ScalafmtParser.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/config/ScalafmtParser.scala @@ -8,7 +8,7 @@ sealed class ScalafmtParser(val parse: Parse[_ <: Tree]) object ScalafmtParser { case object Case extends ScalafmtParser(Parse.parseCase) case object Stat extends ScalafmtParser(Parse.parseStat) - case object Source extends ScalafmtParser(Parse.parseSource) + case object Source extends ScalafmtParser(Parse.parseAmmonite) implicit val codec = ReaderUtil.oneOf[ScalafmtParser](Case, Stat, Source) } diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/Router.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/Router.scala index 7c3b0ec427..1fb237d257 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/Router.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/Router.scala @@ -83,6 +83,10 @@ class Router(formatOps: FormatOps) { val newlines = formatToken.newlinesBetween formatToken match { + // between sources (EOF -> @ -> BOF) + case FormatToken(_: T.EOF, _, _) => Seq(Split(Newline, 0)) + case ft @ FormatToken(_, _: T.BOF, _) => + Seq(Split(NoSplit.orNL(next(ft).right.is[T.EOF]), 0)) case FormatToken(_: T.BOF, right, _) => val policy = right match { case T.Ident(name) // shebang in .sc files diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/Rewrite.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/Rewrite.scala index f6f9afcbac..fb25872c27 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/Rewrite.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/Rewrite.scala @@ -5,7 +5,6 @@ import scala.collection.mutable import metaconfig.ConfCodecEx import scala.meta._ -import scala.meta.Input.VirtualFile import scala.meta.tokens.Token.LF import scala.meta.transversers.SimpleTraverser @@ -16,7 +15,7 @@ import org.scalafmt.util.{TokenOps, TokenTraverser, TreeOps, Trivia, Whitespace} case class RewriteCtx( style: ScalafmtConfig, - fileName: String, + input: Input, tree: Tree ) { implicit val dialect = style.dialect @@ -24,7 +23,7 @@ case class RewriteCtx( private val patchBuilder = mutable.Map.empty[(Int, Int), TokenPatch] val tokens = tree.tokens - val tokenTraverser = new TokenTraverser(tokens, fileName) + val tokenTraverser = new TokenTraverser(tokens, input) val matchingParens = TreeOps.getMatchingParentheses(tokens) @inline def getMatching(a: Token): Token = @@ -142,7 +141,7 @@ object Rewrite { val default: Seq[Rewrite] = name2rewrite.values.toSeq def apply( - input: VirtualFile, + input: Input, style: ScalafmtConfig, toInput: String => Input ): Input = { @@ -152,7 +151,7 @@ object Rewrite { } else { style.runner.parse(input) match { case Parsed.Success(ast) => - val ctx = RewriteCtx(style, input.path, ast) + val ctx = RewriteCtx(style, input, ast) val rewriteSessions = rewrites.map(_.create(ctx)).toList val traverser = new SimpleTraverser { override def apply(tree: Tree): Unit = { diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/util/TokenTraverser.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/util/TokenTraverser.scala index 622e2794fc..eaa8d67447 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/util/TokenTraverser.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/util/TokenTraverser.scala @@ -1,12 +1,11 @@ package org.scalafmt.util -import org.scalafmt.sysops.FileOps - import scala.annotation.tailrec +import scala.meta.Input import scala.meta.tokens.Token import scala.meta.tokens.Tokens -class TokenTraverser(tokens: Tokens, filename: String) { +class TokenTraverser(tokens: Tokens, input: Input) { private[this] val (tok2idx, excludedTokens) = { val map = Map.newBuilder[Token, Int] val excluded = Set.newBuilder[TokenOps.TokenHash] @@ -22,7 +21,7 @@ class TokenTraverser(tokens: Tokens, filename: String) { map += (tok -> i) i += 1 } - if (FileOps.isAmmonite(filename)) { + if (input.isInstanceOf[Input.Ammonite]) { val realTokens = tokens.dropWhile(_.is[Token.BOF]) realTokens.headOption.foreach { // shebang in .sc files diff --git a/scalafmt-tests/src/test/resources/test/Dialect.source b/scalafmt-tests/src/test/resources/test/Dialect.source index 8517da37b1..fb085efaf6 100644 --- a/scalafmt-tests/src/test/resources/test/Dialect.source +++ b/scalafmt-tests/src/test/resources/test/Dialect.source @@ -74,6 +74,7 @@ import mill._ interp.repositories() = interp.repositories() ++ Seq(coursier.MavenRepository("https://jitpack.io")) +@ @ import $ivy.`com.github.yyadavalli::mill-ensime:0.0.2` <<< #2204 annotations