diff --git a/core/pom.xml b/core/pom.xml
index 97a463abbefdd..0ddcdfe2e98c5 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -189,6 +189,10 @@
net.jpountz.lz4
lz4
+
+ com.github.luben
+ zstd-jni
+
org.roaringbitmap
RoaringBitmap
diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
index 2e991ce394c42..cdcc613af769a 100644
--- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
+++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
@@ -19,6 +19,7 @@ package org.apache.spark.io
import java.io._
+import com.github.luben.zstd.{ZstdInputStream, ZstdOutputStream}
import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
import net.jpountz.lz4.LZ4BlockOutputStream
import org.xerial.snappy.{Snappy, SnappyInputStream, SnappyOutputStream}
@@ -49,13 +50,14 @@ private[spark] object CompressionCodec {
private[spark] def supportsConcatenationOfSerializedStreams(codec: CompressionCodec): Boolean = {
(codec.isInstanceOf[SnappyCompressionCodec] || codec.isInstanceOf[LZFCompressionCodec]
- || codec.isInstanceOf[LZ4CompressionCodec])
+ || codec.isInstanceOf[LZ4CompressionCodec] || codec.isInstanceOf[ZStandardCompressionCodec])
}
private val shortCompressionCodecNames = Map(
"lz4" -> classOf[LZ4CompressionCodec].getName,
"lzf" -> classOf[LZFCompressionCodec].getName,
- "snappy" -> classOf[SnappyCompressionCodec].getName)
+ "snappy" -> classOf[SnappyCompressionCodec].getName,
+ "zstd" -> classOf[ZStandardCompressionCodec].getName)
def getCodecName(conf: SparkConf): String = {
conf.get(configKey, DEFAULT_COMPRESSION_CODEC)
@@ -215,3 +217,22 @@ private final class SnappyOutputStreamWrapper(os: SnappyOutputStream) extends Ou
}
}
}
+
+/**
+ * :: DeveloperApi ::
+ * ZStandard implementation of [[org.apache.spark.io.CompressionCodec]].
+ *
+ * @note The wire protocol for this codec is not guaranteed to be compatible across versions
+ * of Spark. This is intended for use as an internal compression utility within a single Spark
+ * application.
+ */
+@DeveloperApi
+class ZStandardCompressionCodec(conf: SparkConf) extends CompressionCodec {
+
+ override def compressedOutputStream(s: OutputStream): OutputStream = {
+ val level = conf.getSizeAsBytes("spark.io.compression.zstandard.level", "3").toInt
+ new ZstdOutputStream(s, level)
+ }
+
+ override def compressedInputStream(s: InputStream): InputStream = new ZstdInputStream(s)
+}
diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala
index 9e9c2b0165e13..ffac1e1e695d7 100644
--- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala
+++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala
@@ -104,6 +104,24 @@ class CompressionCodecSuite extends SparkFunSuite {
testConcatenationOfSerializedStreams(codec)
}
+ test("zstd compression codec") {
+ val codec = CompressionCodec.createCodec(conf, classOf[ZStandardCompressionCodec].getName)
+ assert(codec.getClass === classOf[ZStandardCompressionCodec])
+ testCodec(codec)
+ }
+
+ test("zstd compression codec short form") {
+ val codec = CompressionCodec.createCodec(conf, "zstd")
+ assert(codec.getClass === classOf[ZStandardCompressionCodec])
+ testCodec(codec)
+ }
+
+ test("zstd supports concatenation of serialized zstd") {
+ val codec = CompressionCodec.createCodec(conf, classOf[ZStandardCompressionCodec].getName)
+ assert(codec.getClass === classOf[ZStandardCompressionCodec])
+ testConcatenationOfSerializedStreams(codec)
+ }
+
test("bad compression codec") {
intercept[IllegalArgumentException] {
CompressionCodec.createCodec(conf, "foobar")
diff --git a/pom.xml b/pom.xml
index c1174593c1922..b6d18770042d5 100644
--- a/pom.xml
+++ b/pom.xml
@@ -519,6 +519,11 @@
lz4
1.3.0
+
+ com.github.luben
+ zstd-jni
+ 1.1.1
+
com.clearspring.analytics
stream