Skip to content

Commit

Permalink
Apply review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
maropu committed Feb 18, 2017
1 parent 763601d commit 873a383
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,10 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
val broadcastedHadoopConf =
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))

val columnNameOfCorruptRecord = csvOptions.columnNameOfCorruptRecord.getOrElse(sparkSession
.sessionState.conf.columnNameOfCorruptRecord)
val shouldHandleCorruptRecord = csvOptions.permissive && requiredSchema.exists { f =>
f.name == columnNameOfCorruptRecord && f.dataType == StringType && f.nullable
}
val parsedOptions = new CSVOptions(
options,
sparkSession.sessionState.conf.sessionLocalTimeZone,
sparkSession.sessionState.conf.columnNameOfCorruptRecord)

(file: PartitionedFile) => {
val lines = {
Expand All @@ -126,11 +125,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
}

val filteredLines = CSVUtils.filterCommentAndEmpty(linesWithoutHeader, csvOptions)
val parser = if (shouldHandleCorruptRecord) {
new UnivocityParser(dataSchema, requiredSchema, csvOptions, Some(columnNameOfCorruptRecord))
} else {
new UnivocityParser(dataSchema, requiredSchema, csvOptions)
}
val parser = new UnivocityParser(dataSchema, requiredSchema, parsedOptions)
filteredLines.flatMap(parser.parse)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,21 @@ import org.apache.commons.lang3.time.FastDateFormat
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs, ParseModes}

private[sql] class CSVOptions(
@transient private val parameters: CaseInsensitiveMap[String], defaultTimeZoneId: String)
private[csv] class CSVOptions(
@transient private val parameters: CaseInsensitiveMap[String],
defaultTimeZoneId: String,
defaultColumnNameOfCorruptRecord: String)
extends Logging with Serializable {

def this(parameters: Map[String, String], defaultTimeZoneId: String) =
this(CaseInsensitiveMap(parameters), defaultTimeZoneId)
def this(
parameters: Map[String, String],
defaultTimeZoneId: String,
defaultColumnNameOfCorruptRecord: String = "") = {
this(
CaseInsensitiveMap(parameters),
defaultTimeZoneId,
defaultColumnNameOfCorruptRecord)
}

private def getChar(paramName: String, default: Char): Char = {
val paramValue = parameters.get(paramName)
Expand Down Expand Up @@ -95,7 +104,8 @@ private[sql] class CSVOptions(
val dropMalformed = ParseModes.isDropMalformedMode(parseMode)
val permissive = ParseModes.isPermissiveMode(parseMode)

val columnNameOfCorruptRecord = parameters.get("columnNameOfCorruptRecord")
val columnNameOfCorruptRecord =
parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord)

val nullValue = parameters.getOrElse("nullValue", "")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ import org.apache.spark.unsafe.types.UTF8String
private[csv] class UnivocityParser(
schema: StructType,
requiredSchema: StructType,
options: CSVOptions,
columnNameOfCorruptRecord: Option[String] = None) extends Logging {
options: CSVOptions) extends Logging {
require(requiredSchema.toSet.subsetOf(schema.toSet),
"requiredSchema should be the subset of schema.")

Expand All @@ -46,9 +45,15 @@ private[csv] class UnivocityParser(
// A `ValueConverter` is responsible for converting the given value to a desired type.
private type ValueConverter = String => Any

private val inputSchema = columnNameOfCorruptRecord.map { fn =>
StructType(schema.filter(_.name != fn))
}.getOrElse(schema)
private val shouldHandleCorruptRecord = options.permissive && requiredSchema.exists { f =>
f.name == options.columnNameOfCorruptRecord && f.dataType == StringType && f.nullable
}

private val inputSchema = if (shouldHandleCorruptRecord) {
StructType(schema.filter(_.name != options.columnNameOfCorruptRecord))
} else {
schema
}

private val valueConverters =
inputSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray
Expand All @@ -59,10 +64,8 @@ private[csv] class UnivocityParser(

private val row = new GenericInternalRow(requiredSchema.length)

private val shouldHandleCorruptRecord = columnNameOfCorruptRecord.isDefined
private val corruptIndex = columnNameOfCorruptRecord.flatMap { fn =>
requiredSchema.getFieldIndex(fn)
}.getOrElse(-1)
private val corruptIndex =
requiredSchema.getFieldIndex(options.columnNameOfCorruptRecord).getOrElse(-1)

private val indexArr: Array[(Int, Int)] = {
val fields = if (options.dropMalformed) {
Expand Down Expand Up @@ -185,32 +188,13 @@ private[csv] class UnivocityParser(
*/
def parse(input: String): Option[InternalRow] = {
convertWithParseMode(parser.parseLine(input)) { tokens =>
var foundMalformed: Boolean = false
indexArr.foreach { case (pos, i) =>
try {
// It anyway needs to try to parse since it decides if this row is malformed
// or not after trying to cast in `DROPMALFORMED` mode even if the casted
// value is not stored in the row.
val value = valueConverters(pos).apply(tokens(pos))
if (i < requiredSchema.length) {
row(i) = value
}
} catch {
case _: NumberFormatException | _: IllegalArgumentException
if options.permissive && shouldHandleCorruptRecord =>
foundMalformed = true
if (i < requiredSchema.length) {
row.setNullAt(i)
}
case e: Throwable =>
throw e
}
}
if (shouldHandleCorruptRecord) {
if (foundMalformed) {
row(corruptIndex) = UTF8String.fromString(tokens.mkString(options.delimiter.toString))
} else {
row.setNullAt(corruptIndex)
// It anyway needs to try to parse since it decides if this row is malformed
// or not after trying to cast in `DROPMALFORMED` mode even if the casted
// value is not stored in the row.
val value = valueConverters(pos).apply(tokens(pos))
if (i < requiredSchema.length) {
row(i) = value
}
}
row
Expand All @@ -234,6 +218,9 @@ private[csv] class UnivocityParser(
throw new RuntimeException(s"Malformed line in FAILFAST mode: " +
s"${tokens.mkString(options.delimiter.toString)}")
} else {
if (options.permissive && shouldHandleCorruptRecord) {
row.setNullAt(corruptIndex)
}
val checkedTokens = if (options.permissive && inputSchema.length > tokens.length) {
tokens ++ new Array[String](inputSchema.length - tokens.length)
} else if (options.permissive && inputSchema.length < tokens.length) {
Expand All @@ -245,11 +232,23 @@ private[csv] class UnivocityParser(
try {
Some(convert(checkedTokens))
} catch {
// We only catch exceptions about malformed values here and pass over other exceptions
// (e.g., SparkException about unsupported types).
case _: NumberFormatException | _: IllegalArgumentException
if options.permissive && shouldHandleCorruptRecord =>
row(corruptIndex) = UTF8String.fromString(tokens.mkString(options.delimiter.toString))
Some(row)
case NonFatal(e) if options.dropMalformed =>
if (numMalformedRecords < options.maxMalformedLogPerPartition) {
logWarning("Parse exception. " +
s"Dropping malformed line: ${tokens.mkString(options.delimiter.toString)}")
}
if (numMalformedRecords == options.maxMalformedLogPerPartition - 1) {
logWarning(
s"More than ${options.maxMalformedLogPerPartition} malformed records have been " +
"found on this partition. Malformed records from now on will not be logged.")
}
numMalformedRecords += 1
None
}
}
Expand Down

0 comments on commit 873a383

Please sign in to comment.