Skip to content

Commit

Permalink
Scalafmt: use ammonite parser to handle scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
kitbellew committed Feb 10, 2022
1 parent 4de41cc commit 7937dd4
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 37 deletions.
32 changes: 5 additions & 27 deletions scalafmt-core/shared/src/main/scala/org/scalafmt/Scalafmt.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@ package org.scalafmt
import java.nio.file.Path

import metaconfig.Configured
import scala.annotation.tailrec
import scala.meta.Dialect
import scala.meta.Input
import scala.meta.parsers.ParseException
import scala.util.Failure
import scala.util.Success
import scala.util.Try
import scala.util.matching.Regex

import org.scalafmt.config.Config
import org.scalafmt.Error.PreciseIncomplete
Expand Down Expand Up @@ -100,36 +98,13 @@ object Scalafmt {
}
}

private def flatMapAll[A, B](xs: Iterator[A])(f: A => Try[B]): Try[Seq[B]] = {
val res = Seq.newBuilder[B]
@tailrec
def iter: Try[Seq[B]] =
if (!xs.hasNext) Success(res.result())
else
f(xs.next()) match {
case Success(x) => res += x; iter
case Failure(e) => Failure(e)
}
iter
}

// see: https://ammonite.io/#Save/LoadSession
private val ammonitePattern: Regex = "(?:\\s*\\n@(?=\\s))+".r

private def doFormat(
code: String,
style: ScalafmtConfig,
file: String,
range: Set[Range]
): Try[String] =
if (FileOps.isAmmonite(file)) {
// XXX: we won't support ranges as we don't keep track of lines
val chunks = ammonitePattern.split(code)
if (chunks.length <= 1) doFormatOne(code, style, file, range)
else
flatMapAll(chunks.iterator)(doFormatOne(_, style, file))
.map(_.mkString("\n@\n"))
} else if (FileOps.isMarkdown(file)) {
if (FileOps.isMarkdown(file)) {
val markdown = MarkdownFile.parse(Input.VirtualFile(file, code))

val resultIterator: Iterator[Try[String]] =
Expand Down Expand Up @@ -159,7 +134,10 @@ object Scalafmt {
if (code.matches("\\s*")) Try("\n")
else {
val runner = style.runner
def codeToInput(srcCode: String) = Input.VirtualFile(file, srcCode)
def codeToFile(srcCode: String) = Input.VirtualFile(file, srcCode)
val codeToInput: String => Input =
if (FileOps.isAmmonite(file)) x => Input.Ammonite(codeToFile(x))
else codeToFile
val parsed = runner.parse(Rewrite(codeToInput(code), style, codeToInput))
parsed.fold(
_.details match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ sealed class ScalafmtParser(val parse: Parse[_ <: Tree])
object ScalafmtParser {
case object Case extends ScalafmtParser(Parse.parseCase)
case object Stat extends ScalafmtParser(Parse.parseStat)
case object Source extends ScalafmtParser(Parse.parseSource)
case object Source extends ScalafmtParser(Parse.parseAmmonite)

implicit val codec = ReaderUtil.oneOf[ScalafmtParser](Case, Stat, Source)
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ class Router(formatOps: FormatOps) {
val newlines = formatToken.newlinesBetween

formatToken match {
// between sources (EOF -> @ -> BOF)
case FormatToken(_: T.EOF, _, _) => Seq(Split(Newline, 0))
case ft @ FormatToken(_, _: T.BOF, _) =>
Seq(Split(NoSplit.orNL(next(ft).right.is[T.EOF]), 0))
case FormatToken(_: T.BOF, right, _) =>
val policy = right match {
case T.Ident(name) // shebang in .sc files
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import scala.collection.mutable

import metaconfig.ConfCodecEx
import scala.meta._
import scala.meta.Input.VirtualFile
import scala.meta.tokens.Token.LF
import scala.meta.transversers.SimpleTraverser

Expand All @@ -16,15 +15,15 @@ import org.scalafmt.util.{TokenOps, TokenTraverser, TreeOps, Trivia, Whitespace}

case class RewriteCtx(
style: ScalafmtConfig,
fileName: String,
input: Input,
tree: Tree
) {
implicit val dialect = style.dialect

private val patchBuilder = mutable.Map.empty[(Int, Int), TokenPatch]

val tokens = tree.tokens
val tokenTraverser = new TokenTraverser(tokens, fileName)
val tokenTraverser = new TokenTraverser(tokens, input)
val matchingParens = TreeOps.getMatchingParentheses(tokens)

@inline def getMatching(a: Token): Token =
Expand Down Expand Up @@ -142,7 +141,7 @@ object Rewrite {
val default: Seq[Rewrite] = name2rewrite.values.toSeq

def apply(
input: VirtualFile,
input: Input,
style: ScalafmtConfig,
toInput: String => Input
): Input = {
Expand All @@ -152,7 +151,7 @@ object Rewrite {
} else {
style.runner.parse(input) match {
case Parsed.Success(ast) =>
val ctx = RewriteCtx(style, input.path, ast)
val ctx = RewriteCtx(style, input, ast)
val rewriteSessions = rewrites.map(_.create(ctx)).toList
val traverser = new SimpleTraverser {
override def apply(tree: Tree): Unit = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
package org.scalafmt.util

import org.scalafmt.sysops.FileOps

import scala.annotation.tailrec
import scala.meta.Input
import scala.meta.tokens.Token
import scala.meta.tokens.Tokens

class TokenTraverser(tokens: Tokens, filename: String) {
class TokenTraverser(tokens: Tokens, input: Input) {
private[this] val (tok2idx, excludedTokens) = {
val map = Map.newBuilder[Token, Int]
val excluded = Set.newBuilder[TokenOps.TokenHash]
Expand All @@ -22,7 +21,7 @@ class TokenTraverser(tokens: Tokens, filename: String) {
map += (tok -> i)
i += 1
}
if (FileOps.isAmmonite(filename)) {
if (input.isInstanceOf[Input.Ammonite]) {
val realTokens = tokens.dropWhile(_.is[Token.BOF])
realTokens.headOption.foreach {
// shebang in .sc files
Expand Down
1 change: 1 addition & 0 deletions scalafmt-tests/src/test/resources/test/Dialect.source
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ import mill._
interp.repositories() =
interp.repositories() ++ Seq(coursier.MavenRepository("https://jitpack.io"))

@
@
import $ivy.`com.github.yyadavalli::mill-ensime:0.0.2`
<<< #2204 annotations
Expand Down

0 comments on commit 7937dd4

Please sign in to comment.