Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-22395][SQL][PYTHON] Fix the behavior of timestamp values for Pandas to respect session timezone #19607

Closed
wants to merge 29 commits into from
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
4735e59
Add a conf to make Pandas DataFrame respect session local timezone.
ueshin Oct 23, 2017
1f85150
Fix toPandas() behavior.
ueshin Oct 23, 2017
5c08ecf
Modify pandas UDFs to respect session timezone.
ueshin Oct 23, 2017
ee1a1c8
Workaround for old pandas.
ueshin Nov 1, 2017
b1436b8
Don't use is_datetime64tz_dtype for old pandas.
ueshin Nov 1, 2017
6872516
Fix one of the failed tests.
ueshin Nov 1, 2017
1f096bf
Modify check_data udf for debug messages.
ueshin Nov 2, 2017
569bb63
Remove unused method.
ueshin Nov 3, 2017
ce07f39
Modify a test.
ueshin Nov 3, 2017
ba3d6e3
Add debug print, which will be removed later.
ueshin Nov 6, 2017
9101a3a
Fix style.
ueshin Nov 6, 2017
ab13baf
Remove debug prints.
ueshin Nov 8, 2017
4adb073
Modify tests to avoid times within DST.
ueshin Nov 8, 2017
1e0f217
Clean up.
ueshin Nov 8, 2017
d18cd36
Merge branch 'master' into issues/SPARK-22395
ueshin Nov 8, 2017
292678f
Fix the behavior of createDataFrame from pandas DataFrame.
ueshin Nov 8, 2017
f37c067
Merge branch 'master' into issues/SPARK-22395
ueshin Nov 13, 2017
8b1a4d8
Add a test to check the behavior of createDataFrame from pandas DataF…
ueshin Nov 13, 2017
e919ed5
Clarify the usage of Row.
ueshin Nov 13, 2017
9c94f90
Merge branch 'master' into issues/SPARK-22395
ueshin Nov 20, 2017
9cfdde2
Add TODOs for nested timestamp fields.
ueshin Nov 21, 2017
8b1a4a1
Remove workarounds for old Pandas but add some error messages saying …
ueshin Nov 21, 2017
3db2bea
Fix tests.
ueshin Nov 21, 2017
3e23653
Use `_exception_message()` to access error messages.
ueshin Nov 21, 2017
d741171
Fix a test.
ueshin Nov 21, 2017
e240631
Add a description about deprecation of the config.
ueshin Nov 27, 2017
f92eae3
Add migration guide.
ueshin Nov 27, 2017
40a9735
Merge branch 'master' into issues/SPARK-22395
ueshin Nov 27, 2017
9200f38
Address comments.
ueshin Nov 28, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def __repr__(self):
return "ArrowSerializer"


def _create_batch(series):
def _create_batch(series, timezone):
from pyspark.sql.types import _check_series_convert_timestamps_internal
import pyarrow as pa
# Make input conform to [(series1, type1), (series2, type2), ...]
Expand All @@ -227,7 +227,7 @@ def _create_batch(series):
def cast_series(s, t):
if type(t) == pa.TimestampType:
# NOTE: convert to 'us' with astype here, unit ignored in `from_pandas` see ARROW-1680
return _check_series_convert_timestamps_internal(s.fillna(0))\
return _check_series_convert_timestamps_internal(s.fillna(0), timezone)\
.values.astype('datetime64[us]', copy=False)
elif t == pa.date32():
# TODO: this converts the series to Python objects, possibly avoid with Arrow >= 0.8
Expand All @@ -252,6 +252,10 @@ class ArrowStreamPandasSerializer(Serializer):
Serializes Pandas.Series as Arrow data with Arrow streaming format.
"""

def __init__(self, timezone):
super(ArrowStreamPandasSerializer, self).__init__()
self._timezone = timezone

def dump_stream(self, iterator, stream):
"""
Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or
Expand All @@ -261,7 +265,7 @@ def dump_stream(self, iterator, stream):
writer = None
try:
for series in iterator:
batch = _create_batch(series)
batch = _create_batch(series, self._timezone)
if writer is None:
write_int(SpecialLengths.START_ARROW_STREAM, stream)
writer = pa.RecordBatchStreamWriter(stream, batch.schema)
Expand All @@ -274,12 +278,13 @@ def load_stream(self, stream):
"""
Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
"""
from pyspark.sql.types import _check_dataframe_localize_timestamps
from pyspark.sql.types import _check_dataframe_localize_timestamps, from_arrow_schema
import pyarrow as pa
reader = pa.open_stream(stream)
schema = from_arrow_schema(reader.schema)
for batch in reader:
# NOTE: changed from pa.Columns.to_pandas, timezone issue in conversion fixed in 0.7.1
pdf = _check_dataframe_localize_timestamps(batch.to_pandas())
pdf = _check_dataframe_localize_timestamps(batch.to_pandas(), schema, self._timezone)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If self._timezone is not None, then it will be the SESSION_LOCAL_TIMEZONE and Arrow data will already have this timezone set so nothing needs to be done here right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, maybe I misunderstood the purpose of this conf "spark.sql.execution.pandas.respectSessionTimeZone". If that is true then what is the behavior of Spark?

  1. convert timestamps in Pandas to remove the timezone and localize to SESSION_LOCAL_TIMEZONE

  2. show Pandas timestamps with SESSION_LOCAL_TIMEZONE set as the timezone

It seems this change is doing (1), but what's wrong with doing (2)? I think that would be a lot cleaner

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, current implementation is doing (1). I'm not sure if we should hold the timezone. cc @cloud-fan @gatorsmile

yield [c for _, c in pdf.iteritems()]

def __repr__(self):
Expand Down
13 changes: 10 additions & 3 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1881,15 +1881,22 @@ def toPandas(self):
1 5 Bob
"""
import pandas as pd
from pyspark.sql.types import _check_dataframe_localize_timestamps

if self.sql_ctx.getConf("spark.sql.execution.pandas.respectSessionTimeZone").lower() \
== "true":
timezone = self.sql_ctx.getConf("spark.sql.session.timeZone")
else:
timezone = None

if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true":
try:
from pyspark.sql.types import _check_dataframe_localize_timestamps
import pyarrow
tables = self._collectAsArrow()
if tables:
table = pyarrow.concat_tables(tables)
pdf = table.to_pandas()
return _check_dataframe_localize_timestamps(pdf)
return _check_dataframe_localize_timestamps(pdf, self.schema, timezone)
else:
return pd.DataFrame.from_records([], columns=self.columns)
except ImportError as e:
Expand All @@ -1913,7 +1920,7 @@ def toPandas(self):

for f, t in dtype.items():
pdf[f] = pdf[f].astype(t, copy=False)
return pdf
return _check_dataframe_localize_timestamps(pdf, self.schema, timezone)

def _collectAsArrow(self):
"""
Expand Down
133 changes: 118 additions & 15 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2560,19 +2560,25 @@ def count_bucketed_cols(names, table="pyspark_bucket"):

@unittest.skipIf(not _have_pandas, "Pandas not installed")
def test_to_pandas(self):
from datetime import datetime, date
import numpy as np
schema = StructType().add("a", IntegerType()).add("b", StringType())\
.add("c", BooleanType()).add("d", FloatType())
.add("c", BooleanType()).add("d", FloatType())\
.add("dt", DateType()).add("ts", TimestampType())
data = [
(1, "foo", True, 3.0), (2, "foo", True, 5.0),
(3, "bar", False, -1.0), (4, "bar", False, 6.0),
(1, "foo", True, 3.0, date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
(2, "foo", True, 5.0, None, None),
(3, "bar", False, -1.0, date(2012, 3, 3), datetime(2012, 3, 3, 3, 3, 3)),
(4, "bar", False, 6.0, date(2100, 4, 4), datetime(2100, 4, 4, 4, 4, 4)),
]
df = self.spark.createDataFrame(data, schema)
types = df.toPandas().dtypes
self.assertEquals(types[0], np.int32)
self.assertEquals(types[1], np.object)
self.assertEquals(types[2], np.bool)
self.assertEquals(types[3], np.float32)
self.assertEquals(types[4], 'datetime64[ns]')
self.assertEquals(types[5], 'datetime64[ns]')

@unittest.skipIf(not _have_pandas, "Pandas not installed")
def test_to_pandas_avoid_astype(self):
Expand Down Expand Up @@ -3136,14 +3142,42 @@ def test_null_conversion(self):
null_counts = pdf.isnull().sum().tolist()
self.assertTrue(all([c == 1 for c in null_counts]))

def test_toPandas_arrow_toggle(self):
df = self.spark.createDataFrame(self.data, schema=self.schema)
def _toPandas_arrow_toggle(self, df):
self.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
pdf = df.toPandas()
self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
try:
pdf = df.toPandas()
finally:
self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
pdf_arrow = df.toPandas()
return pdf, pdf_arrow

def test_toPandas_arrow_toggle(self):
df = self.spark.createDataFrame(self.data, schema=self.schema)
pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
self.assertFramesEqual(pdf_arrow, pdf)

def test_toPandas_respect_session_timezone(self):
df = self.spark.createDataFrame(self.data, schema=self.schema)
orig_tz = self.spark.conf.get("spark.sql.session.timeZone")
try:
timezone = "America/New_York"
self.spark.conf.set("spark.sql.session.timeZone", timezone)
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false")
try:
pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df)
self.assertFramesEqual(pdf_arrow_la, pdf_la)
finally:
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true")
pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df)
self.assertFramesEqual(pdf_arrow_ny, pdf_ny)

from pyspark.sql.types import _check_dataframe_localize_timestamps
self.assertFalse(pdf_ny.equals(pdf_la))
self.assertTrue(pdf_ny.equals(
_check_dataframe_localize_timestamps(pdf_la, self.schema, timezone)))
finally:
self.spark.conf.set("spark.sql.session.timeZone", orig_tz)

def test_pandas_round_trip(self):
import pandas as pd
import numpy as np
Expand All @@ -3169,6 +3203,27 @@ def test_filtered_frame(self):
@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
class VectorizedUDFTests(ReusedSQLTestCase):

@classmethod
def setUpClass(cls):
ReusedSQLTestCase.setUpClass()

# Synchronize default timezone between Python and Java
cls.tz_prev = os.environ.get("TZ", None) # save current tz if set
tz = "America/Los_Angeles"
os.environ["TZ"] = tz
time.tzset()

cls.sc.environment["TZ"] = tz
cls.spark.conf.set("spark.sql.session.timeZone", tz)

@classmethod
def tearDownClass(cls):
del os.environ["TZ"]
if cls.tz_prev is not None:
os.environ["TZ"] = cls.tz_prev
time.tzset()
ReusedSQLTestCase.tearDownClass()

def test_vectorized_udf_basic(self):
from pyspark.sql.functions import pandas_udf, col
df = self.spark.range(10).select(
Expand Down Expand Up @@ -3429,22 +3484,29 @@ def test_vectorized_udf_timestamps(self):
f_timestamp_copy = pandas_udf(lambda t: t, returnType=TimestampType())
df = df.withColumn("timestamp_copy", f_timestamp_copy(col("timestamp")))

@pandas_udf(returnType=BooleanType())
@pandas_udf(returnType=StringType())
def check_data(idx, timestamp, timestamp_copy):
import pandas as pd
msgs = []
is_equal = timestamp.isnull() # use this array to check values are equal
for i in range(len(idx)):
# Check that timestamps are as expected in the UDF
is_equal[i] = (is_equal[i] and data[idx[i]][1] is None) or \
timestamp[i].to_pydatetime() == data[idx[i]][1]
return is_equal

result = df.withColumn("is_equal", check_data(col("idx"), col("timestamp"),
col("timestamp_copy"))).collect()
if (is_equal[i] and data[idx[i]][1] is None) or \
timestamp[i].to_pydatetime() == data[idx[i]][1]:
msgs.append(None)
else:
msgs.append(
"timestamp values are not equal (timestamp='%s': data[%d][1]='%s')"
% (timestamp[i], idx[i], data[idx[i]][1]))
return pd.Series(msgs)

result = df.withColumn("check_data", check_data(col("idx"), col("timestamp"),
col("timestamp_copy"))).collect()
# Check that collection values are correct
self.assertEquals(len(data), len(result))
for i in range(len(result)):
self.assertEquals(data[i][1], result[i][1]) # "timestamp" col
self.assertTrue(result[i][3]) # "is_equal" data in udf was as expected
self.assertIsNone(result[i][3]) # "check_data" col

def test_vectorized_udf_return_timestamp_tz(self):
from pyspark.sql.functions import pandas_udf, col
Expand Down Expand Up @@ -3484,6 +3546,47 @@ def check_records_per_batch(x):
else:
self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", orig_value)

def test_vectorized_udf_timestamps_respect_session_timezone(self):
from pyspark.sql.functions import pandas_udf, col
from datetime import datetime
import pandas as pd
schema = StructType([
StructField("idx", LongType(), True),
StructField("timestamp", TimestampType(), True)])
data = [(1, datetime(1969, 1, 1, 1, 1, 1)),
(2, datetime(2012, 2, 2, 2, 2, 2)),
(3, None),
(4, datetime(2100, 4, 4, 4, 4, 4))]
df = self.spark.createDataFrame(data, schema=schema)

f_timestamp_copy = pandas_udf(lambda ts: ts, TimestampType())
internal_value = pandas_udf(
lambda ts: ts.apply(lambda ts: ts.value if ts is not pd.NaT else None), LongType())

orig_tz = self.spark.conf.get("spark.sql.session.timeZone")
try:
timezone = "America/New_York"
self.spark.conf.set("spark.sql.session.timeZone", timezone)
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false")
try:
df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
.withColumn("internal_value", internal_value(col("timestamp")))
result_la = df_la.select(col("idx"), col("internal_value")).collect()
diff = 3 * 60 * 60 * 1000 * 1000 * 1000
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here too. it took me a while to check where this 3 came from ..

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'll add some comments.

result_la_corrected = \
df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect()
finally:
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true")

df_ny = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
.withColumn("internal_value", internal_value(col("timestamp")))
result_ny = df_ny.select(col("idx"), col("tscopy"), col("internal_value")).collect()

self.assertNotEqual(result_ny, result_la)
self.assertEqual(result_ny, result_la_corrected)
finally:
self.spark.conf.set("spark.sql.session.timeZone", orig_tz)


@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
class GroupbyApplyTests(ReusedSQLTestCase):
Expand Down
Loading