diff --git a/src/compiler/scala/tools/nsc/Global.scala b/src/compiler/scala/tools/nsc/Global.scala index fede7a4bb688..4f330bc83fcf 100644 --- a/src/compiler/scala/tools/nsc/Global.scala +++ b/src/compiler/scala/tools/nsc/Global.scala @@ -10,12 +10,15 @@ package nsc import java.io.{File, FileNotFoundException, IOException} import java.net.URL import java.nio.charset.{Charset, CharsetDecoder, IllegalCharsetNameException, UnsupportedCharsetException} +import java.util.concurrent.TimeUnit + import scala.collection.{immutable, mutable} import io.{AbstractFile, Path, SourceReader} -import reporters.Reporter +import reporters.{BufferedReporter, Reporter} import util.{ClassPath, returning} import scala.reflect.ClassTag -import scala.reflect.internal.util.{BatchSourceFile, NoSourceFile, ScalaClassLoader, ScriptSourceFile, SourceFile, StatisticsStatics} +import scala.reflect.internal.util.{BatchSourceFile, NoSourceFile, Parallel, ScalaClassLoader, ScriptSourceFile, SourceFile, StatisticsStatics} +import scala.reflect.internal.util.Parallel._ import scala.reflect.internal.pickling.PickleBuffer import symtab.{Flags, SymbolTable, SymbolTrackers} import symtab.classfile.Pickler @@ -26,12 +29,13 @@ import typechecker._ import transform.patmat.PatternMatching import transform._ import backend.{JavaPlatform, ScalaPrimitives} -import backend.jvm.{GenBCode, BackendStats} -import scala.concurrent.Future +import backend.jvm.{BackendStats, GenBCode} +import scala.concurrent.duration.Duration +import scala.concurrent._ import scala.language.postfixOps import scala.tools.nsc.ast.{TreeGen => AstTreeGen} import scala.tools.nsc.classpath._ -import scala.tools.nsc.profile.Profiler +import scala.tools.nsc.profile.{Profiler, ThreadPoolFactory} class Global(var currentSettings: Settings, reporter0: Reporter) extends SymbolTable @@ -75,16 +79,26 @@ class Global(var currentSettings: Settings, reporter0: Reporter) override def settings = currentSettings - private[this] var currentReporter: Reporter = { reporter = reporter0 ; currentReporter } + // Umad reported violation on: `scala.reflect.internal.SymbolTable.scala$reflect$internal$Names$$nc_$eq(int)` + // To synchronize access we can re-use mechanism from `scala.reflect.internal.Names` so far used for runtime mirror. + // For efficiency want to enable it only for phases which are parallelized. + override protected def synchronizeNames: Boolean = _synchronizeNames + private[this] var _synchronizeNames = false - def reporter: Reporter = currentReporter + /* `currentReporter` is used on both main thread as well as worker threads. + * On worker threads for every unit we are creating new `BufferedReporter` + * which at the end of the unit processing is dumped to the main reporter. + * We need to do it if we want to retain the same messages order as in case of single threaded execution. + */ + private[this] val currentReporter: WorkerOrMainThreadLocal[Reporter] = WorkerThreadLocal(reporter0, reporter0) + def reporter: Reporter = currentReporter.get def reporter_=(newReporter: Reporter): Unit = - currentReporter = newReporter match { + currentReporter.set(newReporter match { case _: reporters.ConsoleReporter | _: reporters.LimitingReporter => newReporter case _ if settings.maxerrs.isSetByUser && settings.maxerrs.value < settings.maxerrs.default => new reporters.LimitingReporter(settings, newReporter) case _ => newReporter - } + }) /** Switch to turn on detailed type logs */ var printTypings = settings.Ytyperdebug.value @@ -385,11 +399,6 @@ class Global(var currentSettings: Settings, reporter0: Reporter) abstract class GlobalPhase(prev: Phase) extends Phase(prev) { phaseWithId(id) = this - def run(): Unit = { - echoPhaseSummary(this) - currentRun.units foreach applyPhase - } - def apply(unit: CompilationUnit): Unit /** Is current phase cancelled on this unit? */ @@ -399,31 +408,96 @@ class Global(var currentSettings: Settings, reporter0: Reporter) reporter.cancelled || unit.isJava && this.id > maxJavaPhase } - final def withCurrentUnit(unit: CompilationUnit)(task: => Unit): Unit = { - if ((unit ne null) && unit.exists) - lastSeenSourceFile = unit.source + // Cleanup method needed for some phases, for example typer + def afterUnit(unit: CompilationUnit): Unit = {} + + def run(): Unit = { + assertOnMain() - if (settings.debug && (settings.verbose || currentRun.size < 5)) - inform("[running phase " + name + " on " + unit + "]") - if (!cancelled(unit)) { - currentRun.informUnitStarting(this, unit) - try withCurrentUnitNoLog(unit)(task) - finally currentRun.advanceUnit() + if (isDebugPrintEnabled) inform("[running phase " + name + " on " + currentRun.size + " compilation units]") + implicit val ec: ExecutionContextExecutor = createExecutionContext() + + try { + _synchronizeNames = isParallel + + /* Every unit is now run in separate `Future`. If given phase is not ran as parallel one + * (which is indicated by `isParallel`) it's swill run on the main thread. This is accomplished by + * properly modified `ExecutionContext` returned by `createExecutionContext`. + */ + val futures = currentRun.units.toList.collect { + case unit if !cancelled(unit) => + Future { + asWorkerThread { + processUnit(unit) + afterUnit(unit) + reporter + } + } + } + + /* Dumping messages from unit's `BufferedReporter` to main reporter. + * Since we are awaiting for previous units this allows us to retain messages order. + */ + futures.foreach { future => + val workerReporter = Await.result(future, Duration.Inf) + if (isParallel) workerReporter.asInstanceOf[BufferedReporter].flushTo(reporter) + } + } finally { + _synchronizeNames = false + + ec match { + case ecxs: ExecutionContextExecutorService => + ecxs.shutdown() + assert(ecxs.awaitTermination(1, TimeUnit.MINUTES)) + case _ => + } } } - final def withCurrentUnitNoLog(unit: CompilationUnit)(task: => Unit): Unit = { + // Used in methods like `compileLate` + final def applyPhase(unit: CompilationUnit): Unit = { + assertOnWorker() + if (!cancelled(unit)) processUnit(unit) + } + + private def processUnit(unit: CompilationUnit): Unit = { + assertOnWorker() + + /* In worker threads if we are processing units in parallel we want to use temporary `BufferedReporter` for every unit. + * Then later we can then keep it until all previous units are processed and then dump all messages to main reporter. + */ + if (isParallel) reporter = new BufferedReporter + + if (isDebugPrintEnabled) inform("[running phase " + name + " on " + unit + "]") + val unit0 = currentUnit + try { + if ((unit ne null) && unit.exists) lastSeenSourceFile = unit.source currentRun.currentUnit = unit - task + apply(unit) } finally { - //assert(currentUnit == unit) currentRun.currentUnit = unit0 + currentRun.advanceUnit() } } - final def applyPhase(unit: CompilationUnit) = withCurrentUnit(unit)(apply(unit)) + /* Only output a summary message under debug if we aren't echoing each file. */ + private def isDebugPrintEnabled: Boolean = settings.debug && !(settings.verbose || currentRun.size < 5) + + private def isParallel = settings.YparallelPhases.containsPhase(this) + + /* Depending if we are in the parallel phase or not it creates executor with fixed thread pool size or + * executor which runs everything on the current thread. + */ + private def createExecutionContext(): ExecutionContextExecutor = { + if (isParallel) { + val parallelThreads = settings.YparallelThreads.value + val threadPoolFactory = ThreadPoolFactory(Global.this, this) + val javaExecutor = threadPoolFactory.newUnboundedQueueFixedThreadPool(parallelThreads, "worker") + ExecutionContext.fromExecutorService(javaExecutor, _ => ()) + } else ExecutionContext.fromExecutor((task: Runnable) => task.run()) + } } // phaseName = "parser" @@ -949,11 +1023,18 @@ class Global(var currentSettings: Settings, reporter0: Reporter) } with typechecker.StructuredTypeStrings /** There are common error conditions where when the exception hits - * here, currentRun.currentUnit is null. This robs us of the knowledge - * of what file was being compiled when it broke. Since I really - * really want to know, this hack. + * here, currentRun.currentUnit is null. This robs us of the knowledge + * of what file was being compiled when it broke. Since I really + * really want to know, this hack. + * + * pkukielka: I'm not sure if I believe above comment but nevertheless + * if we want to keep that design we need to keep one `lastSeenSourceFile` per thread. + * Othervise in case of failure we would be reporting last unit/file which started to be processing + * which may or may not be the same as the one which failed. */ - protected var lastSeenSourceFile: SourceFile = NoSourceFile + private[this] final val _lastSeenSourceFile: WorkerThreadLocal[SourceFile] = WorkerThreadLocal(NoSourceFile) + @inline protected def lastSeenSourceFile: SourceFile = _lastSeenSourceFile.get + @inline protected def lastSeenSourceFile_=(source: SourceFile): Unit = _lastSeenSourceFile.set(source) /** Let's share a lot more about why we crash all over the place. * People will be very grateful. @@ -1058,12 +1139,6 @@ class Global(var currentSettings: Settings, reporter0: Reporter) */ override def currentRunId = curRunId - def echoPhaseSummary(ph: Phase) = { - /* Only output a summary message under debug if we aren't echoing each file. */ - if (settings.debug && !(settings.verbose || currentRun.size < 5)) - inform("[running phase " + ph.name + " on " + currentRun.size + " compilation units]") - } - def newSourceFile(code: String, filename: String = "") = new BatchSourceFile(filename, code) @@ -1090,7 +1165,9 @@ class Global(var currentSettings: Settings, reporter0: Reporter) */ var isDefined = false /** The currently compiled unit; set from GlobalPhase */ - var currentUnit: CompilationUnit = NoCompilationUnit + private[this] final val _currentUnit: WorkerOrMainThreadLocal[CompilationUnit] = WorkerThreadLocal(NoCompilationUnit) + def currentUnit: CompilationUnit = _currentUnit.get + def currentUnit_=(unit: CompilationUnit): Unit = _currentUnit.set(unit) val profiler: Profiler = Profiler(settings) keepPhaseStack = settings.log.isSetByUser @@ -1128,8 +1205,8 @@ class Global(var currentSettings: Settings, reporter0: Reporter) /** A map from compiled top-level symbols to their picklers */ val symData = new mutable.AnyRefMap[Symbol, PickleBuffer] - private var phasec: Int = 0 // phases completed - private var unitc: Int = 0 // units completed this phase + private var phasec: Int = 0 // phases completed + private final val unitc: Counter = new Counter // units completed this phase def size = unitbuf.size override def toString = "scalac Run for:\n " + compiledFiles.toList.sorted.mkString("\n ") @@ -1250,7 +1327,7 @@ class Global(var currentSettings: Settings, reporter0: Reporter) * (for progress reporting) */ def advancePhase(): Unit = { - unitc = 0 + unitc.reset() phasec += 1 refreshProgress() } @@ -1258,14 +1335,14 @@ class Global(var currentSettings: Settings, reporter0: Reporter) * (for progress reporting) */ def advanceUnit(): Unit = { - unitc += 1 + unitc.incrementAndGet() refreshProgress() } // for sbt def cancel(): Unit = { reporter.cancelled = true } - private def currentProgress = (phasec * size) + unitc + private def currentProgress = (phasec * size) + unitc.get private def totalProgress = (phaseDescriptors.size - 1) * size // -1: drops terminal phase private def refreshProgress() = if (size > 0) progress(currentProgress, totalProgress) @@ -1429,8 +1506,16 @@ class Global(var currentSettings: Settings, reporter0: Reporter) private final val GlobalPhaseName = "global (synthetic)" protected final val totalCompileTime = statistics.newTimer("#total compile time", GlobalPhaseName) - def compileUnits(units: List[CompilationUnit], fromPhase: Phase = firstPhase): Unit = + /* We assumes that everything that is done while compiling unit is done on the worker thread. + * Most work is done (and most code is run) in the workers thread, comparing to main. + * Because of that we make design decision that every thread is marked as worker by default. + * It makes our life easier when dealing with tests, toolbox, and virtually everything which not goes through `Global`. + * But now we need to somehow mark the main thread. That is done by `Parallel.asMainThread`. + */ + def compileUnits(units: List[CompilationUnit], fromPhase: Phase = firstPhase): Unit = Parallel.asMainThread { compileUnitsInternal(units, fromPhase) + } + private def compileUnitsInternal(units: List[CompilationUnit], fromPhase: Phase): Unit = { units foreach addUnit reporter.reset() diff --git a/src/compiler/scala/tools/nsc/ast/Positions.scala b/src/compiler/scala/tools/nsc/ast/Positions.scala index 36a9f371edcd..16f6cda0f76a 100644 --- a/src/compiler/scala/tools/nsc/ast/Positions.scala +++ b/src/compiler/scala/tools/nsc/ast/Positions.scala @@ -1,6 +1,8 @@ package scala.tools.nsc package ast +import scala.reflect.internal.util.Parallel.WorkerThreadLocal + trait Positions extends scala.reflect.internal.Positions { self: Global => @@ -24,7 +26,8 @@ trait Positions extends scala.reflect.internal.Positions { } } - override protected[this] lazy val posAssigner: PosAssigner = + override protected[this] final val _posAssigner: WorkerThreadLocal[PosAssigner] = WorkerThreadLocal { if (settings.Yrangepos && settings.debug || settings.Yposdebug) new ValidatingPosAssigner else new DefaultPosAssigner + } } diff --git a/src/compiler/scala/tools/nsc/profile/Profiler.scala b/src/compiler/scala/tools/nsc/profile/Profiler.scala index d0931071b3a1..8d3b10e781c3 100644 --- a/src/compiler/scala/tools/nsc/profile/Profiler.scala +++ b/src/compiler/scala/tools/nsc/profile/Profiler.scala @@ -107,18 +107,18 @@ private [profile] class RealProfiler(reporter : ProfileReporter, val settings: S private val mainThread = Thread.currentThread() - private[profile] def snapThread( idleTimeNanos:Long): ProfileSnap = { + private[profile] def snapThread(idleTimeNanos:Long, thread: Thread = Thread.currentThread()): ProfileSnap = { import RealProfiler._ - val current = Thread.currentThread() + val threadId = thread.getId ProfileSnap( - threadId = current.getId, - threadName = current.getName, + threadId = threadId, + threadName = thread.getName, snapTimeNanos = System.nanoTime(), idleTimeNanos = idleTimeNanos, - cpuTimeNanos = threadMx.getCurrentThreadCpuTime, - userTimeNanos = threadMx.getCurrentThreadUserTime, - allocatedBytes = threadMx.getThreadAllocatedBytes(Thread.currentThread().getId), + cpuTimeNanos = threadMx.getThreadCpuTime(threadId), + userTimeNanos = threadMx.getThreadUserTime(threadId), + allocatedBytes = threadMx.getThreadAllocatedBytes(threadId), heapBytes = readHeapUsage() ) } diff --git a/src/compiler/scala/tools/nsc/profile/ThreadPoolFactory.scala b/src/compiler/scala/tools/nsc/profile/ThreadPoolFactory.scala index 33d8cefde10b..ed1a907e3afd 100644 --- a/src/compiler/scala/tools/nsc/profile/ThreadPoolFactory.scala +++ b/src/compiler/scala/tools/nsc/profile/ThreadPoolFactory.scala @@ -35,7 +35,7 @@ object ThreadPoolFactory { protected def wrapWorker(worker: Runnable, shortId: String): Runnable = worker protected final class CommonThreadFactory( - shortId: String, + val shortId: String, daemon: Boolean = true, priority: Int) extends ThreadFactory { private val group: ThreadGroup = childGroup(shortId) @@ -73,31 +73,28 @@ object ThreadPoolFactory { override def newUnboundedQueueFixedThreadPool(nThreads: Int, shortId: String, priority: Int): ThreadPoolExecutor = { val threadFactory = new CommonThreadFactory(shortId, priority = priority) //like Executors.newFixedThreadPool - new SinglePhaseInstrumentedThreadPoolExecutor(nThreads, nThreads, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue[Runnable], threadFactory, new AbortPolicy) + new SinglePhaseInstrumentedThreadPoolExecutor(nThreads, nThreads, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue[Runnable], threadFactory, new AbortPolicy, baseGroup) } override def newBoundedQueueFixedThreadPool(nThreads: Int, maxQueueSize: Int, rejectHandler: RejectedExecutionHandler, shortId: String, priority: Int): ThreadPoolExecutor = { val threadFactory = new CommonThreadFactory(shortId, priority = priority) //like Executors.newFixedThreadPool - new SinglePhaseInstrumentedThreadPoolExecutor(nThreads, nThreads, 0L, TimeUnit.MILLISECONDS, new ArrayBlockingQueue[Runnable](maxQueueSize), threadFactory, rejectHandler) + new SinglePhaseInstrumentedThreadPoolExecutor(nThreads, nThreads, 0L, TimeUnit.MILLISECONDS, new ArrayBlockingQueue[Runnable](maxQueueSize), threadFactory, rejectHandler, baseGroup) } override protected def wrapWorker(worker: Runnable, shortId: String): Runnable = () => { val data = new ThreadProfileData - localData.set(data) - - val profileStart = profiler.snapThread(0) - try worker.run finally { - val snap = profiler.snapThread(data.idleNs) - val threadRange = ProfileRange(profileStart, snap, phase, shortId, data.taskCount, Thread.currentThread()) - profiler.completeBackground(threadRange) - } + data.profileStart = profiler.snapThread(0) + localData.put(Thread.currentThread().getId, data) + worker.run() } /** * data for thread run. Not threadsafe, only written from a single thread */ final class ThreadProfileData { + var profileStart: ProfileSnap = _ + var firstStartNs = 0L var taskCount = 0 @@ -108,15 +105,28 @@ object ThreadPoolFactory { var lastEndNs = 0L } - val localData = new ThreadLocal[ThreadProfileData] + val localData = new ConcurrentHashMap[Long, ThreadProfileData] private class SinglePhaseInstrumentedThreadPoolExecutor( corePoolSize: Int, maximumPoolSize: Int, keepAliveTime: Long, unit: TimeUnit, - workQueue: BlockingQueue[Runnable], threadFactory: ThreadFactory, handler: RejectedExecutionHandler) + workQueue: BlockingQueue[Runnable], threadFactory: CommonThreadFactory, handler: RejectedExecutionHandler, group: ThreadGroup) extends ThreadPoolExecutor(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory, handler) { + override def shutdown(): Unit = { + val arr = new Array[Thread](group.activeCount()) + group.enumerate(arr) + arr.foreach { thread => + val data = localData.get(thread.getId) + val snap = profiler.snapThread(data.idleNs, thread) + val threadRange = ProfileRange(data.profileStart, snap, phase, threadFactory.shortId, data.taskCount, thread) + profiler.completeBackground(threadRange) + } + + super.shutdown() + } + override def beforeExecute(t: Thread, r: Runnable): Unit = { - val data = localData.get + val data = localData.get(t.getId) data.taskCount += 1 val now = System.nanoTime() @@ -130,7 +140,7 @@ object ThreadPoolFactory { override def afterExecute(r: Runnable, t: Throwable): Unit = { val now = System.nanoTime() - val data = localData.get + val data = localData.get(Thread.currentThread().getId) data.lastEndNs = now data.runningNs += now - data.lastStartNs diff --git a/src/compiler/scala/tools/nsc/reporters/BufferedReporter.scala b/src/compiler/scala/tools/nsc/reporters/BufferedReporter.scala new file mode 100644 index 000000000000..37ae5d0ea184 --- /dev/null +++ b/src/compiler/scala/tools/nsc/reporters/BufferedReporter.scala @@ -0,0 +1,29 @@ +package scala.tools.nsc.reporters + +import scala.reflect.internal.util.Parallel.{assertOnMain, assertOnWorker} +import scala.reflect.internal.util.Position + +/* Simple Reporter which allows us to accumulate messages over time + * and then at suitable time forward them to other reporter using `flushTo` method + */ +final class BufferedReporter extends Reporter { + private[this] var buffered = List.empty[BufferedMessage] + + protected def info0(pos: Position, msg: String, severity: Severity, force: Boolean): Unit = { + assertOnWorker() + buffered = BufferedMessage(pos, msg, severity, force) :: buffered + severity.count += 1 + } + + def flushTo(reporter: Reporter): Unit = { + assertOnMain() + val sev = Array(reporter.INFO, reporter.WARNING, reporter.ERROR) + buffered.reverse.foreach { + msg => + reporter.info(msg.pos, msg.msg, sev(msg.severity.id), msg.force) + } + buffered = Nil + } + + private case class BufferedMessage(pos: Position, msg: String, severity: Severity, force: Boolean) +} \ No newline at end of file diff --git a/src/compiler/scala/tools/nsc/reporters/Reporter.scala b/src/compiler/scala/tools/nsc/reporters/Reporter.scala index 91a28f61f970..ef8cd74ab809 100644 --- a/src/compiler/scala/tools/nsc/reporters/Reporter.scala +++ b/src/compiler/scala/tools/nsc/reporters/Reporter.scala @@ -19,6 +19,12 @@ abstract class Reporter extends scala.reflect.internal.Reporter { /** Informational messages. If `!force`, they may be suppressed. */ final def info(pos: Position, msg: String, force: Boolean): Unit = info0(pos, msg, INFO, force) + /* Ugly hack as we need access to internal info0 method from other reporter. + * Unluckily base `Reporter` class is in different package: `scala.reflect.internal`. + */ + protected[reporters] def info(pos: Position, msg: String, severity: Severity, force: Boolean): Unit = + info0(pos, msg, severity, force) + /** For sending a message which should not be labelled as a warning/error, * but also shouldn't require -verbose to be visible. */ diff --git a/src/compiler/scala/tools/nsc/settings/ScalaSettings.scala b/src/compiler/scala/tools/nsc/settings/ScalaSettings.scala index 2c4ef991add3..76371588081e 100644 --- a/src/compiler/scala/tools/nsc/settings/ScalaSettings.scala +++ b/src/compiler/scala/tools/nsc/settings/ScalaSettings.scala @@ -143,7 +143,10 @@ trait ScalaSettings extends AbsScalaSettings val reporter = StringSetting ("-Xreporter", "classname", "Specify a custom reporter for compiler messages.", "scala.tools.nsc.reporters.ConsoleReporter") val source = ScalaVersionSetting ("-Xsource", "version", "Treat compiler input as Scala source for the specified version, see scala/bug#8126.", initial = ScalaVersion("2.13")) - val XnoPatmatAnalysis = BooleanSetting ("-Xno-patmat-analysis", "Don't perform exhaustivity/unreachability analysis. Also, ignore @switch annotation.") + val XnoPatmatAnalysis = BooleanSetting ("-Xno-patmat-analysis", "Don't perform exhaustivity/unreachability analysis. Also, ignore @switch annotation.") + + val YparallelPhases = PhasesSetting ("-Yparallel-phases", "Which phases to run in parallel") + val YparallelThreads = IntSetting ("-Yparallel-threads", "Worker threads for parallel compilation", 4, Some((0,64)), _ => None ) val XmixinForceForwarders = ChoiceSetting( name = "-Xmixin-force-forwarders", diff --git a/src/compiler/scala/tools/nsc/typechecker/Analyzer.scala b/src/compiler/scala/tools/nsc/typechecker/Analyzer.scala index b25119d6ba30..b943551a7272 100644 --- a/src/compiler/scala/tools/nsc/typechecker/Analyzer.scala +++ b/src/compiler/scala/tools/nsc/typechecker/Analyzer.scala @@ -87,17 +87,19 @@ trait Analyzer extends AnyRef // Lacking a better fix, we clear it here (before the phase is created, meaning for each // compiler run). This is good enough for the resident compiler, which was the most affected. undoLog.clear() + + override def afterUnit(unit: CompilationUnit): Unit = undoLog.clear() + override def run(): Unit = { val start = if (StatisticsStatics.areSomeColdStatsEnabled) statistics.startTimer(statistics.typerNanos) else null - global.echoPhaseSummary(this) - for (unit <- currentRun.units) { - applyPhase(unit) - undoLog.clear() - } + // We never want to completely override `run` without calling `super.run()` inside. + // `run` is now more complicated that for over the units and there is no point in duplicating that logic + super.run() // defensive measure in case the bookkeeping in deferred macro expansion is buggy clearDelayed() if (StatisticsStatics.areSomeColdStatsEnabled) statistics.stopTimer(statistics.typerNanos, start) } + def apply(unit: CompilationUnit): Unit = { try { val typer = newTyper(rootContext(unit)) diff --git a/src/reflect/scala/reflect/api/Trees.scala b/src/reflect/scala/reflect/api/Trees.scala index 355cc65b118c..4734a2e52c92 100644 --- a/src/reflect/scala/reflect/api/Trees.scala +++ b/src/reflect/scala/reflect/api/Trees.scala @@ -6,6 +6,8 @@ package scala package reflect package api +import scala.reflect.internal.util.Parallel.WorkerThreadLocal + /** * EXPERIMENTAL * @@ -2463,7 +2465,13 @@ trait Trees { self: Universe => * @group Traversal */ class Traverser { - protected[scala] var currentOwner: Symbol = rootMirror.RootClass + /** Access from multiple threads was reported by umad. + * That possibly could be solved by ensuring that every unit operates on it's own copy of the tree, + * but it would require much bigger refactorings and would be more memory consuming. + */ + @inline final protected[scala] def currentOwner: Symbol = _currentOwner.get + @inline final protected[scala] def currentOwner_=(sym: Symbol): Unit = _currentOwner.set(sym) + private final val _currentOwner: WorkerThreadLocal[Symbol] = WorkerThreadLocal(rootMirror.RootClass) /** Traverse something which Trees contain, but which isn't a Tree itself. */ def traverseName(name: Name): Unit = () @@ -2534,8 +2542,15 @@ trait Trees { self: Universe => /** The underlying tree copier. */ val treeCopy: TreeCopier = newLazyTreeCopier - /** The current owner symbol. */ - protected[scala] var currentOwner: Symbol = rootMirror.RootClass + /** The current owner symbol. + * + * Access from multiple threads was reported by umad. + * That possibly could be solved by ensuring that every unit operates on it's own copy of the tree, + * but it would require much bigger refactorings and would be more memory consuming. + */ + @inline protected[scala] final def currentOwner: Symbol = _currentOwner.get + @inline protected[scala] final def currentOwner_=(sym: Symbol): Unit = _currentOwner.set(sym) + private final val _currentOwner: WorkerThreadLocal[Symbol] = WorkerThreadLocal(rootMirror.RootClass) /** The enclosing method of the currently transformed tree. */ protected def currentMethod = { diff --git a/src/reflect/scala/reflect/internal/Positions.scala b/src/reflect/scala/reflect/internal/Positions.scala index 66a3d72796a9..8ea16fddb256 100644 --- a/src/reflect/scala/reflect/internal/Positions.scala +++ b/src/reflect/scala/reflect/internal/Positions.scala @@ -5,6 +5,7 @@ package internal import scala.collection.mutable import util._ import scala.collection.mutable.ListBuffer +import scala.reflect.internal.util.Parallel.WorkerThreadLocal /** Handling range positions * atPos, the main method in this trait, will add positions to a tree, @@ -280,7 +281,10 @@ trait Positions extends api.Positions { self: SymbolTable => trait PosAssigner extends InternalTraverser { var pos: Position } - protected[this] lazy val posAssigner: PosAssigner = new DefaultPosAssigner + + // Reported by umad + protected[this] val _posAssigner: WorkerThreadLocal[PosAssigner] = WorkerThreadLocal(new DefaultPosAssigner) + @inline protected[this] final def posAssigner: PosAssigner = _posAssigner.get protected class DefaultPosAssigner extends PosAssigner { var pos: Position = _ diff --git a/src/reflect/scala/reflect/internal/SymbolTable.scala b/src/reflect/scala/reflect/internal/SymbolTable.scala index f667005320e8..af59c1fedff8 100644 --- a/src/reflect/scala/reflect/internal/SymbolTable.scala +++ b/src/reflect/scala/reflect/internal/SymbolTable.scala @@ -13,6 +13,7 @@ import util._ import java.util.concurrent.TimeUnit import scala.collection.mutable.ArrayBuffer +import scala.reflect.internal.util.Parallel.synchronizeAccess import scala.reflect.internal.{TreeGen => InternalTreeGen} abstract class SymbolTable extends macros.Universe @@ -53,6 +54,12 @@ abstract class SymbolTable extends macros.Universe val gen = new InternalTreeGen { val global: SymbolTable.this.type = SymbolTable.this } + // Wrapper for `synchronized` method. In future could provide additional logging, safety checks, etc. + // We are locking on `synchronizeSymbolsAccess` object which is created per `SymbolTable` instance + object synchronizeSymbolsAccess { + def apply[T](block: => T): T = synchronizeAccess(this)(block) + } + trait ReflectStats extends BaseTypeSeqsStats with TypesStats with SymbolTableStats diff --git a/src/reflect/scala/reflect/internal/Symbols.scala b/src/reflect/scala/reflect/internal/Symbols.scala index efaaaf107331..574e167f7df5 100644 --- a/src/reflect/scala/reflect/internal/Symbols.scala +++ b/src/reflect/scala/reflect/internal/Symbols.scala @@ -3116,13 +3116,19 @@ trait Symbols extends api.Symbols { self: SymbolTable => * type arguments. */ override def tpe_* : Type = { - maybeUpdateTypeCache() - tpeCache + // We are simply locking all completers on current `SymbolTable` for now. + // Should we see if that will be efficient enough. + synchronizeSymbolsAccess { + maybeUpdateTypeCache() + tpeCache + } } override def typeConstructor: Type = { - if (tyconCacheNeedsUpdate) - setTyconCache(newTypeRef(Nil)) - tyconCache + synchronizeSymbolsAccess { + if (tyconCacheNeedsUpdate) + setTyconCache(newTypeRef(Nil)) + tyconCache + } } override def tpeHK: Type = typeConstructor @@ -3141,6 +3147,7 @@ trait Symbols extends api.Symbols { self: SymbolTable => updateTypeCache() // perform the actual update } } + private def updateTypeCache() { if (tpeCache eq NoType) throw CyclicReference(this, typeConstructor) diff --git a/src/reflect/scala/reflect/internal/Trees.scala b/src/reflect/scala/reflect/internal/Trees.scala index da8abfb1b593..9f0b551fc709 100644 --- a/src/reflect/scala/reflect/internal/Trees.scala +++ b/src/reflect/scala/reflect/internal/Trees.scala @@ -9,13 +9,14 @@ package internal import Flags._ import scala.collection.mutable +import scala.reflect.internal.util.Parallel.Counter import scala.reflect.macros.Attachments import util.{Statistics, StatisticsStatics} trait Trees extends api.Trees { self: SymbolTable => - private[scala] var nodeCount = 0 + private[scala] final val nodeCount: Counter = new Counter protected def treeLine(t: Tree): String = if (t.pos.isDefined && t.pos.isRange) t.pos.lineContent.drop(t.pos.column - 1).take(t.pos.end - t.pos.start + 1) @@ -35,8 +36,7 @@ trait Trees extends api.Trees { } abstract class Tree extends TreeContextApiImpl with Attachable with Product { - val id = nodeCount // TODO: add to attachment? - nodeCount += 1 + val id = nodeCount.getAndIncrement() // TODO: add to attachment? if (StatisticsStatics.areSomeHotStatsEnabled()) statistics.incCounter(statistics.nodeByType, getClass) diff --git a/src/reflect/scala/reflect/internal/util/Parallel.scala b/src/reflect/scala/reflect/internal/util/Parallel.scala new file mode 100644 index 000000000000..57833f2aa03c --- /dev/null +++ b/src/reflect/scala/reflect/internal/util/Parallel.scala @@ -0,0 +1,101 @@ +package scala.reflect.internal.util + +import java.util.concurrent.atomic.AtomicInteger + +object Parallel { + + class Counter { + private val count = new AtomicInteger + + @inline final def get: Int = count.get() + + @inline final def reset(): Unit = { + assertOnMain() + count.set(0) + } + + @inline final def incrementAndGet(): Int = count.incrementAndGet + + @inline final def getAndIncrement(): Int = count.getAndIncrement + + @inline final override def toString: String = s"Counter[$count]" + } + + // Wrapper for `synchronized` method. In future could provide additional logging, safety checks, etc. + def synchronizeAccess[T <: Object, U](obj: T)(block: => U): U = { + obj.synchronized[U](block) + } + + def WorkerThreadLocal[T <: AnyRef](valueOnWorker: => T, valueOnMain: => T) = new WorkerOrMainThreadLocal[T](valueOnWorker, valueOnMain) + + def WorkerThreadLocal[T <: AnyRef](valueOnWorker: => T) = new WorkerThreadLocal[T](valueOnWorker) + + // `WorkerOrMainThreadLocal` allows us to have different type of values on main and worker threads. + // It's useful in cases like reporter, when on workers we want to just store messages and on main we want to print them, + class WorkerOrMainThreadLocal[T](valueOnWorker: => T, valueOnMain: => T) { + + private var main: T = null.asInstanceOf[T] + + private val worker: ThreadLocal[T] = new ThreadLocal[T] { + override def initialValue(): T = valueOnWorker + } + + @inline final def get: T = { + if (isWorker.get()) worker.get() + else { + if (main == null) main = valueOnMain + main + } + } + + @inline final def set(value: T): Unit = if (isWorker.get()) worker.set(value) else main = value + + @inline final def reset(): Unit = { + worker.remove() + main = valueOnMain + } + } + + // `WorkerThreadLocal` detects reads/writes of given value on the main thread and + // and report such violations by throwing exception. + class WorkerThreadLocal[T](valueOnWorker: => T) + extends WorkerOrMainThreadLocal(valueOnWorker, throw new IllegalStateException("not allowed on main thread")) + + // Asserts that current execution happens on the main thread + @inline final def assertOnMain(): Unit = { + if (ParallelSettings.areAssertionsEnabled) assert(!isWorker.get()) + } + + // Asserts that current execution happens on the worker thread + @inline final def assertOnWorker(): Unit = { + if (ParallelSettings.areAssertionsEnabled) assert(isWorker.get()) + } + + // Runs block of the code in the 'worker thread' mode + // All unit processing should always happen in the worker thread + @inline final def asWorkerThread[T](fn: => T): T = { + val previous = isWorker.get() + isWorker.set(true) + try fn finally isWorker.set(previous) + } + + // Runs block of the code in the 'main thread' mode. + // In 'main' mode we mostly sets/resets global variables, initialize contexts, + // and orchestrate processing of phases/units + @inline final def asMainThread[T](fn: => T): T = { + val previous = isWorker.get() + isWorker.set(false) + try fn finally isWorker.set(previous) + } + + // ThreadLocal variable which allows us to mark current thread as main or worker. + // This is important because real main thread is not necessarily always running 'main' code. + // Good example may be tests which all runs in one main thread, although often processes units + // (what conceptually should always happen in workers). + // Because there is much more entry points to unit processing than to Global, + // it's much easier to start with assuming everything is initially worker thread + // and just mark main accordingly when needed. + private val isWorker: ThreadLocal[Boolean] = new ThreadLocal[Boolean] { + override def initialValue(): Boolean = true + } +} \ No newline at end of file diff --git a/src/reflect/scala/reflect/internal/util/ParallelSettings.java b/src/reflect/scala/reflect/internal/util/ParallelSettings.java new file mode 100644 index 000000000000..d8a407ea6119 --- /dev/null +++ b/src/reflect/scala/reflect/internal/util/ParallelSettings.java @@ -0,0 +1,5 @@ +package scala.reflect.internal.util; + +public class ParallelSettings { + final static boolean areAssertionsEnabled = Boolean.valueOf(System.getProperty("parallel-assertions", "true")); +}