Skip to content

Commit

Permalink
[SPARK-25945][SQL] Support locale while parsing date/timestamp from C…
Browse files Browse the repository at this point in the history
…SV/JSON

## What changes were proposed in this pull request?

In the PR, I propose to add new option `locale` into CSVOptions/JSONOptions to make parsing date/timestamps in local languages possible. Currently the locale is hard coded to `Locale.US`.

## How was this patch tested?

Added two tests for parsing a date from CSV/JSON - `ноя 2018`.

Closes apache#22951 from MaxGekk/locale.

Authored-by: Maxim Gekk <[email protected]>
Signed-off-by: hyukjinkwon <[email protected]>
  • Loading branch information
MaxGekk authored and jackylee-ch committed Feb 18, 2019
1 parent 9f51022 commit 587b063
Show file tree
Hide file tree
Showing 10 changed files with 109 additions and 14 deletions.
15 changes: 11 additions & 4 deletions python/pyspark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None,
multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None,
dropFieldIfAllNull=None, encoding=None):
dropFieldIfAllNull=None, encoding=None, locale=None):
"""
Loads JSON files and returns the results as a :class:`DataFrame`.
Expand Down Expand Up @@ -249,6 +249,9 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
:param dropFieldIfAllNull: whether to ignore column of all null values or empty
array/struct during schema inference. If None is set, it
uses the default value, ``false``.
:param locale: sets a locale as language tag in IETF BCP 47 format. If None is set,
it uses the default value, ``en-US``. For instance, ``locale`` is used while
parsing dates and timestamps.
>>> df1 = spark.read.json('python/test_support/sql/people.json')
>>> df1.dtypes
Expand All @@ -267,7 +270,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat,
timestampFormat=timestampFormat, multiLine=multiLine,
allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep,
samplingRatio=samplingRatio, dropFieldIfAllNull=dropFieldIfAllNull, encoding=encoding)
samplingRatio=samplingRatio, dropFieldIfAllNull=dropFieldIfAllNull, encoding=encoding,
locale=locale)
if isinstance(path, basestring):
path = [path]
if type(path) == list:
Expand Down Expand Up @@ -349,7 +353,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None,
columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None,
samplingRatio=None, enforceSchema=None, emptyValue=None):
samplingRatio=None, enforceSchema=None, emptyValue=None, locale=None):
r"""Loads a CSV file and returns the result as a :class:`DataFrame`.
This function will go through the input once to determine the input schema if
Expand Down Expand Up @@ -446,6 +450,9 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
If None is set, it uses the default value, ``1.0``.
:param emptyValue: sets the string representation of an empty value. If None is set, it uses
the default value, empty string.
:param locale: sets a locale as language tag in IETF BCP 47 format. If None is set,
it uses the default value, ``en-US``. For instance, ``locale`` is used while
parsing dates and timestamps.
>>> df = spark.read.csv('python/test_support/sql/ages.csv')
>>> df.dtypes
Expand All @@ -465,7 +472,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode,
columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine,
charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio,
enforceSchema=enforceSchema, emptyValue=emptyValue)
enforceSchema=enforceSchema, emptyValue=emptyValue, locale=locale)
if isinstance(path, basestring):
path = [path]
if type(path) == list:
Expand Down
14 changes: 10 additions & 4 deletions python/pyspark/sql/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None,
allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None,
multiLine=None, allowUnquotedControlChars=None, lineSep=None):
multiLine=None, allowUnquotedControlChars=None, lineSep=None, locale=None):
"""
Loads a JSON file stream and returns the results as a :class:`DataFrame`.
Expand Down Expand Up @@ -469,6 +469,9 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
including tab and line feed characters) or not.
:param lineSep: defines the line separator that should be used for parsing. If None is
set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``.
:param locale: sets a locale as language tag in IETF BCP 47 format. If None is set,
it uses the default value, ``en-US``. For instance, ``locale`` is used while
parsing dates and timestamps.
>>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema)
>>> json_sdf.isStreaming
Expand All @@ -483,7 +486,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter,
mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat,
timestampFormat=timestampFormat, multiLine=multiLine,
allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep)
allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep, locale=locale)
if isinstance(path, basestring):
return self._df(self._jreader.json(path))
else:
Expand Down Expand Up @@ -564,7 +567,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None,
columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None,
enforceSchema=None, emptyValue=None):
enforceSchema=None, emptyValue=None, locale=None):
r"""Loads a CSV file stream and returns the result as a :class:`DataFrame`.
This function will go through the input once to determine the input schema if
Expand Down Expand Up @@ -660,6 +663,9 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
different, ``\0`` otherwise..
:param emptyValue: sets the string representation of an empty value. If None is set, it uses
the default value, empty string.
:param locale: sets a locale as language tag in IETF BCP 47 format. If None is set,
it uses the default value, ``en-US``. For instance, ``locale`` is used while
parsing dates and timestamps.
>>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema)
>>> csv_sdf.isStreaming
Expand All @@ -677,7 +683,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode,
columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine,
charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, enforceSchema=enforceSchema,
emptyValue=emptyValue)
emptyValue=emptyValue, locale=locale)
if isinstance(path, basestring):
return self._df(self._jreader.csv(path))
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,16 @@ class CSVOptions(
val timeZone: TimeZone = DateTimeUtils.getTimeZone(
parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId))

// A language tag in IETF BCP 47 format
val locale: Locale = parameters.get("locale").map(Locale.forLanguageTag).getOrElse(Locale.US)

// Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe.
val dateFormat: FastDateFormat =
FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), Locale.US)
FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), locale)

val timestampFormat: FastDateFormat =
FastDateFormat.getInstance(
parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US)
parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, locale)

val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,19 @@ private[sql] class JSONOptions(
// Whether to ignore column of all null values or empty array/struct during schema inference
val dropFieldIfAllNull = parameters.get("dropFieldIfAllNull").map(_.toBoolean).getOrElse(false)

// A language tag in IETF BCP 47 format
val locale: Locale = parameters.get("locale").map(Locale.forLanguageTag).getOrElse(Locale.US)

val timeZone: TimeZone = DateTimeUtils.getTimeZone(
parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId))

// Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe.
val dateFormat: FastDateFormat =
FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), Locale.US)
FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), locale)

val timestampFormat: FastDateFormat =
FastDateFormat.getInstance(
parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US)
parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, locale)

val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

package org.apache.spark.sql.catalyst.expressions

import java.util.Calendar
import java.text.SimpleDateFormat
import java.util.{Calendar, Locale}

import org.scalatest.exceptions.TestFailedException

Expand Down Expand Up @@ -209,4 +210,20 @@ class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P
"2015-12-31T16:00:00"
)
}

test("parse date with locale") {
Seq("en-US", "ru-RU").foreach { langTag =>
val locale = Locale.forLanguageTag(langTag)
val date = new SimpleDateFormat("yyyy-MM-dd").parse("2018-11-05")
val schema = new StructType().add("d", DateType)
val dateFormat = "MMM yyyy"
val sdf = new SimpleDateFormat(dateFormat, locale)
val dateStr = sdf.format(date)
val options = Map("dateFormat" -> dateFormat, "locale" -> langTag)

checkEvaluation(
CsvToStructs(schema, options, Literal.create(dateStr), gmtId),
InternalRow(17836)) // number of days from 1970-01-01
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

package org.apache.spark.sql.catalyst.expressions

import java.util.Calendar
import java.text.SimpleDateFormat
import java.util.{Calendar, Locale}

import org.scalatest.exceptions.TestFailedException

Expand Down Expand Up @@ -737,4 +738,20 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
CreateMap(Seq(Literal.create("allowNumericLeadingZeros"), Literal.create("true")))),
"struct<col:bigint>")
}

test("parse date with locale") {
Seq("en-US", "ru-RU").foreach { langTag =>
val locale = Locale.forLanguageTag(langTag)
val date = new SimpleDateFormat("yyyy-MM-dd").parse("2018-11-05")
val schema = new StructType().add("d", DateType)
val dateFormat = "MMM yyyy"
val sdf = new SimpleDateFormat(dateFormat, locale)
val dateStr = s"""{"d":"${sdf.format(date)}"}"""
val options = Map("dateFormat" -> dateFormat, "locale" -> langTag)

checkEvaluation(
JsonToStructs(schema, options, Literal.create(dateStr), gmtId),
InternalRow(17836)) // number of days from 1970-01-01
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* for schema inferring.</li>
* <li>`dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or
* empty array/struct during schema inference.</li>
* <li>`locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format.
* For instance, this is used while parsing dates and timestamps.</li>
* </ul>
*
* @since 2.0.0
Expand Down Expand Up @@ -604,6 +606,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string
* created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.</li>
* <li>`multiLine` (default `false`): parse one record, which may span multiple lines.</li>
* <li>`locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format.
* For instance, this is used while parsing dates and timestamps.</li>
* </ul>
*
* @since 2.0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
* that should be used for parsing.</li>
* <li>`dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or
* empty array/struct during schema inference.</li>
* <li>`locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format.
* For instance, this is used while parsing dates and timestamps.</li>
* </ul>
*
* @since 2.0.0
Expand Down Expand Up @@ -372,6 +374,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
* `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string
* created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.</li>
* <li>`multiLine` (default `false`): parse one record, which may span multiple lines.</li>
* <li>`locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format.
* For instance, this is used while parsing dates and timestamps.</li>
* </ul>
*
* @since 2.0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.sql

import java.text.SimpleDateFormat
import java.util.Locale

import scala.collection.JavaConverters._

import org.apache.spark.SparkException
Expand Down Expand Up @@ -164,4 +167,18 @@ class CsvFunctionsSuite extends QueryTest with SharedSQLContext {
val df1 = Seq(Tuple1(Tuple1(1))).toDF("a")
checkAnswer(df1.selectExpr("to_csv(a)"), Row("1") :: Nil)
}

test("parse timestamps with locale") {
Seq("en-US", "ko-KR", "zh-CN", "ru-RU").foreach { langTag =>
val locale = Locale.forLanguageTag(langTag)
val ts = new SimpleDateFormat("dd/MM/yyyy HH:mm").parse("06/11/2018 18:00")
val timestampFormat = "dd MMM yyyy HH:mm"
val sdf = new SimpleDateFormat(timestampFormat, locale)
val input = Seq(s"""${sdf.format(ts)}""").toDS()
val options = Map("timestampFormat" -> timestampFormat, "locale" -> langTag)
val df = input.select(from_csv($"value", lit("time timestamp"), options.asJava))

checkAnswer(df, Row(Row(java.sql.Timestamp.valueOf("2018-11-06 18:00:00.0"))))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.sql

import java.text.SimpleDateFormat
import java.util.Locale

import collection.JavaConverters._

import org.apache.spark.SparkException
Expand Down Expand Up @@ -591,4 +594,18 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
df.select(from_json($"value", schema, Map("columnNameOfCorruptRecord" -> "_unparsed"))),
Row(Row(null, badRec, null)) :: Row(Row(2, null, 12)) :: Nil)
}

test("parse timestamps with locale") {
Seq("en-US", "ko-KR", "zh-CN", "ru-RU").foreach { langTag =>
val locale = Locale.forLanguageTag(langTag)
val ts = new SimpleDateFormat("dd/MM/yyyy HH:mm").parse("06/11/2018 18:00")
val timestampFormat = "dd MMM yyyy HH:mm"
val sdf = new SimpleDateFormat(timestampFormat, locale)
val input = Seq(s"""{"time": "${sdf.format(ts)}"}""").toDS()
val options = Map("timestampFormat" -> timestampFormat, "locale" -> langTag)
val df = input.select(from_json($"value", "time timestamp", options))

checkAnswer(df, Row(Row(java.sql.Timestamp.valueOf("2018-11-06 18:00:00.0"))))
}
}
}

0 comments on commit 587b063

Please sign in to comment.