Skip to content

Commit

Permalink
Add support for wrapping system streams in WorkRequestHandler (#17583)
Browse files Browse the repository at this point in the history
There are often [places](https://github.com/bazelbuild/bazel/blob/ea19c17075478092eb77580e6d3825d480126d3a/src/tools/android/java/com/google/devtools/build/android/ResourceProcessorBusyBox.java#L188) where persistent workers need to swap out the standard system streams to avoid tools poisoning the worker communication streams by writing logs/exceptions to it.

This pull request extracts that pattern into an optional WorkerIO wrapper can be used to swap in and out the standard streams without the added boilerplate.

PiperOrigin-RevId: 498983983
Change-Id: Iefb956d38a5887d9e5bbf0821551eb0efa14fce9

(cherry picked from commit e05345d)

Co-authored-by: Benjamin Lee <[email protected]>
Co-authored-by: kshyanashree <[email protected]>
  • Loading branch information
3 people authored Feb 24, 2023
1 parent 7a41105 commit 10e6251
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@
import com.google.devtools.build.lib.worker.WorkerProtocol.WorkResponse;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import com.sun.management.OperatingSystemMXBean;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.lang.management.ManagementFactory;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.List;
import java.util.Optional;
Expand Down Expand Up @@ -317,8 +321,17 @@ public WorkRequestHandler build() {
* then writing the corresponding {@link WorkResponse} to {@code out}. If there is an error
* reading or writing the requests or responses, it writes an error message on {@code err} and
* returns. If {@code in} reaches EOF, it also returns.
*
* <p>This function also wraps the system streams in a {@link WorkerIO} instance that prevents the
* underlying tool from writing to {@link System#out} or reading from {@link System#in}, which
* would corrupt the worker worker protocol. When the while loop exits, the original system
* streams will be swapped back into {@link System}.
*/
public void processRequests() throws IOException {
// Wrap the system streams into a WorkerIO instance to prevent unexpected reads and writes on
// stdin/stdout.
WorkerIO workerIO = WorkerIO.capture();

try {
while (!shutdownWorker.get()) {
WorkRequest request = messageProcessor.readWorkRequest();
Expand All @@ -328,31 +341,39 @@ public void processRequests() throws IOException {
if (request.getCancel()) {
respondToCancelRequest(request);
} else {
startResponseThread(request);
startResponseThread(workerIO, request);
}
}
} catch (IOException e) {
stderr.println("Error reading next WorkRequest: " + e);
e.printStackTrace(stderr);
}
// TODO(b/220878242): Give the outstanding requests a chance to send a "shutdown" response,
// but also try to kill stuck threads. For now, we just interrupt the remaining threads.
// We considered doing System.exit here, but that is hard to test and would deny the callers
// of this method a chance to clean up. Instead, we initiate the cleanup of our resources here
// and the caller can decide whether to wait for an orderly shutdown or now.
for (RequestInfo ri : activeRequests.values()) {
if (ri.thread.isAlive()) {
try {
ri.thread.interrupt();
} catch (RuntimeException e) {
// If we can't interrupt, we can't do much else.
} finally {
// TODO(b/220878242): Give the outstanding requests a chance to send a "shutdown" response,
// but also try to kill stuck threads. For now, we just interrupt the remaining threads.
// We considered doing System.exit here, but that is hard to test and would deny the callers
// of this method a chance to clean up. Instead, we initiate the cleanup of our resources here
// and the caller can decide whether to wait for an orderly shutdown or now.
for (RequestInfo ri : activeRequests.values()) {
if (ri.thread.isAlive()) {
try {
ri.thread.interrupt();
} catch (RuntimeException e) {
// If we can't interrupt, we can't do much else.
}
}
}

try {
// Unwrap the system streams placing the original streams back
workerIO.close();
} catch (Exception e) {
stderr.println(e.getMessage());
}
}
}

/** Starts a thread for the given request. */
void startResponseThread(WorkRequest request) {
void startResponseThread(WorkerIO workerIO, WorkRequest request) {
Thread currentThread = Thread.currentThread();
String threadName =
request.getRequestId() > 0
Expand Down Expand Up @@ -381,7 +402,7 @@ void startResponseThread(WorkRequest request) {
return;
}
try {
respondToRequest(request, requestInfo);
respondToRequest(workerIO, request, requestInfo);
} catch (IOException e) {
// IOExceptions here means a problem talking to the server, so we must shut down.
if (!shutdownWorker.compareAndSet(false, true)) {
Expand Down Expand Up @@ -419,7 +440,8 @@ void startResponseThread(WorkRequest request) {
* #callback} are reported with exit code 1.
*/
@VisibleForTesting
void respondToRequest(WorkRequest request, RequestInfo requestInfo) throws IOException {
void respondToRequest(WorkerIO workerIO, WorkRequest request, RequestInfo requestInfo)
throws IOException {
int exitCode;
StringWriter sw = new StringWriter();
try (PrintWriter pw = new PrintWriter(sw)) {
Expand All @@ -431,6 +453,16 @@ void respondToRequest(WorkRequest request, RequestInfo requestInfo) throws IOExc
e.printStackTrace(pw);
exitCode = 1;
}

try {
// Read out the captured string for the final WorkResponse output
String captured = workerIO.readCapturedAsUtf8String().trim();
if (!captured.isEmpty()) {
pw.write(captured);
}
} catch (IOException e) {
stderr.println(e.getMessage());
}
}
Optional<WorkResponse.Builder> optBuilder = requestInfo.takeBuilder();
if (optBuilder.isPresent()) {
Expand Down Expand Up @@ -541,4 +573,104 @@ private void maybePerformGc() {
}
}
}

/**
* Class that wraps the standard {@link System#in}, {@link System#out}, and {@link System#err}
* with our own ByteArrayOutputStream that allows {@link WorkRequestHandler} to safely capture
* outputs that can't be directly captured by the PrintStream associated with the work request.
*
* <p>This is most useful when integrating JVM tools that write exceptions and logs directly to
* {@link System#out} and {@link System#err}, which would corrupt the persistent worker protocol.
* We also redirect {@link System#in}, just in case a tool should attempt to read it.
*
* <p>WorkerIO implements {@link AutoCloseable} and will swap the original streams back into
* {@link System} once close has been called.
*/
public static class WorkerIO implements AutoCloseable {
private final InputStream originalInputStream;
private final PrintStream originalOutputStream;
private final PrintStream originalErrorStream;
private final ByteArrayOutputStream capturedStream;
private final AutoCloseable restore;

/**
* Creates a new {@link WorkerIO} that allows {@link WorkRequestHandler} to capture standard
* output and error streams that can't be directly captured by the PrintStream associated with
* the work request.
*/
@VisibleForTesting
WorkerIO(
InputStream originalInputStream,
PrintStream originalOutputStream,
PrintStream originalErrorStream,
ByteArrayOutputStream capturedStream,
AutoCloseable restore) {
this.originalInputStream = originalInputStream;
this.originalOutputStream = originalOutputStream;
this.originalErrorStream = originalErrorStream;
this.capturedStream = capturedStream;
this.restore = restore;
}

/** Wraps the standard System streams and WorkerIO instance */
public static WorkerIO capture() {
// Save the original streams
InputStream originalInputStream = System.in;
PrintStream originalOutputStream = System.out;
PrintStream originalErrorStream = System.err;

// Replace the original streams with our own instances
ByteArrayOutputStream capturedStream = new ByteArrayOutputStream();
PrintStream outputBuffer = new PrintStream(capturedStream, true);
ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(new byte[0]);
System.setIn(byteArrayInputStream);
System.setOut(outputBuffer);
System.setErr(outputBuffer);

return new WorkerIO(
originalInputStream,
originalOutputStream,
originalErrorStream,
capturedStream,
() -> {
System.setIn(originalInputStream);
System.setOut(originalOutputStream);
System.setErr(originalErrorStream);
outputBuffer.close();
byteArrayInputStream.close();
});
}

/** Returns the original input stream most commonly provided by {@link System#in} */
@VisibleForTesting
InputStream getOriginalInputStream() {
return originalInputStream;
}

/** Returns the original output stream most commonly provided by {@link System#out} */
@VisibleForTesting
PrintStream getOriginalOutputStream() {
return originalOutputStream;
}

/** Returns the original error stream most commonly provided by {@link System#err} */
@VisibleForTesting
PrintStream getOriginalErrorStream() {
return originalErrorStream;
}

/** Returns the captured outputs as a UTF-8 string */
@VisibleForTesting
String readCapturedAsUtf8String() throws IOException {
capturedStream.flush();
String captureOutput = capturedStream.toString(StandardCharsets.UTF_8);
capturedStream.reset();
return captureOutput;
}

@Override
public void close() throws Exception {
restore.close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public final class ExampleWorker {
static final Pattern FLAG_FILE_PATTERN = Pattern.compile("(?:@|--?flagfile=)(.+)");

// A UUID that uniquely identifies this running worker process.
static final UUID workerUuid = UUID.randomUUID();
static final UUID WORKER_UUID = UUID.randomUUID();

// A counter that increases with each work unit processed.
static int workUnitCounter = 1;
Expand Down Expand Up @@ -83,6 +83,9 @@ private static class InterruptableWorkRequestHandler extends WorkRequestHandler
@Override
@SuppressWarnings("SystemExitOutsideMain")
public void processRequests() throws IOException {
ByteArrayOutputStream captured = new ByteArrayOutputStream();
WorkerIO workerIO = new WorkerIO(System.in, System.out, System.err, captured, captured);

while (true) {
WorkRequest request = messageProcessor.readWorkRequest();
if (request == null) {
Expand All @@ -100,12 +103,19 @@ public void processRequests() throws IOException {
if (request.getCancel()) {
respondToCancelRequest(request);
} else {
startResponseThread(request);
startResponseThread(workerIO, request);
}
if (workerOptions.exitAfter > 0 && workUnitCounter > workerOptions.exitAfter) {
System.exit(0);
}
}

try {
// Unwrap the system streams placing the original streams back
workerIO.close();
} catch (Exception e) {
workerIO.getOriginalErrorStream().println(e.getMessage());
}
}
}

Expand Down Expand Up @@ -241,7 +251,7 @@ private static void parseOptionsAndLog(List<String> args) throws Exception {
List<String> outputs = new ArrayList<>();

if (options.writeUUID) {
outputs.add("UUID " + workerUuid);
outputs.add("UUID " + WORKER_UUID);
}

if (options.writeCounter) {
Expand Down
Loading

0 comments on commit 10e6251

Please sign in to comment.