diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index ae4320d4583d6..e3d81a6be5383 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -131,9 +131,9 @@ class HadoopRDD[K, V]( minPartitions) } - protected val jobConfCacheKey = "rdd_%d_job_conf".format(id) + protected val jobConfCacheKey: String = "rdd_%d_job_conf".format(id) - protected val inputFormatCacheKey = "rdd_%d_input_format".format(id) + protected val inputFormatCacheKey: String = "rdd_%d_input_format".format(id) // used to build JobTracker ID private val createTime = new Date() @@ -210,22 +210,24 @@ class HadoopRDD[K, V]( override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = { val iter = new NextIterator[(K, V)] { - val split = theSplit.asInstanceOf[HadoopPartition] + private val split = theSplit.asInstanceOf[HadoopPartition] logInfo("Input split: " + split.inputSplit) - val jobConf = getJobConf() + private val jobConf = getJobConf() - val inputMetrics = context.taskMetrics().inputMetrics - val existingBytesRead = inputMetrics.bytesRead + private val inputMetrics = context.taskMetrics().inputMetrics + private val existingBytesRead = inputMetrics.bytesRead - // Sets the thread local variable for the file's name + // Sets InputFileBlockHolder for the file block's information split.inputSplit.value match { - case fs: FileSplit => InputFileNameHolder.setInputFileName(fs.getPath.toString) - case _ => InputFileNameHolder.unsetInputFileName() + case fs: FileSplit => + InputFileBlockHolder.set(fs.getPath.toString, fs.getStart, fs.getLength) + case _ => + InputFileBlockHolder.unset() } // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val getBytesReadCallback: Option[() => Long] = split.inputSplit.value match { + private val getBytesReadCallback: Option[() => Long] = split.inputSplit.value match { case _: FileSplit | _: CombineFileSplit => SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() case _ => None @@ -235,14 +237,14 @@ class HadoopRDD[K, V]( // If we do a coalesce, however, we are likely to compute multiple partitions in the same // task and in the same thread, in which case we need to avoid override values written by // previous partitions (SPARK-13071). - def updateBytesRead(): Unit = { + private def updateBytesRead(): Unit = { getBytesReadCallback.foreach { getBytesRead => inputMetrics.setBytesRead(existingBytesRead + getBytesRead()) } } - var reader: RecordReader[K, V] = null - val inputFormat = getInputFormat(jobConf) + private var reader: RecordReader[K, V] = null + private val inputFormat = getInputFormat(jobConf) HadoopRDD.addLocalConfiguration( new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(createTime), context.stageId, theSplit.index, context.attemptNumber, jobConf) @@ -250,8 +252,8 @@ class HadoopRDD[K, V]( // Register an on-task-completion callback to close the input stream. context.addTaskCompletionListener{ context => closeIfNeeded() } - val key: K = reader.createKey() - val value: V = reader.createValue() + private val key: K = reader.createKey() + private val value: V = reader.createValue() override def getNext(): (K, V) = { try { @@ -270,7 +272,7 @@ class HadoopRDD[K, V]( override def close() { if (reader != null) { - InputFileNameHolder.unsetInputFileName() + InputFileBlockHolder.unset() // Close the reader and release it. Note: it's very important that we don't close the // reader more than once, since that exposes us to MAPREDUCE-5918 when running against // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic diff --git a/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala b/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala new file mode 100644 index 0000000000000..9ba476d2ba26a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.unsafe.types.UTF8String + +/** + * This holds file names of the current Spark task. This is used in HadoopRDD, + * FileScanRDD, NewHadoopRDD and InputFileName function in Spark SQL. + */ +private[spark] object InputFileBlockHolder { + /** + * A wrapper around some input file information. + * + * @param filePath path of the file read, or empty string if not available. + * @param startOffset starting offset, in bytes, or -1 if not available. + * @param length size of the block, in bytes, or -1 if not available. + */ + private class FileBlock(val filePath: UTF8String, val startOffset: Long, val length: Long) { + def this() { + this(UTF8String.fromString(""), -1, -1) + } + } + + /** + * The thread variable for the name of the current file being read. This is used by + * the InputFileName function in Spark SQL. + */ + private[this] val inputBlock: ThreadLocal[FileBlock] = new ThreadLocal[FileBlock] { + override protected def initialValue(): FileBlock = new FileBlock + } + + /** + * Returns the holding file name or empty string if it is unknown. + */ + def getInputFilePath: UTF8String = inputBlock.get().filePath + + /** + * Returns the starting offset of the block currently being read, or -1 if it is unknown. + */ + def getStartOffset: Long = inputBlock.get().startOffset + + /** + * Returns the length of the block being read, or -1 if it is unknown. + */ + def getLength: Long = inputBlock.get().length + + /** + * Sets the thread-local input block. + */ + def set(filePath: String, startOffset: Long, length: Long): Unit = { + require(filePath != null, "filePath cannot be null") + require(startOffset >= 0, s"startOffset ($startOffset) cannot be negative") + require(length >= 0, s"length ($length) cannot be negative") + inputBlock.set(new FileBlock(UTF8String.fromString(filePath), startOffset, length)) + } + + /** + * Clears the input file block to default value. + */ + def unset(): Unit = inputBlock.remove() +} diff --git a/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala b/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala deleted file mode 100644 index 960c91a154db1..0000000000000 --- a/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rdd - -import org.apache.spark.unsafe.types.UTF8String - -/** - * This holds file names of the current Spark task. This is used in HadoopRDD, - * FileScanRDD, NewHadoopRDD and InputFileName function in Spark SQL. - * - * The returned value should never be null but empty string if it is unknown. - */ -private[spark] object InputFileNameHolder { - /** - * The thread variable for the name of the current file being read. This is used by - * the InputFileName function in Spark SQL. - */ - private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] { - override protected def initialValue(): UTF8String = UTF8String.fromString("") - } - - /** - * Returns the holding file name or empty string if it is unknown. - */ - def getInputFileName(): UTF8String = inputFileName.get() - - private[spark] def setInputFileName(file: String) = { - require(file != null, "The input file name cannot be null") - inputFileName.set(UTF8String.fromString(file)) - } - - private[spark] def unsetInputFileName(): Unit = inputFileName.remove() - -} diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index c783e1375283a..e90e84c45904c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -132,54 +132,57 @@ class NewHadoopRDD[K, V]( override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = { val iter = new Iterator[(K, V)] { - val split = theSplit.asInstanceOf[NewHadoopPartition] + private val split = theSplit.asInstanceOf[NewHadoopPartition] logInfo("Input split: " + split.serializableHadoopSplit) - val conf = getConf + private val conf = getConf - val inputMetrics = context.taskMetrics().inputMetrics - val existingBytesRead = inputMetrics.bytesRead + private val inputMetrics = context.taskMetrics().inputMetrics + private val existingBytesRead = inputMetrics.bytesRead - // Sets the thread local variable for the file's name + // Sets InputFileBlockHolder for the file block's information split.serializableHadoopSplit.value match { - case fs: FileSplit => InputFileNameHolder.setInputFileName(fs.getPath.toString) - case _ => InputFileNameHolder.unsetInputFileName() + case fs: FileSplit => + InputFileBlockHolder.set(fs.getPath.toString, fs.getStart, fs.getLength) + case _ => + InputFileBlockHolder.unset() } // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val getBytesReadCallback: Option[() => Long] = split.serializableHadoopSplit.value match { - case _: FileSplit | _: CombineFileSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() - case _ => None - } + private val getBytesReadCallback: Option[() => Long] = + split.serializableHadoopSplit.value match { + case _: FileSplit | _: CombineFileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() + case _ => None + } // For Hadoop 2.5+, we get our input bytes from thread-local Hadoop FileSystem statistics. // If we do a coalesce, however, we are likely to compute multiple partitions in the same // task and in the same thread, in which case we need to avoid override values written by // previous partitions (SPARK-13071). - def updateBytesRead(): Unit = { + private def updateBytesRead(): Unit = { getBytesReadCallback.foreach { getBytesRead => inputMetrics.setBytesRead(existingBytesRead + getBytesRead()) } } - val format = inputFormatClass.newInstance + private val format = inputFormatClass.newInstance format match { case configurable: Configurable => configurable.setConf(conf) case _ => } - val attemptId = new TaskAttemptID(jobTrackerId, id, TaskType.MAP, split.index, 0) - val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + private val attemptId = new TaskAttemptID(jobTrackerId, id, TaskType.MAP, split.index, 0) + private val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) private var reader = format.createRecordReader( split.serializableHadoopSplit.value, hadoopAttemptContext) reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) // Register an on-task-completion callback to close the input stream. context.addTaskCompletionListener(context => close()) - var havePair = false - var finished = false - var recordsSinceMetricsUpdate = 0 + private var havePair = false + private var finished = false + private var recordsSinceMetricsUpdate = 0 override def hasNext: Boolean = { if (!finished && !havePair) { @@ -215,7 +218,7 @@ class NewHadoopRDD[K, V]( private def close() { if (reader != null) { - InputFileNameHolder.unsetInputFileName() + InputFileBlockHolder.unset() // Close the reader and release it. Note: it's very important that we don't close the // reader more than once, since that exposes us to MAPREDUCE-5918 when running against // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index e41f1cad93d4c..5d065d736ecd0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -371,6 +371,8 @@ object FunctionRegistry { expression[Sha2]("sha2"), expression[SparkPartitionID]("spark_partition_id"), expression[InputFileName]("input_file_name"), + expression[InputFileBlockStart]("input_file_block_start"), + expression[InputFileBlockLength]("input_file_block_length"), expression[MonotonicallyIncreasingID]("monotonically_increasing_id"), expression[CurrentDatabase]("current_database"), expression[CallMethodViaReflection]("reflect"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala deleted file mode 100644 index d412336699d80..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.rdd.InputFileNameHolder -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.types.{DataType, StringType} -import org.apache.spark.unsafe.types.UTF8String - -/** - * Expression that returns the name of the current file being read. - */ -@ExpressionDescription( - usage = "_FUNC_() - Returns the name of the current file being read if available.") -case class InputFileName() extends LeafExpression with Nondeterministic { - - override def nullable: Boolean = false - - override def dataType: DataType = StringType - - override def prettyName: String = "input_file_name" - - override protected def initializeInternal(partitionIndex: Int): Unit = {} - - override protected def evalInternal(input: InternalRow): UTF8String = { - InputFileNameHolder.getInputFileName() - } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + - "org.apache.spark.rdd.InputFileNameHolder.getInputFileName();", isNull = "false") - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala new file mode 100644 index 0000000000000..7a8edabed1757 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.rdd.InputFileBlockHolder +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.types.{DataType, LongType, StringType} +import org.apache.spark.unsafe.types.UTF8String + + +@ExpressionDescription( + usage = "_FUNC_() - Returns the name of the file being read, or empty string if not available.") +case class InputFileName() extends LeafExpression with Nondeterministic { + + override def nullable: Boolean = false + + override def dataType: DataType = StringType + + override def prettyName: String = "input_file_name" + + override protected def initializeInternal(partitionIndex: Int): Unit = {} + + override protected def evalInternal(input: InternalRow): UTF8String = { + InputFileBlockHolder.getInputFilePath + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") + ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + + s"$className.getInputFilePath();", isNull = "false") + } +} + + +@ExpressionDescription( + usage = "_FUNC_() - Returns the start offset of the block being read, or -1 if not available.") +case class InputFileBlockStart() extends LeafExpression with Nondeterministic { + override def nullable: Boolean = false + + override def dataType: DataType = LongType + + override def prettyName: String = "input_file_block_start" + + override protected def initializeInternal(partitionIndex: Int): Unit = {} + + override protected def evalInternal(input: InternalRow): Long = { + InputFileBlockHolder.getStartOffset + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") + ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + + s"$className.getStartOffset();", isNull = "false") + } +} + + +@ExpressionDescription( + usage = "_FUNC_() - Returns the length of the block being read, or -1 if not available.") +case class InputFileBlockLength() extends LeafExpression with Nondeterministic { + override def nullable: Boolean = false + + override def dataType: DataType = LongType + + override def prettyName: String = "input_file_block_length" + + override protected def initializeInternal(partitionIndex: Int): Unit = {} + + override protected def evalInternal(input: InternalRow): Long = { + InputFileBlockHolder.getLength + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") + ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + + s"$className.getLength();", isNull = "false") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index ccfc759c8fa7e..5f605b965b231 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -84,7 +84,7 @@ case class DataSource( case class SourceInfo(name: String, schema: StructType, partitionColumns: Seq[String]) lazy val providingClass: Class[_] = DataSource.lookupDataSource(className) - lazy val sourceInfo = sourceSchema() + lazy val sourceInfo: SourceInfo = sourceSchema() private val caseInsensitiveOptions = new CaseInsensitiveMap(options) /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 89944570df662..306dc6527e5a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -23,7 +23,7 @@ import scala.collection.mutable import org.apache.spark.{Partition => RDDPartition, TaskContext} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.rdd.{InputFileNameHolder, RDD} +import org.apache.spark.rdd.{InputFileBlockHolder, RDD} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.vectorized.ColumnarBatch @@ -121,7 +121,8 @@ class FileScanRDD( if (files.hasNext) { currentFile = files.next() logInfo(s"Reading File $currentFile") - InputFileNameHolder.setInputFileName(currentFile.filePath) + // Sets InputFileBlockHolder for the file block's information + InputFileBlockHolder.set(currentFile.filePath, currentFile.start, currentFile.length) try { if (ignoreCorruptFiles) { @@ -162,7 +163,7 @@ class FileScanRDD( hasNext } else { currentFile = null - InputFileNameHolder.unsetInputFileName() + InputFileBlockHolder.unset() false } } @@ -170,7 +171,7 @@ class FileScanRDD( override def close(): Unit = { updateBytesRead() updateBytesReadWithFileSize() - InputFileNameHolder.unsetInputFileName() + InputFileBlockHolder.unset() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 26e1a9f75da13..b0339a88fbf62 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -533,31 +533,54 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { ) } - test("input_file_name - FileScanRDD") { + test("input_file_name, input_file_block_start, input_file_block_length - FileScanRDD") { withTempPath { dir => val data = sparkContext.parallelize(0 to 10).toDF("id") data.write.parquet(dir.getCanonicalPath) - val answer = spark.read.parquet(dir.getCanonicalPath).select(input_file_name()) - .head.getString(0) - assert(answer.contains(dir.getCanonicalPath)) - checkAnswer(data.select(input_file_name()).limit(1), Row("")) + // Test the 3 expressions when reading from files + val q = spark.read.parquet(dir.getCanonicalPath).select( + input_file_name(), expr("input_file_block_start()"), expr("input_file_block_length()")) + val firstRow = q.head() + assert(firstRow.getString(0).contains(dir.getCanonicalPath)) + assert(firstRow.getLong(1) == 0) + assert(firstRow.getLong(2) > 0) + + // Now read directly from the original RDD without going through any files to make sure + // we are returning empty string, -1, and -1. + checkAnswer( + data.select( + input_file_name(), expr("input_file_block_start()"), expr("input_file_block_length()") + ).limit(1), + Row("", -1L, -1L)) } } - test("input_file_name - HadoopRDD") { + test("input_file_name, input_file_block_start, input_file_block_length - HadoopRDD") { withTempPath { dir => val data = sparkContext.parallelize((0 to 10).map(_.toString)).toDF() data.write.text(dir.getCanonicalPath) val df = spark.sparkContext.textFile(dir.getCanonicalPath).toDF() - val answer = df.select(input_file_name()).head.getString(0) - assert(answer.contains(dir.getCanonicalPath)) - checkAnswer(data.select(input_file_name()).limit(1), Row("")) + // Test the 3 expressions when reading from files + val q = df.select( + input_file_name(), expr("input_file_block_start()"), expr("input_file_block_length()")) + val firstRow = q.head() + assert(firstRow.getString(0).contains(dir.getCanonicalPath)) + assert(firstRow.getLong(1) == 0) + assert(firstRow.getLong(2) > 0) + + // Now read directly from the original RDD without going through any files to make sure + // we are returning empty string, -1, and -1. + checkAnswer( + data.select( + input_file_name(), expr("input_file_block_start()"), expr("input_file_block_length()") + ).limit(1), + Row("", -1L, -1L)) } } - test("input_file_name - NewHadoopRDD") { + test("input_file_name, input_file_block_start, input_file_block_length - NewHadoopRDD") { withTempPath { dir => val data = sparkContext.parallelize((0 to 10).map(_.toString)).toDF() data.write.text(dir.getCanonicalPath) @@ -567,10 +590,22 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { classOf[LongWritable], classOf[Text]) val df = rdd.map(pair => pair._2.toString).toDF() - val answer = df.select(input_file_name()).head.getString(0) - assert(answer.contains(dir.getCanonicalPath)) - checkAnswer(data.select(input_file_name()).limit(1), Row("")) + // Test the 3 expressions when reading from files + val q = df.select( + input_file_name(), expr("input_file_block_start()"), expr("input_file_block_length()")) + val firstRow = q.head() + assert(firstRow.getString(0).contains(dir.getCanonicalPath)) + assert(firstRow.getLong(1) == 0) + assert(firstRow.getLong(2) > 0) + + // Now read directly from the original RDD without going through any files to make sure + // we are returning empty string, -1, and -1. + checkAnswer( + data.select( + input_file_name(), expr("input_file_block_start()"), expr("input_file_block_length()") + ).limit(1), + Row("", -1L, -1L)) } }