diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java index 17adfec32192f..b5dddb9f11b22 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -21,6 +21,7 @@ import java.nio.ByteBuffer; import java.util.List; +import org.apache.commons.lang.NotImplementedException; import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.parquet.Preconditions; @@ -41,6 +42,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; import org.apache.spark.sql.execution.vectorized.ColumnVector; import org.apache.spark.sql.execution.vectorized.ColumnarBatch; +import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.types.UTF8String; @@ -207,13 +209,7 @@ public boolean nextBatch() throws IOException { int num = (int)Math.min((long) columnarBatch.capacity(), totalRowCount - rowsReturned); for (int i = 0; i < columnReaders.length; ++i) { - switch (columnReaders[i].descriptor.getType()) { - case INT32: - columnReaders[i].readIntBatch(num, columnarBatch.column(i)); - break; - default: - throw new IOException("Unsupported type: " + columnReaders[i].descriptor.getType()); - } + columnReaders[i].readBatch(num, columnarBatch.column(i)); } rowsReturned += num; columnarBatch.setNumRows(num); @@ -237,7 +233,8 @@ private void initializeInternal() throws IOException { // TODO: Be extremely cautious in what is supported. Expand this. if (originalTypes[i] != null && originalTypes[i] != OriginalType.DECIMAL && - originalTypes[i] != OriginalType.UTF8 && originalTypes[i] != OriginalType.DATE) { + originalTypes[i] != OriginalType.UTF8 && originalTypes[i] != OriginalType.DATE && + originalTypes[i] != OriginalType.INT_8 && originalTypes[i] != OriginalType.INT_16) { throw new IOException("Unsupported type: " + t); } if (originalTypes[i] == OriginalType.DECIMAL && @@ -464,6 +461,11 @@ private final class ColumnReader { */ private boolean useDictionary; + /** + * If useDictionary is true, the staging vector used to decode the ids. + */ + private ColumnVector dictionaryIds; + /** * Maximum definition level for this column. */ @@ -587,9 +589,8 @@ private boolean next() throws IOException { /** * Reads `total` values from this columnReader into column. - * TODO: implement the other encodings. */ - private void readIntBatch(int total, ColumnVector column) throws IOException { + private void readBatch(int total, ColumnVector column) throws IOException { int rowId = 0; while (total > 0) { // Compute the number of values we want to read in this page. @@ -599,21 +600,134 @@ private void readIntBatch(int total, ColumnVector column) throws IOException { leftInPage = (int)(endOfPageValueCount - valuesRead); } int num = Math.min(total, leftInPage); - defColumn.readIntegers( - num, column, rowId, maxDefLevel, (VectorizedValuesReader)dataColumn, 0); - - // Remap the values if it is dictionary encoded. if (useDictionary) { - for (int i = rowId; i < rowId + num; ++i) { - column.putInt(i, dictionary.decodeToInt(column.getInt(i))); + // Data is dictionary encoded. We will vector decode the ids and then resolve the values. + if (dictionaryIds == null) { + dictionaryIds = ColumnVector.allocate(total, DataTypes.IntegerType, MemoryMode.ON_HEAP); + } else { + dictionaryIds.reset(); + dictionaryIds.reserve(total); + } + // Read and decode dictionary ids. + readIntBatch(rowId, num, dictionaryIds); + decodeDictionaryIds(rowId, num, column); + } else { + switch (descriptor.getType()) { + case INT32: + readIntBatch(rowId, num, column); + break; + case INT64: + readLongBatch(rowId, num, column); + break; + case BINARY: + readBinaryBatch(rowId, num, column); + break; + default: + throw new IOException("Unsupported type: " + descriptor.getType()); } } + valuesRead += num; rowId += num; total -= num; } } + /** + * Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`. + */ + private void decodeDictionaryIds(int rowId, int num, ColumnVector column) { + switch (descriptor.getType()) { + case INT32: + if (column.dataType() == DataTypes.IntegerType) { + for (int i = rowId; i < rowId + num; ++i) { + column.putInt(i, dictionary.decodeToInt(dictionaryIds.getInt(i))); + } + } else if (column.dataType() == DataTypes.ByteType) { + for (int i = rowId; i < rowId + num; ++i) { + column.putByte(i, (byte)dictionary.decodeToInt(dictionaryIds.getInt(i))); + } + } else { + throw new NotImplementedException("Unimplemented type: " + column.dataType()); + } + break; + + case INT64: + for (int i = rowId; i < rowId + num; ++i) { + column.putLong(i, dictionary.decodeToLong(dictionaryIds.getInt(i))); + } + break; + + case BINARY: + // TODO: this is incredibly inefficient as it blows up the dictionary right here. We + // need to do this better. We should probably add the dictionary data to the ColumnVector + // and reuse it across batches. This should mean adding a ByteArray would just update + // the length and offset. + for (int i = rowId; i < rowId + num; ++i) { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); + column.putByteArray(i, v.getBytes()); + } + break; + + default: + throw new NotImplementedException("Unsupported type: " + descriptor.getType()); + } + + if (dictionaryIds.numNulls() > 0) { + // Copy the NULLs over. + // TODO: we can improve this by decoding the NULLs directly into column. This would + // mean we decode the int ids into `dictionaryIds` and the NULLs into `column` and then + // just do the ID remapping as above. + for (int i = 0; i < num; ++i) { + if (dictionaryIds.getIsNull(rowId + i)) { + column.putNull(rowId + i); + } + } + } + } + + /** + * For all the read*Batch functions, reads `num` values from this columnReader into column. It + * is guaranteed that num is smaller than the number of values left in the current page. + */ + + private void readIntBatch(int rowId, int num, ColumnVector column) throws IOException { + // This is where we implement support for the valid type conversions. + // TODO: implement remaining type conversions + if (column.dataType() == DataTypes.IntegerType) { + defColumn.readIntegers( + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn, 0); + } else if (column.dataType() == DataTypes.ByteType) { + defColumn.readBytes( + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + } else { + throw new NotImplementedException("Unimplemented type: " + column.dataType()); + } + } + + private void readLongBatch(int rowId, int num, ColumnVector column) throws IOException { + // This is where we implement support for the valid type conversions. + // TODO: implement remaining type conversions + if (column.dataType() == DataTypes.LongType) { + defColumn.readLongs( + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + } else { + throw new NotImplementedException("Unimplemented type: " + column.dataType()); + } + } + + private void readBinaryBatch(int rowId, int num, ColumnVector column) throws IOException { + // This is where we implement support for the valid type conversions. + // TODO: implement remaining type conversions + if (column.isArray()) { + defColumn.readBinarys( + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + } else { + throw new NotImplementedException("Unimplemented type: " + column.dataType()); + } + } + + private void readPage() throws IOException { DataPage page = pageReader.readPage(); // TODO: Why is this a visitor? diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java index dac0c52ebd2cf..cec2418e46030 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java @@ -18,10 +18,13 @@ import java.io.IOException; +import org.apache.spark.sql.Column; import org.apache.spark.sql.execution.vectorized.ColumnVector; import org.apache.spark.unsafe.Platform; +import org.apache.commons.lang.NotImplementedException; import org.apache.parquet.column.values.ValuesReader; +import org.apache.parquet.io.api.Binary; /** * An implementation of the Parquet PLAIN decoder that supports the vectorized interface. @@ -52,15 +55,53 @@ public void skip(int n) { } @Override - public void readIntegers(int total, ColumnVector c, int rowId) { + public final void readIntegers(int total, ColumnVector c, int rowId) { c.putIntsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); offset += 4 * total; } @Override - public int readInteger() { + public final void readLongs(int total, ColumnVector c, int rowId) { + c.putLongsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); + offset += 8 * total; + } + + @Override + public final void readBytes(int total, ColumnVector c, int rowId) { + for (int i = 0; i < total; i++) { + // Bytes are stored as a 4-byte little endian int. Just read the first byte. + // TODO: consider pushing this in ColumnVector by adding a readBytes with a stride. + c.putInt(rowId + i, buffer[offset]); + offset += 4; + } + } + + @Override + public final int readInteger() { int v = Platform.getInt(buffer, offset); offset += 4; return v; } + + @Override + public final long readLong() { + long v = Platform.getLong(buffer, offset); + offset += 8; + return v; + } + + @Override + public final byte readByte() { + return (byte)readInteger(); + } + + @Override + public final void readBinary(int total, ColumnVector v, int rowId) { + for (int i = 0; i < total; i++) { + int len = readInteger(); + int start = offset; + offset += len; + v.putByteArray(rowId + i, buffer, start - Platform.BYTE_ARRAY_OFFSET, len); + } + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index 493ec9deed499..9bfd74db38766 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -17,12 +17,16 @@ package org.apache.spark.sql.execution.datasources.parquet; +import org.apache.commons.lang.NotImplementedException; import org.apache.parquet.Preconditions; import org.apache.parquet.bytes.BytesUtils; import org.apache.parquet.column.values.ValuesReader; import org.apache.parquet.column.values.bitpacking.BytePacker; import org.apache.parquet.column.values.bitpacking.Packer; import org.apache.parquet.io.ParquetDecodingException; +import org.apache.parquet.io.api.Binary; + +import org.apache.spark.sql.Column; import org.apache.spark.sql.execution.vectorized.ColumnVector; /** @@ -35,7 +39,8 @@ * - Definition/Repetition levels * - Dictionary ids. */ -public final class VectorizedRleValuesReader extends ValuesReader { +public final class VectorizedRleValuesReader extends ValuesReader + implements VectorizedValuesReader { // Current decoding mode. The encoded data contains groups of either run length encoded data // (RLE) or bit packed data. Each group contains a header that indicates which group it is and // the number of values in the group. @@ -121,6 +126,7 @@ public int readValueDictionaryId() { return readInteger(); } + @Override public int readInteger() { if (this.currentCount == 0) { this.readNextGroup(); } @@ -138,7 +144,9 @@ public int readInteger() { /** * Reads `total` ints into `c` filling them in starting at `c[rowId]`. This reader * reads the definition levels and then will read from `data` for the non-null values. - * If the value is null, c will be populated with `nullValue`. + * If the value is null, c will be populated with `nullValue`. Note that `nullValue` is only + * necessary for readIntegers because we also use it to decode dictionaryIds and want to make + * sure it always has a value in range. * * This is a batched version of this logic: * if (this.readInt() == level) { @@ -180,6 +188,154 @@ public void readIntegers(int total, ColumnVector c, int rowId, int level, } } + // TODO: can this code duplication be removed without a perf penalty? + public void readBytes(int total, ColumnVector c, + int rowId, int level, VectorizedValuesReader data) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + if (currentValue == level) { + data.readBytes(n, c, rowId); + c.putNotNulls(rowId, n); + } else { + c.putNulls(rowId, n); + } + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == level) { + c.putByte(rowId + i, data.readByte()); + c.putNotNull(rowId + i); + } else { + c.putNull(rowId + i); + } + } + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + + public void readLongs(int total, ColumnVector c, int rowId, int level, + VectorizedValuesReader data) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + if (currentValue == level) { + data.readLongs(n, c, rowId); + c.putNotNulls(rowId, n); + } else { + c.putNulls(rowId, n); + } + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == level) { + c.putLong(rowId + i, data.readLong()); + c.putNotNull(rowId + i); + } else { + c.putNull(rowId + i); + } + } + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + + public void readBinarys(int total, ColumnVector c, int rowId, int level, + VectorizedValuesReader data) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + if (currentValue == level) { + c.putNotNulls(rowId, n); + data.readBinary(n, c, rowId); + } else { + c.putNulls(rowId, n); + } + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == level) { + c.putNotNull(rowId + i); + data.readBinary(1, c, rowId); + } else { + c.putNull(rowId + i); + } + } + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + + + // The RLE reader implements the vectorized decoding interface when used to decode dictionary + // IDs. This is different than the above APIs that decodes definitions levels along with values. + // Since this is only used to decode dictionary IDs, only decoding integers is supported. + @Override + public void readIntegers(int total, ColumnVector c, int rowId) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + switch (mode) { + case RLE: + c.putInts(rowId, n, currentValue); + break; + case PACKED: + c.putInts(rowId, n, currentBuffer, currentBufferIdx); + currentBufferIdx += n; + break; + } + rowId += n; + left -= n; + currentCount -= n; + } + } + + @Override + public byte readByte() { + throw new UnsupportedOperationException("only readInts is valid."); + } + + @Override + public void readBytes(int total, ColumnVector c, int rowId) { + throw new UnsupportedOperationException("only readInts is valid."); + } + + @Override + public void readLongs(int total, ColumnVector c, int rowId) { + throw new UnsupportedOperationException("only readInts is valid."); + } + + @Override + public void readBinary(int total, ColumnVector c, int rowId) { + throw new UnsupportedOperationException("only readInts is valid."); + } + + @Override + public void skip(int n) { + throw new UnsupportedOperationException("only readInts is valid."); + } + + /** * Reads the next varint encoded int. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java index 49a9ed83d590a..b6ec7311c564a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java @@ -24,12 +24,17 @@ * TODO: merge this into parquet-mr. */ public interface VectorizedValuesReader { + byte readByte(); int readInteger(); + long readLong(); /* * Reads `total` values into `c` start at `c[rowId]` */ + void readBytes(int total, ColumnVector c, int rowId); void readIntegers(int total, ColumnVector c, int rowId); + void readLongs(int total, ColumnVector c, int rowId); + void readBinary(int total, ColumnVector c, int rowId); // TODO: add all the other parquet types. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index a5bc506a65ac2..0514252a8e53d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -763,7 +763,12 @@ public final int appendStruct(boolean isNull) { /** * Returns the elements appended. */ - public int getElementsAppended() { return elementsAppended; } + public final int getElementsAppended() { return elementsAppended; } + + /** + * Returns true if this column is an array. + */ + public final boolean isArray() { return resultArray != null; } /** * Maximum number of rows that can be stored in this column. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala new file mode 100644 index 0000000000000..cef6b79a094d1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils +import org.apache.spark.sql.test.SharedSQLContext + +// TODO: this needs a lot more testing but it's currently not easy to test with the parquet +// writer abstractions. Revisit. +class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContext { + import testImplicits._ + + val ROW = ((1).toByte, 2, 3L, "abc") + val NULL_ROW = ( + null.asInstanceOf[java.lang.Byte], + null.asInstanceOf[Integer], + null.asInstanceOf[java.lang.Long], + null.asInstanceOf[String]) + + test("All Types Dictionary") { + (1 :: 1000 :: Nil).foreach { n => { + withTempPath { dir => + List.fill(n)(ROW).toDF.repartition(1).write.parquet(dir.getCanonicalPath) + val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head + + val reader = new UnsafeRowParquetRecordReader + reader.initialize(file.asInstanceOf[String], null) + val batch = reader.resultBatch() + assert(reader.nextBatch()) + assert(batch.numRows() == n) + var i = 0 + while (i < n) { + assert(batch.column(0).getByte(i) == 1) + assert(batch.column(1).getInt(i) == 2) + assert(batch.column(2).getLong(i) == 3) + assert(ColumnVectorUtils.toString(batch.column(3).getByteArray(i)) == "abc") + i += 1 + } + reader.close() + } + }} + } + + test("All Types Null") { + (1 :: 100 :: Nil).foreach { n => { + withTempPath { dir => + val data = List.fill(n)(NULL_ROW).toDF + data.repartition(1).write.parquet(dir.getCanonicalPath) + val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head + + val reader = new UnsafeRowParquetRecordReader + reader.initialize(file.asInstanceOf[String], null) + val batch = reader.resultBatch() + assert(reader.nextBatch()) + assert(batch.numRows() == n) + var i = 0 + while (i < n) { + assert(batch.column(0).getIsNull(i)) + assert(batch.column(1).getIsNull(i)) + assert(batch.column(2).getIsNull(i)) + assert(batch.column(3).getIsNull(i)) + i += 1 + } + reader.close() + }} + } + } +}