diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 690b13072244b..726de4a965418 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -177,7 +177,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None, - dropFieldIfAllNull=None, encoding=None): + dropFieldIfAllNull=None, encoding=None, locale=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. @@ -249,6 +249,9 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param dropFieldIfAllNull: whether to ignore column of all null values or empty array/struct during schema inference. If None is set, it uses the default value, ``false``. + :param locale: sets a locale as language tag in IETF BCP 47 format. If None is set, + it uses the default value, ``en-US``. For instance, ``locale`` is used while + parsing dates and timestamps. >>> df1 = spark.read.json('python/test_support/sql/people.json') >>> df1.dtypes @@ -267,7 +270,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep, - samplingRatio=samplingRatio, dropFieldIfAllNull=dropFieldIfAllNull, encoding=encoding) + samplingRatio=samplingRatio, dropFieldIfAllNull=dropFieldIfAllNull, encoding=encoding, + locale=locale) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -349,7 +353,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, - samplingRatio=None, enforceSchema=None, emptyValue=None): + samplingRatio=None, enforceSchema=None, emptyValue=None, locale=None): r"""Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -446,6 +450,9 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non If None is set, it uses the default value, ``1.0``. :param emptyValue: sets the string representation of an empty value. If None is set, it uses the default value, empty string. + :param locale: sets a locale as language tag in IETF BCP 47 format. If None is set, + it uses the default value, ``en-US``. For instance, ``locale`` is used while + parsing dates and timestamps. >>> df = spark.read.csv('python/test_support/sql/ages.csv') >>> df.dtypes @@ -465,7 +472,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio, - enforceSchema=enforceSchema, emptyValue=emptyValue) + enforceSchema=enforceSchema, emptyValue=emptyValue, locale=locale) if isinstance(path, basestring): path = [path] if type(path) == list: diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index b18453b2a4f96..02b14ea187cba 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -404,7 +404,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - multiLine=None, allowUnquotedControlChars=None, lineSep=None): + multiLine=None, allowUnquotedControlChars=None, lineSep=None, locale=None): """ Loads a JSON file stream and returns the results as a :class:`DataFrame`. @@ -469,6 +469,9 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, including tab and line feed characters) or not. :param lineSep: defines the line separator that should be used for parsing. If None is set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. + :param locale: sets a locale as language tag in IETF BCP 47 format. If None is set, + it uses the default value, ``en-US``. For instance, ``locale`` is used while + parsing dates and timestamps. >>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema) >>> json_sdf.isStreaming @@ -483,7 +486,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, - allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep) + allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep, locale=locale) if isinstance(path, basestring): return self._df(self._jreader.json(path)) else: @@ -564,7 +567,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, - enforceSchema=None, emptyValue=None): + enforceSchema=None, emptyValue=None, locale=None): r"""Loads a CSV file stream and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -660,6 +663,9 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non different, ``\0`` otherwise.. :param emptyValue: sets the string representation of an empty value. If None is set, it uses the default value, empty string. + :param locale: sets a locale as language tag in IETF BCP 47 format. If None is set, + it uses the default value, ``en-US``. For instance, ``locale`` is used while + parsing dates and timestamps. >>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema) >>> csv_sdf.isStreaming @@ -677,7 +683,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, enforceSchema=enforceSchema, - emptyValue=emptyValue) + emptyValue=emptyValue, locale=locale) if isinstance(path, basestring): return self._df(self._jreader.csv(path)) else: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala index cdaaa172e8367..642823582a645 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala @@ -131,13 +131,16 @@ class CSVOptions( val timeZone: TimeZone = DateTimeUtils.getTimeZone( parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) + // A language tag in IETF BCP 47 format + val locale: Locale = parameters.get("locale").map(Locale.forLanguageTag).getOrElse(Locale.US) + // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. val dateFormat: FastDateFormat = - FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), Locale.US) + FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), locale) val timestampFormat: FastDateFormat = FastDateFormat.getInstance( - parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US) + parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, locale) val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 64152e04928d2..e10b8a327c01a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -76,16 +76,19 @@ private[sql] class JSONOptions( // Whether to ignore column of all null values or empty array/struct during schema inference val dropFieldIfAllNull = parameters.get("dropFieldIfAllNull").map(_.toBoolean).getOrElse(false) + // A language tag in IETF BCP 47 format + val locale: Locale = parameters.get("locale").map(Locale.forLanguageTag).getOrElse(Locale.US) + val timeZone: TimeZone = DateTimeUtils.getTimeZone( parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. val dateFormat: FastDateFormat = - FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), Locale.US) + FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), locale) val timestampFormat: FastDateFormat = FastDateFormat.getInstance( - parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US) + parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, locale) val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala index d006197bd5678..f5aaaec456153 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.expressions -import java.util.Calendar +import java.text.SimpleDateFormat +import java.util.{Calendar, Locale} import org.scalatest.exceptions.TestFailedException @@ -209,4 +210,20 @@ class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P "2015-12-31T16:00:00" ) } + + test("parse date with locale") { + Seq("en-US", "ru-RU").foreach { langTag => + val locale = Locale.forLanguageTag(langTag) + val date = new SimpleDateFormat("yyyy-MM-dd").parse("2018-11-05") + val schema = new StructType().add("d", DateType) + val dateFormat = "MMM yyyy" + val sdf = new SimpleDateFormat(dateFormat, locale) + val dateStr = sdf.format(date) + val options = Map("dateFormat" -> dateFormat, "locale" -> langTag) + + checkEvaluation( + CsvToStructs(schema, options, Literal.create(dateStr), gmtId), + InternalRow(17836)) // number of days from 1970-01-01 + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 304642161146b..6ee8c74010d3d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.expressions -import java.util.Calendar +import java.text.SimpleDateFormat +import java.util.{Calendar, Locale} import org.scalatest.exceptions.TestFailedException @@ -737,4 +738,20 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with CreateMap(Seq(Literal.create("allowNumericLeadingZeros"), Literal.create("true")))), "struct") } + + test("parse date with locale") { + Seq("en-US", "ru-RU").foreach { langTag => + val locale = Locale.forLanguageTag(langTag) + val date = new SimpleDateFormat("yyyy-MM-dd").parse("2018-11-05") + val schema = new StructType().add("d", DateType) + val dateFormat = "MMM yyyy" + val sdf = new SimpleDateFormat(dateFormat, locale) + val dateStr = s"""{"d":"${sdf.format(date)}"}""" + val options = Map("dateFormat" -> dateFormat, "locale" -> langTag) + + checkEvaluation( + JsonToStructs(schema, options, Literal.create(dateStr), gmtId), + InternalRow(17836)) // number of days from 1970-01-01 + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 95c97e5c9433c..02ffc940184db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -384,6 +384,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * for schema inferring. *
  • `dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or * empty array/struct during schema inference.
  • + *
  • `locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format. + * For instance, this is used while parsing dates and timestamps.
  • * * * @since 2.0.0 @@ -604,6 +606,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`. *
  • `multiLine` (default `false`): parse one record, which may span multiple lines.
  • + *
  • `locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format. + * For instance, this is used while parsing dates and timestamps.
  • * * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 4c7dcedafeeae..20c84305776ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -296,6 +296,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * that should be used for parsing. *
  • `dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or * empty array/struct during schema inference.
  • + *
  • `locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format. + * For instance, this is used while parsing dates and timestamps.
  • * * * @since 2.0.0 @@ -372,6 +374,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`. *
  • `multiLine` (default `false`): parse one record, which may span multiple lines.
  • + *
  • `locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format. + * For instance, this is used while parsing dates and timestamps.
  • * * * @since 2.0.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala index b97ac380def63..1c359ce1d2014 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql +import java.text.SimpleDateFormat +import java.util.Locale + import scala.collection.JavaConverters._ import org.apache.spark.SparkException @@ -164,4 +167,18 @@ class CsvFunctionsSuite extends QueryTest with SharedSQLContext { val df1 = Seq(Tuple1(Tuple1(1))).toDF("a") checkAnswer(df1.selectExpr("to_csv(a)"), Row("1") :: Nil) } + + test("parse timestamps with locale") { + Seq("en-US", "ko-KR", "zh-CN", "ru-RU").foreach { langTag => + val locale = Locale.forLanguageTag(langTag) + val ts = new SimpleDateFormat("dd/MM/yyyy HH:mm").parse("06/11/2018 18:00") + val timestampFormat = "dd MMM yyyy HH:mm" + val sdf = new SimpleDateFormat(timestampFormat, locale) + val input = Seq(s"""${sdf.format(ts)}""").toDS() + val options = Map("timestampFormat" -> timestampFormat, "locale" -> langTag) + val df = input.select(from_csv($"value", lit("time timestamp"), options.asJava)) + + checkAnswer(df, Row(Row(java.sql.Timestamp.valueOf("2018-11-06 18:00:00.0")))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index d6b73387e84b3..24e7564259c83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql +import java.text.SimpleDateFormat +import java.util.Locale + import collection.JavaConverters._ import org.apache.spark.SparkException @@ -591,4 +594,18 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { df.select(from_json($"value", schema, Map("columnNameOfCorruptRecord" -> "_unparsed"))), Row(Row(null, badRec, null)) :: Row(Row(2, null, 12)) :: Nil) } + + test("parse timestamps with locale") { + Seq("en-US", "ko-KR", "zh-CN", "ru-RU").foreach { langTag => + val locale = Locale.forLanguageTag(langTag) + val ts = new SimpleDateFormat("dd/MM/yyyy HH:mm").parse("06/11/2018 18:00") + val timestampFormat = "dd MMM yyyy HH:mm" + val sdf = new SimpleDateFormat(timestampFormat, locale) + val input = Seq(s"""{"time": "${sdf.format(ts)}"}""").toDS() + val options = Map("timestampFormat" -> timestampFormat, "locale" -> langTag) + val df = input.select(from_json($"value", "time timestamp", options)) + + checkAnswer(df, Row(Row(java.sql.Timestamp.valueOf("2018-11-06 18:00:00.0")))) + } + } }