From c86febe6b018faafa62e0bf6444f8cd4326fb021 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 23 Feb 2017 13:50:15 +0900 Subject: [PATCH] Apply review comments --- python/pyspark/sql/readwriter.py | 2 +- python/pyspark/sql/streaming.py | 16 +++++++++++++--- .../org/apache/spark/sql/DataFrameReader.scala | 2 +- .../datasources/csv/UnivocityParser.scala | 2 +- .../spark/sql/streaming/DataStreamReader.scala | 2 +- .../sql/execution/datasources/csv/CSVSuite.scala | 1 - 6 files changed, 17 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index d12ceb7900d67..251df07577f3f 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -368,7 +368,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non If None is set, it uses the default value, session local timezone. * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted record. - If users set a string-type field named ``columnNameOfCorruptRecord`` in a + If users set a string type field named ``columnNameOfCorruptRecord`` in a user-specified ``schema``, it puts the malformed string into the field. When a ``schema`` is set by user, it sets ``null`` for extra fields. * ``DROPMALFORMED`` : ignores the whole corrupted records. diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 965c8f6b269e9..60ca2b9e8fd7e 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -558,7 +558,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, - maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None): + maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None, + columnNameOfCorruptRecord=None): """Loads a CSV file stream and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -619,10 +620,18 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non If None is set, it uses the default value, session local timezone. * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted record. - When a schema is set by user, it sets ``null`` for extra fields. + If users set a string type field named ``columnNameOfCorruptRecord`` in a + user-specified ``schema``, it puts the malformed string into the field. When + a ``schema`` is set by user, it sets ``null`` for extra fields. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. + :param columnNameOfCorruptRecord: defines a field name for malformed strings created + by ``PERMISSIVE`` mode. If a user-specified `schema` + has this named field, Spark puts malformed strings + in this field. This overrides + `spark.sql.columnNameOfCorruptRecord`. + >>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema) >>> csv_sdf.isStreaming True @@ -636,7 +645,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non nanValue=nanValue, positiveInf=positiveInf, negativeInf=negativeInf, dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn, - maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone) + maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone, + columnNameOfCorruptRecord=columnNameOfCorruptRecord) if isinstance(path, basestring): return self._df(self._jreader.csv(path)) else: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 5208c72363760..07206532879bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -423,7 +423,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * during parsing. *