diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 7492b88c471a4..1a351933a366c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -177,6 +177,10 @@ public void pointTo(byte[] buf, int sizeInBytes) { pointTo(buf, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); } + public void setTotalSize(int sizeInBytes) { + this.sizeInBytes = sizeInBytes; + } + public void setNotNullAt(int i) { assertIndexIsValid(i); BitSetMethods.unset(baseObject, baseOffset, i); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java index a6758bddfa7d0..198bfb6d67aee 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -256,6 +256,15 @@ private boolean loadBatch() throws IOException { numBatched = num; batchIdx = 0; } + + // Update the total row lengths if the schema contained variable length. We did not maintain + // this as we populated the columns. + if (containsVarLenFields) { + for (int i = 0; i < numBatched; ++i) { + rows[i].setTotalSize(rowWriters[i].holder().totalSize()); + } + } + return true; } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 0c5d4887ed799..b0581e8b35510 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -38,6 +38,7 @@ import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -618,6 +619,29 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { readResourceParquetFile("dec-in-fixed-len.parquet"), sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'fixed_len_dec)) } + + test("SPARK-12589 copy() on rows returned from reader works for strings") { + withTempPath { dir => + val data = (1, "abc") ::(2, "helloabcde") :: Nil + data.toDF().write.parquet(dir.getCanonicalPath) + var hash1: Int = 0 + var hash2: Int = 0 + (false :: true :: Nil).foreach { v => + withSQLConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key -> v.toString) { + val df = sqlContext.read.parquet(dir.getCanonicalPath) + val rows = df.queryExecution.toRdd.map(_.copy()).collect() + val unsafeRows = rows.map(_.asInstanceOf[UnsafeRow]) + if (!v) { + hash1 = unsafeRows(0).hashCode() + hash2 = unsafeRows(1).hashCode() + } else { + assert(hash1 == unsafeRows(0).hashCode()) + assert(hash2 == unsafeRows(1).hashCode()) + } + } + } + } + } } class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext)