Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-22982] Remove unsafe asynchronous close() call from FileDownloadChannel #20179

Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 7 additions & 14 deletions core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -332,16 +332,14 @@ private[netty] class NettyRpcEnv(

val pipe = Pipe.open()
val source = new FileDownloadChannel(pipe.source())
try {
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
val client = downloadClient(parsedUri.getHost(), parsedUri.getPort())
val callback = new FileDownloadCallback(pipe.sink(), source, client)
client.stream(parsedUri.getPath(), callback)
} catch {
case e: Exception =>
pipe.sink().close()
source.close()
throw e
}
})(catchBlock = {
pipe.sink().close()
source.close()
})

source
}
Expand Down Expand Up @@ -376,18 +374,13 @@ private[netty] class NettyRpcEnv(

def setError(e: Throwable): Unit = {
error = e
source.close()
}

override def read(dst: ByteBuffer): Int = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the caller of read would close the source channel?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. This currently happens in two places:

Try(source.read(dst)) match {
case _ if error != null => throw error
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to also add a short comment here. This bug is subtle and no test against it now. Just from this code, it is hard to know why we check error even success.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a pair of comments to explain the flow of calls involving setError() and pipe closes.

case Success(bytesRead) => bytesRead
case Failure(readErr) =>
if (error != null) {
throw error
} else {
throw readErr
}
case Failure(readErr) => throw readErr
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
package org.apache.spark.shuffle

import java.io._

import com.google.common.io.ByteStreams
import java.nio.channels.Channels
import java.nio.file.Files

import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -196,11 +196,24 @@ private[spark] class IndexShuffleBlockResolver(
// find out the consolidated file, then the offset within that from our index
val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId)

val in = new DataInputStream(new FileInputStream(indexFile))
// SPARK-22982: if this FileInputStream's position is seeked forward by another piece of code
// which is incorrectly using our file descriptor then this code will fetch the wrong offsets
// (which may cause a reducer to be sent a different reducer's data). The explicit position
// checks added here were a useful debugging aid during SPARK-22982 and may help prevent this
// class of issue from re-occurring in the future which is why they are left here even though
// SPARK-22982 is fixed.
val channel = Files.newByteChannel(indexFile.toPath)
channel.position(blockId.reduceId * 8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @zsxwing I recall you mentioned about a performance issue with skipping data in the file channel, do we have this problem here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made sure to incorporate @zsxwing's changes here. The problem originally related to calling skip(), but this change is from his fix to explicitly use position on a FileChannel instead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I'm not clear whether the change here is related to "asynchronous close()" issue?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's used to detect bugs like "asynchronous close()" earlier in the future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some more background: the asynchronous close() bug can cause reads from a closed-and-subsequently-reassigned file descriptor number and in principle this can affect almost any IO operation anywhere in the application. For example, if the closed file descriptor number is immediately recycled by opening a socket then the invalid read can cause that socket read to miss data (since the data would have been consumed by the invalid reader and won't be delivered to the legitimate new user of the file descriptor).

Given this, I see how it might be puzzling that this patch is adding a check only here. There are two reasons for this:

  1. Many other IO operations have implicit checksumming such that dropping data due to an invalid read be detected and cause an exception. For example, many compression codecs have block-level checksumming (and magic numbers at the beginning of the stream), so dropping data (especially at the start of a read) will be detected. This particular shuffle index file, however, does not have mechanisms to detect corruption: skipping forward in the read by a multiple of 8 bytes will still read structurally-valid data (but it will be the wrong data, causing the wrong output to be read from the shuffle data file).

  2. In the investigation which uncovered this bug, the invalid reads were predominantly impacting shuffle index lookups for reading local blocks. In a nutshell, there's a subtle race condition where Janino codegen compilation triggers attempted remote classloading of classes which don't exist, triggering the error-handling / error-propagation paths in FileDownloadChannel and causing the invalid asynchronous close() call to be performed. At the same time that this close() call was being performed, another task from the same stage attempts to read the shuffle index files of local blocks and experiences an invalid read due to the falsely-shared file descriptor.

    This is a very hard-to-trigger bug: we were only able to reproduce it on large clusters with very fast machines and shuffles that contain large numbers of map and reduce tasks (more shuffle blocks means more index file reads and more chances for the race to occur; faster machines increase the likelihood of the race occurring; larger clusters give us more chances for the error to occur). In our reproduction, this race occurred on a microsecond timescale (measured via kernel syscall tracing) and occurred relatively rarely, requiring many iterations until we could trigger a reproduction.

While investigating, I added these checks so that the index read fails-fast when this issue occurs, which made it significantly easier to reproduce and diagnose the root cause (fixed by the other changes in this patch).

There are a number of interesting details in the story of how we worked from the original high-level data corruption symptom to this low-level IO bug. I'll see about writing up the complete story in a blog post at some point.

val in = new DataInputStream(Channels.newInputStream(channel))
try {
ByteStreams.skipFully(in, blockId.reduceId * 8)
val offset = in.readLong()
val nextOffset = in.readLong()
val actualPosition = channel.position()
val expectedPosition = blockId.reduceId * 8 + 16
if (actualPosition != expectedPosition) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe an assert assert(actualPosition == expectedPosition, $msg) is better for things like this so we may elide them using compiler flags if desired

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered this, but I don't think there's ever a case where we want to elide this particular check: if we read an incorrect offset here then there's (potentially) no other mechanism to detect this error, leading to silent wrong answers.

throw new Exception(s"SPARK-22982: Incorrect channel position after index file reads: " +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we'd better change to some specific Exception type here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any suggestions for a better exception subtype? I don't expect this to be a recoverable error and wanted to avoid the possibility that downstream code catches and handles this error. Maybe I should go further and make it a RuntimeException to make it even more fatal?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Thanks!

s"expected $expectedPosition but actual position was $actualPosition.")
}
new FileSegmentManagedBuffer(
transportConf,
getDataFile(blockId.shuffleId, blockId.mapId),
Expand Down