Skip to content

Commit

Permalink
Workaround for old pandas.
Browse files Browse the repository at this point in the history
  • Loading branch information
ueshin committed Nov 1, 2017
1 parent 5c08ecf commit ee1a1c8
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 20 deletions.
23 changes: 16 additions & 7 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2560,19 +2560,25 @@ def count_bucketed_cols(names, table="pyspark_bucket"):

@unittest.skipIf(not _have_pandas, "Pandas not installed")
def test_to_pandas(self):
from datetime import datetime, date
import numpy as np
schema = StructType().add("a", IntegerType()).add("b", StringType())\
.add("c", BooleanType()).add("d", FloatType())
.add("c", BooleanType()).add("d", FloatType())\
.add("dt", DateType()).add("ts", TimestampType())
data = [
(1, "foo", True, 3.0), (2, "foo", True, 5.0),
(3, "bar", False, -1.0), (4, "bar", False, 6.0),
(1, "foo", True, 3.0, date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
(2, "foo", True, 5.0, None, None),
(3, "bar", False, -1.0, date(2012, 3, 3), datetime(2012, 3, 3, 3, 3, 3)),
(4, "bar", False, 6.0, date(2100, 4, 4), datetime(2100, 4, 4, 4, 4, 4)),
]
df = self.spark.createDataFrame(data, schema)
types = df.toPandas().dtypes
self.assertEquals(types[0], np.int32)
self.assertEquals(types[1], np.object)
self.assertEquals(types[2], np.bool)
self.assertEquals(types[3], np.float32)
self.assertEquals(types[4], 'datetime64[ns]')
self.assertEquals(types[5], 'datetime64[ns]')

@unittest.skipIf(not _have_pandas, "Pandas not installed")
def test_to_pandas_avoid_astype(self):
Expand Down Expand Up @@ -3544,6 +3550,7 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self):
(3, datetime(2100, 3, 3, 3, 3, 3))]
df = self.spark.createDataFrame(data, schema=schema)

f_timestamp_copy = pandas_udf(lambda ts: ts, TimestampType())
internal_value = pandas_udf(lambda ts: ts.apply(lambda ts: ts.value), LongType())

orig_tz = self.spark.conf.get("spark.sql.session.timeZone")
Expand All @@ -3552,16 +3559,18 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self):
self.spark.conf.set("spark.sql.session.timeZone", timezone)
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false")
try:
df_la = df.withColumn("internal_value", internal_value(col("timestamp")))
df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
.withColumn("internal_value", internal_value(col("timestamp")))
result_la = df_la.select(col("idx"), col("internal_value")).collect()
diff = 3 * 60 * 60 * 1000 * 1000 * 1000
result_la_corrected = \
df_la.select(col("idx"), col("internal_value") + diff).collect()
df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect()
finally:
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true")

df_ny = df.withColumn("internal_value", internal_value(col("timestamp")))
result_ny = df_ny.select(col("idx"), col("internal_value")).collect()
df_ny = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
.withColumn("internal_value", internal_value(col("timestamp")))
result_ny = df_ny.select(col("idx"), col("tscopy"), col("internal_value")).collect()

self.assertNotEqual(result_ny, result_la)
self.assertEqual(result_ny, result_la_corrected)
Expand Down
51 changes: 38 additions & 13 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1685,17 +1685,36 @@ def _check_dataframe_localize_timestamps(pdf, schema, timezone):
:param pdf: pandas.DataFrame
:return pandas.DataFrame where any timezone aware columns have be converted to tz-naive
"""
from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype
tz = timezone or 'tzlocal()'
for column, series in pdf.iteritems():
if type(schema[str(column)].dataType) == TimestampType:
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if is_datetime64tz_dtype(series.dtype):
pdf[column] = series.dt.tz_convert(tz).dt.tz_localize(None)
elif is_datetime64_dtype(series.dtype) and timezone is not None:
# `series.dt.tz_localize('tzlocal()')` doesn't work properly when including NaT.
pdf[column] = series.apply(lambda ts: ts.tz_localize('tzlocal()')) \
.dt.tz_convert(tz).dt.tz_localize(None)
try:
from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype
tz = timezone or 'tzlocal()'
for column, series in pdf.iteritems():
if type(schema[str(column)].dataType) == TimestampType:
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if is_datetime64tz_dtype(series.dtype):
pdf[column] = series.dt.tz_convert(tz).dt.tz_localize(None)
elif is_datetime64_dtype(series.dtype) and timezone is not None:
# `series.dt.tz_localize('tzlocal()')` doesn't work properly when including NaT.
pdf[column] = series.apply(lambda ts: ts.tz_localize('tzlocal()')) \
.dt.tz_convert(tz).dt.tz_localize(None)
except ImportError:
import pandas as pd
from pandas.core.common import is_datetime64tz_dtype, is_datetime64_dtype
from pandas.tslib import _dateutil_tzlocal
tzlocal = _dateutil_tzlocal()
tz = timezone or tzlocal
for column, series in pdf.iteritems():
if type(schema[str(column)].dataType) == TimestampType:
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if is_datetime64tz_dtype(series.dtype):
# `series.dt.tz_convert(tzlocal).dt.tz_localize(None)` doesn't work properly.
pdf[column] = pd.Series([ts.tz_convert(tz).tz_localize(None)
if ts is not pd.NaT else pd.NaT for ts in series])
elif is_datetime64_dtype(series.dtype) and timezone is not None:
# `series.dt.tz_localize(tzlocal)` doesn't work properly.
pdf[column] = pd.Series(
[ts.tz_localize(tzlocal).tz_convert(tz).tz_localize(None)
if ts is not pd.NaT else pd.NaT for ts in series])
return pdf


Expand All @@ -1705,10 +1724,16 @@ def _check_series_convert_timestamps_internal(s, timezone):
:param s: a pandas.Series
:return pandas.Series where if it is a timestamp, has been UTC normalized without a time zone
"""
from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
try:
from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype
tzlocal = 'tzlocal()'
except ImportError:
from pandas.core.common import is_datetime64tz_dtype, is_datetime64_dtype
from pandas.tslib import _dateutil_tzlocal
tzlocal = _dateutil_tzlocal()
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if is_datetime64_dtype(s.dtype):
tz = timezone or 'tzlocal()'
tz = timezone or tzlocal
return s.dt.tz_localize(tz).dt.tz_convert('UTC')
elif is_datetime64tz_dtype(s.dtype):
return s.dt.tz_convert('UTC')
Expand Down

0 comments on commit ee1a1c8

Please sign in to comment.