From b872dac59ef52a85ad05a9eb8f0bbddfdccf60a8 Mon Sep 17 00:00:00 2001 From: Bor Kae Hwang Date: Thu, 11 Oct 2018 15:21:05 -0600 Subject: [PATCH] Launch multiple threads for worker processes --- rules/private/phases/phase_zinc_compile.bzl | 2 +- .../private/ScalaProtoWorker.scala | 4 +- rules/scalafmt/private/test.bzl | 2 +- rules/scalafmt/scalafmt/ScalafmtRunner.scala | 4 +- .../common/worker/WorkerMain.scala | 69 ++++++++++++++----- .../workers/bloop/compile/BloopRunner.scala | 3 +- .../workers/common/AnnexLogger.scala | 15 ++-- .../rules_scala/workers/deps/DepsRunner.scala | 16 ++--- .../workers/zinc/compile/ZincRunner.scala | 6 +- .../workers/zinc/doc/DocRunner.scala | 4 +- .../src/main/protobuf/worker_protocol.proto | 4 ++ 11 files changed, 82 insertions(+), 47 deletions(-) diff --git a/rules/private/phases/phase_zinc_compile.bzl b/rules/private/phases/phase_zinc_compile.bzl index 8024fa34..1c5f709b 100644 --- a/rules/private/phases/phase_zinc_compile.bzl +++ b/rules/private/phases/phase_zinc_compile.bzl @@ -88,7 +88,7 @@ def phase_zinc_compile(ctx, g): outputs = outputs, executable = worker.files_to_run.executable, input_manifests = input_manifests, - execution_requirements = _resolve_execution_reqs(ctx, {"no-sandbox": "1", "supports-workers": "1"}), + execution_requirements = _resolve_execution_reqs(ctx, {"no-sandbox": "1", "supports-multiplex-workers": "1"}), arguments = [args], ) diff --git a/rules/scala_proto/private/ScalaProtoWorker.scala b/rules/scala_proto/private/ScalaProtoWorker.scala index 195e0f9a..07a261cc 100644 --- a/rules/scala_proto/private/ScalaProtoWorker.scala +++ b/rules/scala_proto/private/ScalaProtoWorker.scala @@ -2,7 +2,7 @@ package annex.scala.proto import higherkindness.rules_scala.common.args.implicits._ import higherkindness.rules_scala.common.worker.WorkerMain -import java.io.File +import java.io.{File, PrintStream} import java.nio.file.{Files, Paths} import java.util.Collections import net.sourceforge.argparse4j.ArgumentParsers @@ -33,7 +33,7 @@ object ScalaProtoWorker extends WorkerMain[Unit] { override def init(args: Option[Array[String]]): Unit = () - protected[this] def work(ctx: Unit, args: Array[String]): Unit = { + protected[this] def work(ctx: Unit, args: Array[String], out: PrintStream): Unit = { val namespace = argParser.parseArgs(args) val sources = namespace.getList[File]("sources").asScala.toList diff --git a/rules/scalafmt/private/test.bzl b/rules/scalafmt/private/test.bzl index 684f9cb5..507d1a5f 100644 --- a/rules/scalafmt/private/test.bzl +++ b/rules/scalafmt/private/test.bzl @@ -52,7 +52,7 @@ def build_format(ctx): input_manifests = runner_manifests, inputs = [ctx.file.config, src], tools = runner_inputs, - execution_requirements = _resolve_execution_reqs(ctx, {"supports-workers": "1"}), + execution_requirements = _resolve_execution_reqs(ctx, {"supports-multiplex-workers": "1"}), mnemonic = "ScalaFmt", ) manifest_content.append("{} {}".format(src.short_path, file.short_path)) diff --git a/rules/scalafmt/scalafmt/ScalafmtRunner.scala b/rules/scalafmt/scalafmt/ScalafmtRunner.scala index ac33064b..a03c3363 100644 --- a/rules/scalafmt/scalafmt/ScalafmtRunner.scala +++ b/rules/scalafmt/scalafmt/ScalafmtRunner.scala @@ -2,7 +2,7 @@ package annex.scalafmt import higherkindness.rules_scala.common.worker.WorkerMain import higherkindness.rules_scala.workers.common.Color -import java.io.File +import java.io.{File, PrintStream} import java.nio.file.Files import net.sourceforge.argparse4j.ArgumentParsers import net.sourceforge.argparse4j.impl.Arguments @@ -16,7 +16,7 @@ object ScalafmtRunner extends WorkerMain[Unit] { protected[this] def init(args: Option[Array[String]]): Unit = {} - protected[this] def work(worker: Unit, args: Array[String]): Unit = { + protected[this] def work(worker: Unit, args: Array[String], out: PrintStream): Unit = { val parser = ArgumentParsers.newFor("scalafmt").addHelp(true).defaultFormatWidth(80).fromFilePrefix("@").build parser.addArgument("--config").required(true).`type`(Arguments.fileType) diff --git a/src/main/scala/higherkindness/rules_scala/common/worker/WorkerMain.scala b/src/main/scala/higherkindness/rules_scala/common/worker/WorkerMain.scala index 16a1973c..6e94547a 100644 --- a/src/main/scala/higherkindness/rules_scala/common/worker/WorkerMain.scala +++ b/src/main/scala/higherkindness/rules_scala/common/worker/WorkerMain.scala @@ -8,8 +8,13 @@ import java.io.InputStream import java.io.PrintStream import java.lang.SecurityManager import java.security.Permission +import java.util.concurrent.Executors import scala.annotation.tailrec +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.Future import scala.util.control.NonFatal +import scala.util.Success +import scala.util.Failure trait WorkerMain[S] { @@ -17,7 +22,7 @@ trait WorkerMain[S] { protected[this] def init(args: Option[Array[String]]): S - protected[this] def work(ctx: S, args: Array[String]): Unit + protected[this] def work(ctx: S, args: Array[String], out: PrintStream): Unit final def main(args: Array[String]): Unit = { args.toList match { @@ -36,12 +41,11 @@ trait WorkerMain[S] { } }) - val outStream = new ByteArrayOutputStream - val out = new PrintStream(outStream) + val garbageOut = new PrintStream(new ByteArrayOutputStream) System.setIn(new ByteArrayInputStream(Array.emptyByteArray)) - System.setOut(out) - System.setErr(out) + System.setOut(garbageOut) + System.setErr(garbageOut) try { @tailrec @@ -49,26 +53,50 @@ trait WorkerMain[S] { val request = WorkerProtocol.WorkRequest.parseDelimitedFrom(stdin) val args = request.getArgumentsList.toArray(Array.empty[String]) - val code = + val outStream = new ByteArrayOutputStream + val out = new PrintStream(outStream) + val requestId = request.getRequestId() + + val f: Future[Int] = Future { try { - work(ctx, args) + work(ctx, args, out) 0 } catch { case ExitTrapped(code) => code case NonFatal(e) => - e.printStackTrace() + e.printStackTrace(out) 1 } + } - WorkerProtocol.WorkResponse.newBuilder - .setOutput(outStream.toString) - .setExitCode(code) - .build - .writeDelimitedTo(stdout) - - out.flush() - outStream.reset() - + f.onComplete { + case Success(code) => + synchronized { + + WorkerProtocol.WorkResponse.newBuilder + .setRequestId(requestId) + .setOutput(outStream.toString) + .setExitCode(code) + .build + .writeDelimitedTo(stdout) + + out.flush() + outStream.reset() + } + case Failure(e) => { + e.printStackTrace(out) + + WorkerProtocol.WorkResponse.newBuilder + .setRequestId(requestId) + .setOutput(outStream.toString) + .setExitCode(-1) + .build + .writeDelimitedTo(stdout) + + out.flush() + outStream.reset() + } + } process(ctx) } process(init(Some(args.toArray))) @@ -78,8 +106,11 @@ trait WorkerMain[S] { System.setErr(stderr) } - case args => work(init(None), args.toArray) + case args => { + val outStream = new ByteArrayOutputStream + val out = new PrintStream(outStream) + work(init(None), args.toArray, out) + } } } - } diff --git a/src/main/scala/higherkindness/rules_scala/workers/bloop/compile/BloopRunner.scala b/src/main/scala/higherkindness/rules_scala/workers/bloop/compile/BloopRunner.scala index 2f43c935..2f82080b 100644 --- a/src/main/scala/higherkindness/rules_scala/workers/bloop/compile/BloopRunner.scala +++ b/src/main/scala/higherkindness/rules_scala/workers/bloop/compile/BloopRunner.scala @@ -4,8 +4,9 @@ package workers.bloop.compile import common.worker.WorkerMain import bloop.Bloop +import java.io.PrintStream object BloopRunner extends WorkerMain[Unit] { override def init(args: Option[Array[String]]): Unit = () - override def work(ctx: Unit, args: Array[String]): Unit = Bloop + override def work(ctx: Unit, args: Array[String], out: PrintStream): Unit = Bloop } diff --git a/src/main/scala/higherkindness/rules_scala/workers/common/AnnexLogger.scala b/src/main/scala/higherkindness/rules_scala/workers/common/AnnexLogger.scala index 5143c174..c41d3f21 100644 --- a/src/main/scala/higherkindness/rules_scala/workers/common/AnnexLogger.scala +++ b/src/main/scala/higherkindness/rules_scala/workers/common/AnnexLogger.scala @@ -3,31 +3,30 @@ package workers.common import xsbti.Logger -import java.io.PrintWriter -import java.io.StringWriter +import java.io.{PrintStream, PrintWriter, StringWriter} import java.nio.file.Paths import java.util.function.Supplier import CommonArguments.LogLevel -final class AnnexLogger(level: String) extends Logger { +final class AnnexLogger(level: String, out: PrintStream = System.err) extends Logger { private[this] val root = s"${Paths.get("").toAbsolutePath}/" private[this] def format(value: String): String = value.replace(root, "") def debug(msg: Supplier[String]): Unit = level match { - case LogLevel.Debug => System.err.println(format(msg.get)) + case LogLevel.Debug => out.println(format(msg.get)) case _ => } def error(msg: Supplier[String]): Unit = level match { - case LogLevel.Debug | LogLevel.Error | LogLevel.Info | LogLevel.Warn => System.err.println(format(msg.get)) + case LogLevel.Debug | LogLevel.Error | LogLevel.Info | LogLevel.Warn => out.println(format(msg.get)) case _ => } def info(msg: Supplier[String]): Unit = level match { - case LogLevel.Debug | LogLevel.Info => System.err.println(format(msg.get)) + case LogLevel.Debug | LogLevel.Info => out.println(format(msg.get)) case _ => } @@ -35,12 +34,12 @@ final class AnnexLogger(level: String) extends Logger { case LogLevel.Debug | LogLevel.Error | LogLevel.Info | LogLevel.Warn => val trace = new StringWriter(); err.get.printStackTrace(new PrintWriter(trace)); - println(format(trace.toString)) + out.println(format(trace.toString)) case _ => } def warn(msg: Supplier[String]): Unit = level match { - case LogLevel.Debug | LogLevel.Info | LogLevel.Warn => System.err.println(format(msg.get)) + case LogLevel.Debug | LogLevel.Info | LogLevel.Warn => out.println(format(msg.get)) case _ => } diff --git a/src/main/scala/higherkindness/rules_scala/workers/deps/DepsRunner.scala b/src/main/scala/higherkindness/rules_scala/workers/deps/DepsRunner.scala index 23b4319f..5f803e0a 100644 --- a/src/main/scala/higherkindness/rules_scala/workers/deps/DepsRunner.scala +++ b/src/main/scala/higherkindness/rules_scala/workers/deps/DepsRunner.scala @@ -4,7 +4,7 @@ package workers.deps import common.args.implicits._ import common.worker.WorkerMain -import java.io.File +import java.io.{File, PrintStream} import java.nio.file.{FileAlreadyExistsException, Files} import java.util.Collections import net.sourceforge.argparse4j.ArgumentParsers @@ -48,7 +48,7 @@ object DepsRunner extends WorkerMain[Unit] { override def init(args: Option[Array[String]]): Unit = () - override def work(ctx: Unit, args: Array[String]): Unit = { + override def work(ctx: Unit, args: Array[String], out: PrintStream): Unit = { val namespace = argParser.parseArgs(args) val label = namespace.getString("label").tail @@ -64,9 +64,9 @@ object DepsRunner extends WorkerMain[Unit] { (directLabels -- usedWhitelist).filterNot(labelToPaths(_).exists(usedPaths)) } else Nil remove.foreach { depLabel => - println(s"Target '$depLabel' not used, please remove it from the deps.") - println(s"You can use the following buildozer command:") - println(s"buildozer 'remove deps $depLabel' $label") + out.println(s"Target '$depLabel' not used, please remove it from the deps.") + out.println(s"You can use the following buildozer command:") + out.println(s"buildozer 'remove deps $depLabel' $label") } val add = if (namespace.getBoolean("check_direct") == true) { @@ -80,9 +80,9 @@ object DepsRunner extends WorkerMain[Unit] { ) } else Nil add.foreach { depLabel => - println(s"Target '$depLabel' is used but isn't explicitly declared, please add it to the deps.") - println(s"You can use the following buildozer command:") - println(s"buildozer 'add deps $depLabel' $label") + out.println(s"Target '$depLabel' is used but isn't explicitly declared, please add it to the deps.") + out.println(s"You can use the following buildozer command:") + out.println(s"buildozer 'add deps $depLabel' $label") } if (add.isEmpty && remove.isEmpty) { diff --git a/src/main/scala/higherkindness/rules_scala/workers/zinc/compile/ZincRunner.scala b/src/main/scala/higherkindness/rules_scala/workers/zinc/compile/ZincRunner.scala index 2f9f9d4c..9b5c87aa 100644 --- a/src/main/scala/higherkindness/rules_scala/workers/zinc/compile/ZincRunner.scala +++ b/src/main/scala/higherkindness/rules_scala/workers/zinc/compile/ZincRunner.scala @@ -8,7 +8,7 @@ import workers.common.FileUtil import workers.common.LoggedReporter import common.worker.WorkerMain import com.google.devtools.build.buildjar.jarhelper.JarCreator -import java.io.{File, PrintWriter} +import java.io.{File, PrintStream, PrintWriter} import java.net.URLClassLoader import java.nio.file.{Files, NoSuchFileException, Path, Paths} import java.text.SimpleDateFormat @@ -80,7 +80,7 @@ object ZincRunner extends WorkerMain[Namespace] { Paths.get(dir.replace("~", sys.props.getOrElse("user.home", ""))) } - protected[this] def work(worker: Namespace, args: Array[String]) = { + protected[this] def work(worker: Namespace, args: Array[String], out: PrintStream) = { val usePersistence: Boolean = worker.getBoolean("use_persistence") match { case p: java.lang.Boolean => p case _ => true @@ -94,7 +94,7 @@ object ZincRunner extends WorkerMain[Namespace] { val depsCache = pathFrom(worker, "extracted_file_cache") - val logger = new AnnexLogger(namespace.getString("log_level")) + val logger = new AnnexLogger(namespace.getString("log_level"), out) val tmpDir = namespace.get[File]("tmp").toPath diff --git a/src/main/scala/higherkindness/rules_scala/workers/zinc/doc/DocRunner.scala b/src/main/scala/higherkindness/rules_scala/workers/zinc/doc/DocRunner.scala index d6b31e58..de6a95b9 100644 --- a/src/main/scala/higherkindness/rules_scala/workers/zinc/doc/DocRunner.scala +++ b/src/main/scala/higherkindness/rules_scala/workers/zinc/doc/DocRunner.scala @@ -8,7 +8,7 @@ import workers.common.AnnexScalaInstance import workers.common.CommonArguments.LogLevel import workers.common.FileUtil -import java.io.File +import java.io.{File, PrintStream} import java.net.URLClassLoader import java.nio.file.{Files, NoSuchFileException} import java.util.{Collections, Optional, Properties} @@ -85,7 +85,7 @@ object DocRunner extends WorkerMain[Unit] { override def init(args: Option[Array[String]]): Unit = () - override def work(ctx: Unit, args: Array[String]): Unit = { + override def work(ctx: Unit, args: Array[String], out: PrintStream): Unit = { val namespace = parser.parseArgsOrFail(args) val tmpDir = namespace.get[File]("tmp").toPath diff --git a/third_party/bazel/src/main/protobuf/worker_protocol.proto b/third_party/bazel/src/main/protobuf/worker_protocol.proto index 4706792f..a55ae858 100644 --- a/third_party/bazel/src/main/protobuf/worker_protocol.proto +++ b/third_party/bazel/src/main/protobuf/worker_protocol.proto @@ -38,6 +38,8 @@ message WorkRequest { // The inputs that the worker is allowed to read during execution of this // request. repeated Input inputs = 2; + + int32 request_id = 3; } // The worker sends this message to Blaze when it finished its work on the WorkRequest message. @@ -48,4 +50,6 @@ message WorkResponse { // compiler warnings / errors etc. - thus we'll use a string type here, which gives us UTF-8 // encoding. string output = 2; + + int32 request_id = 3; }