diff --git a/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/CodecFactory.java b/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/CodecFactory.java index 1998ea09dc..fdad50f0ce 100644 --- a/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/CodecFactory.java +++ b/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/CodecFactory.java @@ -173,10 +173,10 @@ public BytesInput compress(BytesInput bytes) throws IOException { // null compressor for non-native gzip compressor.reset(); } - CompressionOutputStream cos = codec.createOutputStream(compressedOutBuffer, compressor); - bytes.writeAllTo(cos); - cos.finish(); - cos.close(); + try (CompressionOutputStream cos = codec.createOutputStream(compressedOutBuffer, compressor)) { + bytes.writeAllTo(cos); + cos.finish(); + } compressedBytes = BytesInput.from(compressedOutBuffer); } return compressedBytes; @@ -234,11 +234,11 @@ protected CompressionCodec getCodec(CompressionCodecName codecName) { if (codecClassName == null) { return null; } - CompressionCodec codec = CODEC_BY_NAME.get(codecClassName); + String codecCacheKey = this.cacheKey(codecName); + CompressionCodec codec = CODEC_BY_NAME.get(codecCacheKey); if (codec != null) { return codec; } - try { Class codecClass; try { @@ -248,13 +248,36 @@ protected CompressionCodec getCodec(CompressionCodecName codecName) { codecClass = configuration.getClassLoader().loadClass(codecClassName); } codec = (CompressionCodec) ReflectionUtils.newInstance(codecClass, configuration); - CODEC_BY_NAME.put(codecClassName, codec); + CODEC_BY_NAME.put(codecCacheKey, codec); return codec; } catch (ClassNotFoundException e) { throw new BadConfigurationException("Class " + codecClassName + " was not found", e); } } + private String cacheKey(CompressionCodecName codecName) { + String level = null; + switch (codecName) { + case GZIP: + level = configuration.get("zlib.compress.level"); + break; + case BROTLI: + level = configuration.get("compression.brotli.quality"); + break; + case ZSTD: + level = configuration.get("parquet.compression.codec.zstd.level"); + if (level == null) { + // keep "io.compression.codec.zstd.level" for backwards compatibility + level = configuration.get("io.compression.codec.zstd.level"); + } + break; + default: + // compression level is not supported; ignore it + } + String codecClass = codecName.getHadoopCompressionCodecClassName(); + return level == null ? codecClass : codecClass + ":" + level; + } + @Override public void release() { for (BytesCompressor compressor : compressors.values()) { diff --git a/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestDirectCodecFactory.java b/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestDirectCodecFactory.java index 8fec515a4f..fa34d302cc 100644 --- a/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestDirectCodecFactory.java +++ b/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestDirectCodecFactory.java @@ -174,5 +174,55 @@ public void compressionCodecs() { } } } + + static class PublicCodecFactory extends CodecFactory { + // To make getCodec public + + public PublicCodecFactory(Configuration configuration, int pageSize) { + super(configuration, pageSize); + } + + public org.apache.hadoop.io.compress.CompressionCodec getCodec(CompressionCodecName name) { + return super.getCodec(name); + } + } + + @Test + public void cachingKeysGzip() { + Configuration config_zlib_2 = new Configuration(); + config_zlib_2.set("zlib.compress.level", "2"); + + Configuration config_zlib_5 = new Configuration(); + config_zlib_5.set("zlib.compress.level", "5"); + + final CodecFactory codecFactory_2 = new PublicCodecFactory(config_zlib_2, pageSize); + final CodecFactory codecFactory_5 = new PublicCodecFactory(config_zlib_5, pageSize); + + CompressionCodec codec_2_1 = codecFactory_2.getCodec(CompressionCodecName.GZIP); + CompressionCodec codec_2_2 = codecFactory_2.getCodec(CompressionCodecName.GZIP); + CompressionCodec codec_5_1 = codecFactory_5.getCodec(CompressionCodecName.GZIP); + + Assert.assertEquals(codec_2_1, codec_2_2); + Assert.assertNotEquals(codec_2_1, codec_5_1); + } + + @Test + public void cachingKeysZstd() { + Configuration config_zstd_2 = new Configuration(); + config_zstd_2.set("io.compression.codec.zstd.level", "2"); + + Configuration config_zstd_5 = new Configuration(); + config_zstd_5.set("io.compression.codec.zstd.level", "5"); + + final CodecFactory codecFactory_2 = new PublicCodecFactory(config_zstd_2, pageSize); + final CodecFactory codecFactory_5 = new PublicCodecFactory(config_zstd_5, pageSize); + + CompressionCodec codec_2_1 = codecFactory_2.getCodec(CompressionCodecName.ZSTD); + CompressionCodec codec_2_2 = codecFactory_2.getCodec(CompressionCodecName.ZSTD); + CompressionCodec codec_5_1 = codecFactory_5.getCodec(CompressionCodecName.ZSTD); + + Assert.assertEquals(codec_2_1, codec_2_2); + Assert.assertNotEquals(codec_2_1, codec_5_1); + } }