Skip to content
This repository has been archived by the owner on Jun 4, 2024. It is now read-only.

Commit

Permalink
Launch multiple threads for worker processes
Browse files Browse the repository at this point in the history
  • Loading branch information
borkaehw authored and JaredNeil committed May 29, 2020
1 parent f24d0d6 commit b872dac
Show file tree
Hide file tree
Showing 11 changed files with 82 additions and 47 deletions.
2 changes: 1 addition & 1 deletion rules/private/phases/phase_zinc_compile.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def phase_zinc_compile(ctx, g):
outputs = outputs,
executable = worker.files_to_run.executable,
input_manifests = input_manifests,
execution_requirements = _resolve_execution_reqs(ctx, {"no-sandbox": "1", "supports-workers": "1"}),
execution_requirements = _resolve_execution_reqs(ctx, {"no-sandbox": "1", "supports-multiplex-workers": "1"}),
arguments = [args],
)

Expand Down
4 changes: 2 additions & 2 deletions rules/scala_proto/private/ScalaProtoWorker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package annex.scala.proto

import higherkindness.rules_scala.common.args.implicits._
import higherkindness.rules_scala.common.worker.WorkerMain
import java.io.File
import java.io.{File, PrintStream}
import java.nio.file.{Files, Paths}
import java.util.Collections
import net.sourceforge.argparse4j.ArgumentParsers
Expand Down Expand Up @@ -33,7 +33,7 @@ object ScalaProtoWorker extends WorkerMain[Unit] {

override def init(args: Option[Array[String]]): Unit = ()

protected[this] def work(ctx: Unit, args: Array[String]): Unit = {
protected[this] def work(ctx: Unit, args: Array[String], out: PrintStream): Unit = {
val namespace = argParser.parseArgs(args)
val sources = namespace.getList[File]("sources").asScala.toList

Expand Down
2 changes: 1 addition & 1 deletion rules/scalafmt/private/test.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def build_format(ctx):
input_manifests = runner_manifests,
inputs = [ctx.file.config, src],
tools = runner_inputs,
execution_requirements = _resolve_execution_reqs(ctx, {"supports-workers": "1"}),
execution_requirements = _resolve_execution_reqs(ctx, {"supports-multiplex-workers": "1"}),
mnemonic = "ScalaFmt",
)
manifest_content.append("{} {}".format(src.short_path, file.short_path))
Expand Down
4 changes: 2 additions & 2 deletions rules/scalafmt/scalafmt/ScalafmtRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package annex.scalafmt

import higherkindness.rules_scala.common.worker.WorkerMain
import higherkindness.rules_scala.workers.common.Color
import java.io.File
import java.io.{File, PrintStream}
import java.nio.file.Files
import net.sourceforge.argparse4j.ArgumentParsers
import net.sourceforge.argparse4j.impl.Arguments
Expand All @@ -16,7 +16,7 @@ object ScalafmtRunner extends WorkerMain[Unit] {

protected[this] def init(args: Option[Array[String]]): Unit = {}

protected[this] def work(worker: Unit, args: Array[String]): Unit = {
protected[this] def work(worker: Unit, args: Array[String], out: PrintStream): Unit = {

val parser = ArgumentParsers.newFor("scalafmt").addHelp(true).defaultFormatWidth(80).fromFilePrefix("@").build
parser.addArgument("--config").required(true).`type`(Arguments.fileType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,21 @@ import java.io.InputStream
import java.io.PrintStream
import java.lang.SecurityManager
import java.security.Permission
import java.util.concurrent.Executors
import scala.annotation.tailrec
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.Future
import scala.util.control.NonFatal
import scala.util.Success
import scala.util.Failure

trait WorkerMain[S] {

private[this] final case class ExitTrapped(code: Int) extends Throwable

protected[this] def init(args: Option[Array[String]]): S

protected[this] def work(ctx: S, args: Array[String]): Unit
protected[this] def work(ctx: S, args: Array[String], out: PrintStream): Unit

final def main(args: Array[String]): Unit = {
args.toList match {
Expand All @@ -36,39 +41,62 @@ trait WorkerMain[S] {
}
})

val outStream = new ByteArrayOutputStream
val out = new PrintStream(outStream)
val garbageOut = new PrintStream(new ByteArrayOutputStream)

System.setIn(new ByteArrayInputStream(Array.emptyByteArray))
System.setOut(out)
System.setErr(out)
System.setOut(garbageOut)
System.setErr(garbageOut)

try {
@tailrec
def process(ctx: S): S = {
val request = WorkerProtocol.WorkRequest.parseDelimitedFrom(stdin)
val args = request.getArgumentsList.toArray(Array.empty[String])

val code =
val outStream = new ByteArrayOutputStream
val out = new PrintStream(outStream)
val requestId = request.getRequestId()

val f: Future[Int] = Future {
try {
work(ctx, args)
work(ctx, args, out)
0
} catch {
case ExitTrapped(code) => code
case NonFatal(e) =>
e.printStackTrace()
e.printStackTrace(out)
1
}
}

WorkerProtocol.WorkResponse.newBuilder
.setOutput(outStream.toString)
.setExitCode(code)
.build
.writeDelimitedTo(stdout)

out.flush()
outStream.reset()

f.onComplete {
case Success(code) =>
synchronized {

WorkerProtocol.WorkResponse.newBuilder
.setRequestId(requestId)
.setOutput(outStream.toString)
.setExitCode(code)
.build
.writeDelimitedTo(stdout)

out.flush()
outStream.reset()
}
case Failure(e) => {
e.printStackTrace(out)

WorkerProtocol.WorkResponse.newBuilder
.setRequestId(requestId)
.setOutput(outStream.toString)
.setExitCode(-1)
.build
.writeDelimitedTo(stdout)

out.flush()
outStream.reset()
}
}
process(ctx)
}
process(init(Some(args.toArray)))
Expand All @@ -78,8 +106,11 @@ trait WorkerMain[S] {
System.setErr(stderr)
}

case args => work(init(None), args.toArray)
case args => {
val outStream = new ByteArrayOutputStream
val out = new PrintStream(outStream)
work(init(None), args.toArray, out)
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ package workers.bloop.compile
import common.worker.WorkerMain

import bloop.Bloop
import java.io.PrintStream

object BloopRunner extends WorkerMain[Unit] {
override def init(args: Option[Array[String]]): Unit = ()
override def work(ctx: Unit, args: Array[String]): Unit = Bloop
override def work(ctx: Unit, args: Array[String], out: PrintStream): Unit = Bloop
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,43 @@ package workers.common

import xsbti.Logger

import java.io.PrintWriter
import java.io.StringWriter
import java.io.{PrintStream, PrintWriter, StringWriter}
import java.nio.file.Paths
import java.util.function.Supplier

import CommonArguments.LogLevel

final class AnnexLogger(level: String) extends Logger {
final class AnnexLogger(level: String, out: PrintStream = System.err) extends Logger {

private[this] val root = s"${Paths.get("").toAbsolutePath}/"

private[this] def format(value: String): String = value.replace(root, "")

def debug(msg: Supplier[String]): Unit = level match {
case LogLevel.Debug => System.err.println(format(msg.get))
case LogLevel.Debug => out.println(format(msg.get))
case _ =>
}

def error(msg: Supplier[String]): Unit = level match {
case LogLevel.Debug | LogLevel.Error | LogLevel.Info | LogLevel.Warn => System.err.println(format(msg.get))
case LogLevel.Debug | LogLevel.Error | LogLevel.Info | LogLevel.Warn => out.println(format(msg.get))
case _ =>
}

def info(msg: Supplier[String]): Unit = level match {
case LogLevel.Debug | LogLevel.Info => System.err.println(format(msg.get))
case LogLevel.Debug | LogLevel.Info => out.println(format(msg.get))
case _ =>
}

def trace(err: Supplier[Throwable]): Unit = level match {
case LogLevel.Debug | LogLevel.Error | LogLevel.Info | LogLevel.Warn =>
val trace = new StringWriter();
err.get.printStackTrace(new PrintWriter(trace));
println(format(trace.toString))
out.println(format(trace.toString))
case _ =>
}

def warn(msg: Supplier[String]): Unit = level match {
case LogLevel.Debug | LogLevel.Info | LogLevel.Warn => System.err.println(format(msg.get))
case LogLevel.Debug | LogLevel.Info | LogLevel.Warn => out.println(format(msg.get))
case _ =>
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ package workers.deps
import common.args.implicits._
import common.worker.WorkerMain

import java.io.File
import java.io.{File, PrintStream}
import java.nio.file.{FileAlreadyExistsException, Files}
import java.util.Collections
import net.sourceforge.argparse4j.ArgumentParsers
Expand Down Expand Up @@ -48,7 +48,7 @@ object DepsRunner extends WorkerMain[Unit] {

override def init(args: Option[Array[String]]): Unit = ()

override def work(ctx: Unit, args: Array[String]): Unit = {
override def work(ctx: Unit, args: Array[String], out: PrintStream): Unit = {
val namespace = argParser.parseArgs(args)

val label = namespace.getString("label").tail
Expand All @@ -64,9 +64,9 @@ object DepsRunner extends WorkerMain[Unit] {
(directLabels -- usedWhitelist).filterNot(labelToPaths(_).exists(usedPaths))
} else Nil
remove.foreach { depLabel =>
println(s"Target '$depLabel' not used, please remove it from the deps.")
println(s"You can use the following buildozer command:")
println(s"buildozer 'remove deps $depLabel' $label")
out.println(s"Target '$depLabel' not used, please remove it from the deps.")
out.println(s"You can use the following buildozer command:")
out.println(s"buildozer 'remove deps $depLabel' $label")
}

val add = if (namespace.getBoolean("check_direct") == true) {
Expand All @@ -80,9 +80,9 @@ object DepsRunner extends WorkerMain[Unit] {
)
} else Nil
add.foreach { depLabel =>
println(s"Target '$depLabel' is used but isn't explicitly declared, please add it to the deps.")
println(s"You can use the following buildozer command:")
println(s"buildozer 'add deps $depLabel' $label")
out.println(s"Target '$depLabel' is used but isn't explicitly declared, please add it to the deps.")
out.println(s"You can use the following buildozer command:")
out.println(s"buildozer 'add deps $depLabel' $label")
}

if (add.isEmpty && remove.isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import workers.common.FileUtil
import workers.common.LoggedReporter
import common.worker.WorkerMain
import com.google.devtools.build.buildjar.jarhelper.JarCreator
import java.io.{File, PrintWriter}
import java.io.{File, PrintStream, PrintWriter}
import java.net.URLClassLoader
import java.nio.file.{Files, NoSuchFileException, Path, Paths}
import java.text.SimpleDateFormat
Expand Down Expand Up @@ -80,7 +80,7 @@ object ZincRunner extends WorkerMain[Namespace] {
Paths.get(dir.replace("~", sys.props.getOrElse("user.home", "")))
}

protected[this] def work(worker: Namespace, args: Array[String]) = {
protected[this] def work(worker: Namespace, args: Array[String], out: PrintStream) = {
val usePersistence: Boolean = worker.getBoolean("use_persistence") match {
case p: java.lang.Boolean => p
case _ => true
Expand All @@ -94,7 +94,7 @@ object ZincRunner extends WorkerMain[Namespace] {

val depsCache = pathFrom(worker, "extracted_file_cache")

val logger = new AnnexLogger(namespace.getString("log_level"))
val logger = new AnnexLogger(namespace.getString("log_level"), out)

val tmpDir = namespace.get[File]("tmp").toPath

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import workers.common.AnnexScalaInstance
import workers.common.CommonArguments.LogLevel
import workers.common.FileUtil

import java.io.File
import java.io.{File, PrintStream}
import java.net.URLClassLoader
import java.nio.file.{Files, NoSuchFileException}
import java.util.{Collections, Optional, Properties}
Expand Down Expand Up @@ -85,7 +85,7 @@ object DocRunner extends WorkerMain[Unit] {

override def init(args: Option[Array[String]]): Unit = ()

override def work(ctx: Unit, args: Array[String]): Unit = {
override def work(ctx: Unit, args: Array[String], out: PrintStream): Unit = {
val namespace = parser.parseArgsOrFail(args)

val tmpDir = namespace.get[File]("tmp").toPath
Expand Down
4 changes: 4 additions & 0 deletions third_party/bazel/src/main/protobuf/worker_protocol.proto
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ message WorkRequest {
// The inputs that the worker is allowed to read during execution of this
// request.
repeated Input inputs = 2;

int32 request_id = 3;
}

// The worker sends this message to Blaze when it finished its work on the WorkRequest message.
Expand All @@ -48,4 +50,6 @@ message WorkResponse {
// compiler warnings / errors etc. - thus we'll use a string type here, which gives us UTF-8
// encoding.
string output = 2;

int32 request_id = 3;
}

0 comments on commit b872dac

Please sign in to comment.