Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Zinc + Persistent Bazel Worker Processes #12

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions scala/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
load("@io_bazel_rules_scala//scala:scala.bzl", "scala_library")

java_binary(
name = "scala-worker",
main_class = "com.databricks.bazel.ScalaWorker",
visibility = ["//visibility:public"],
runtime_deps = [
":scala-worker-lib-import",
"@bazel_tools//src/main/protobuf:worker_protocol_proto",
"@zinc_0_3_10_SNAPSHOT_jar//jar",
"@scala_compiler_jar//jar",
"@incremental_compiler_0_13_9_jar//jar",
"@scala_library_jar//jar",
"@scala_reflect_jar//jar",
"@sbt_interface_0_13_9_jar//jar",
"@compiler_interface_0_13_9_sources_jar//jar",
"@nailgun_server_0_9_1_jar//jar",
],
)

java_import(
name = "scala-worker-lib-import",
jars = ["scala-worker-lib_deploy.jar"],
)

scala_library(
name = "scala-worker-lib",
srcs = glob(["ScalaWorker.scala"]),
deps = [
"@bazel_tools//src/main/protobuf:worker_protocol_proto",
"@zinc_0_3_10_SNAPSHOT_jar//jar",
"@scala_compiler_jar//jar",
"@incremental_compiler_0_13_9_jar//jar",
"@scala_library_jar//jar",
"@scala_reflect_jar//jar",
"@sbt_interface_0_13_9_jar//jar",
"@compiler_interface_0_13_9_sources_jar//jar",
"@nailgun_server_0_9_1_jar//jar",
],
)

220 changes: 220 additions & 0 deletions scala/ScalaWorker.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
package com.databricks.bazel

import java.nio.charset.StandardCharsets.UTF_8

import com.google.devtools.build.lib.worker.WorkerProtocol.Input
import com.google.devtools.build.lib.worker.WorkerProtocol.WorkRequest
import com.google.devtools.build.lib.worker.WorkerProtocol.WorkResponse

import java.io.ByteArrayOutputStream
import java.io.File
import java.io.IOException
import java.io.PrintStream
import java.nio.file.Files
import java.nio.file.Paths
import java.net.ServerSocket
import java.util.ArrayList
import java.util.LinkedHashMap
import java.util.{List => JList}
import java.util.Map.Entry
import java.util.UUID

import scala.collection.JavaConverters._
import scala.sys.process._

import com.typesafe.zinc.{Main => ZincMain, Nailgun, ZincClient}


/**
* An example implementation of a worker process that is used for integration tests.
*/
object ScalaWorker {

// A UUID that uniquely identifies this running worker process.
private val workerUuid = UUID.randomUUID()

// A counter that increases with each work unit processed.
private var workUnitCounter = 1

// If true, returns corrupt responses instead of correct protobufs.
private var poisoned = false

// Keep state across multiple builds.
private val inputs = new LinkedHashMap[String, String]()

private var serverArgs = ""

private def getFreePort(): Int = {
val sock = new ServerSocket(0)
val port = sock.getLocalPort
sock.close()
port
}

private var zincClient: ZincClient = _

private var zincPort: Int = 0

private var nailgunProcess: Process = _

private def attachShutdownHook() {
Runtime.getRuntime().addShutdownHook(new Thread() {
override def run() {
if (nailgunProcess != null) {
nailgunProcess.destroy()
}
}
})
}

private val serverOutput = new StringBuilder()

private def startServer(classpath: String): Unit = {
attachShutdownHook()
zincPort = getFreePort()

val logger = new ProcessLogger {
def buffer[T](fn: => T): T = fn
def err(s: => String): Unit = serverOutput.append(s).append("\n")
def out(s: => String): Unit = serverOutput.append(s).append("\n")
}

// Options copied from Nailgun.scala in Zinc
val options = List("-cp", classpath, "-server", "-Xms1024m", "-Xmx3g", "-XX:MaxPermSize=384m",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can these be passed in somehow?

"-XX:ReservedCodeCacheSize=192m")
val cmd = "java" :: options ++ Seq(classOf[Nailgun].getName, s"$zincPort")
val builder = Process(cmd)
this.nailgunProcess = builder.run(logger)

serverArgs = cmd.mkString(" ")
zincClient = new ZincClient(port = zincPort)
}

private def awaitServer() {
var count = 0
while (!zincClient.serverAvailable && (count < 50)) {
try { Thread.sleep(100) } catch { case _: InterruptedException => }
count += 1
}
}

def main(args: Array[String]): Unit = {
if (args.contains("--persistent_worker")) {
startServer(args(0))
runPersistentWorker(args)
} else {
// This is a single invocation of the example that exits after it processed the request.
ZincMain.run(args, cwd = None)
}
}

private def listFiles(f: File): Seq[String] = {
val current = f.listFiles
val files = current.filter(_.isFile).map(_.getAbsolutePath)
val directories = current.filter(_.isDirectory)
files ++ directories.flatMap(listFiles)
}

// Extract a src jar to a temporary directory and return the list of extracted files
private def expandSrcJar(path: String): Seq[String] = {
val tempDir = Files.createTempDirectory(null).toFile
Seq("unzip", "-q", path, "-d", tempDir.getAbsolutePath).!!
listFiles(tempDir)
}

@throws[IOException]
private def runPersistentWorker(args: Array[String]) {
val originalStdOut = System.out
val originalStdErr = System.err

while (true) {
try {
val request = WorkRequest.parseDelimitedFrom(System.in)
if (request == null) {
return
}

inputs.clear()

for (input <- request.getInputsList().asScala) {
inputs.put(input.getPath(), input.getDigest().toStringUtf8())
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: idiomatic scala would be:

request.getInputsList.asScala.foreach { input =>
  inputs.put(input.getPath, input.getDigest.toStringUtf8)
}


val baos = new ByteArrayOutputStream()
var exitCode = 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have val exitCode = try { ... below rather than a var? var adds cognitive load to reviewing, which IMO should only be used when necessary for perf. In this case val is just as good.


val ps = new PrintStream(baos)
try {
System.setOut(ps)
System.setErr(ps)

var clientArgs: Seq[String] = null

try {
clientArgs = request.getArgumentsList.asScala.flatMap { arg =>
// srcjars must be extracted before we can pass them to zinc
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these aren't supported yet by the other rules.

if (arg.endsWith(".srcjar")) {
expandSrcJar(arg)
} else {
Seq(arg)
}
}
awaitServer()
exitCode = zincClient.run(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could just be the return value rather than changing the var.

args = clientArgs,
cwd = new File(System.getProperty("user.dir")),
out = ps,
err = ps
)
} catch {
case e: Exception =>
// We use System.out.println as not to accidentally write to real stdout
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't follow this comment. How is it not real stdout?

Can we not use stderr here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh,, I see, you change it above.

System.out.println("Startup Args:")
args.foreach(arg => System.out.println("Arg: " + arg))
System.out.println("Server args: " + serverArgs)
System.out.println("Server output: " + serverOutput.toString)
System.out.println("Unexpanded Client Args:")
request.getArgumentsList.asScala.foreach(arg => System.out.println("Arg: " + arg))
if (clientArgs != null) {
System.out.println("Expanded Client Args:")
clientArgs.foreach(arg => System.out.println("Arg: " + arg))
} else {
System.out.println("======== CLIENT ARG EXPANSION MAY HAVE FAILED =======")
}

e.printStackTrace()
exitCode = 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could just return this rather than using a var.

}
} finally {
System.setOut(originalStdOut)
System.setErr(originalStdErr)
}

if (poisoned) {
System.out.println("I'm a poisoned worker and this is not a protobuf.")
} else {
WorkResponse.newBuilder()
.setOutput(baos.toString())
.setExitCode(exitCode)
.build()
.writeDelimitedTo(System.out)
}
System.out.flush()

/*
if (workerOptions.exitAfter > 0 && workUnitCounter > workerOptions.exitAfter) {
return
}

if (workerOptions.poisonAfter > 0 && workUnitCounter > workerOptions.poisonAfter) {
poisoned = true
}
*/
} finally {
// Be a good worker process and consume less memory when idle.
System.gc()
}
}
}
}

Loading