Skip to content

Commit

Permalink
move import to global setup; update needsConversion
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Nov 3, 2014
1 parent 7c4a6a9 commit 2c9d7e4
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
8 changes: 3 additions & 5 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,6 @@ def _parse_datatype_json_string(json_string):
... complex_arraytype, False)
>>> check_datatype(complex_maptype)
True
>>> from pyspark.tests import ExamplePointUDT
>>> check_datatype(ExamplePointUDT())
True
>>> structtype_with_udt = StructType([StructField("label", DoubleType(), False),
Expand Down Expand Up @@ -581,7 +580,6 @@ def _parse_datatype_json_value(json_value):
def _infer_type(obj):
"""Infer the DataType from obj
>>> from pyspark.tests import ExamplePoint
>>> p = ExamplePoint(1.0, 2.0)
>>> _infer_type(p)
ExamplePointUDT
Expand Down Expand Up @@ -648,7 +646,6 @@ def _need_python_to_sql_conversion(dataType):
... StructField("values", ArrayType(DoubleType(), False), False)])
>>> _need_python_to_sql_conversion(schema0)
False
>>> from pyspark.tests import ExamplePointUDT
>>> _need_python_to_sql_conversion(ExamplePointUDT())
True
>>> schema1 = ArrayType(ExamplePointUDT(), False)
Expand Down Expand Up @@ -682,7 +679,6 @@ def _python_to_sql_converter(dataType):
>>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False))
>>> conv([1.0, 2.0])
[1.0, 2.0]
>>> from pyspark.tests import ExamplePointUDT, ExamplePoint
>>> conv = _python_to_sql_converter(ExamplePointUDT())
>>> conv(ExamplePoint(1.0, 2.0))
[1.0, 2.0]
Expand Down Expand Up @@ -953,7 +949,6 @@ def _verify_type(obj, dataType):
Traceback (most recent call last):
...
ValueError:...
>>> from pyspark.tests import ExamplePoint, ExamplePointUDT
>>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT())
>>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
Expand Down Expand Up @@ -2015,6 +2010,7 @@ def _test():
# let doctest run in pyspark.sql, so DataTypes can be picklable
import pyspark.sql
from pyspark.sql import Row, SQLContext
from pyspark.tests import ExamplePoint, ExamplePointUDT
globs = pyspark.sql.__dict__.copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
Expand All @@ -2026,6 +2022,8 @@ def _test():
Row(field1=2, field2="row2"),
Row(field1=3, field2="row3")]
)
globs['ExamplePoint'] = ExamplePoint
globs['ExamplePointUDT'] = ExamplePointUDT
jsonStrings = [
'{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
'{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
case ArrayType(_, _) => true
case MapType(_, _, _) => true
case StructType(_) => true
case _: UserDefinedType[_] => true
case udt: UserDefinedType[_] => needsConversion(udt.sqlType)
case other => false
}

Expand Down

0 comments on commit 2c9d7e4

Please sign in to comment.