diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 3814108471a9d..04c7740ef1c6e 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -527,7 +527,6 @@ def _parse_datatype_json_string(json_string): ... complex_arraytype, False) >>> check_datatype(complex_maptype) True - >>> from pyspark.tests import ExamplePointUDT >>> check_datatype(ExamplePointUDT()) True >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False), @@ -581,7 +580,6 @@ def _parse_datatype_json_value(json_value): def _infer_type(obj): """Infer the DataType from obj - >>> from pyspark.tests import ExamplePoint >>> p = ExamplePoint(1.0, 2.0) >>> _infer_type(p) ExamplePointUDT @@ -648,7 +646,6 @@ def _need_python_to_sql_conversion(dataType): ... StructField("values", ArrayType(DoubleType(), False), False)]) >>> _need_python_to_sql_conversion(schema0) False - >>> from pyspark.tests import ExamplePointUDT >>> _need_python_to_sql_conversion(ExamplePointUDT()) True >>> schema1 = ArrayType(ExamplePointUDT(), False) @@ -682,7 +679,6 @@ def _python_to_sql_converter(dataType): >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False)) >>> conv([1.0, 2.0]) [1.0, 2.0] - >>> from pyspark.tests import ExamplePointUDT, ExamplePoint >>> conv = _python_to_sql_converter(ExamplePointUDT()) >>> conv(ExamplePoint(1.0, 2.0)) [1.0, 2.0] @@ -953,7 +949,6 @@ def _verify_type(obj, dataType): Traceback (most recent call last): ... ValueError:... - >>> from pyspark.tests import ExamplePoint, ExamplePointUDT >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): @@ -2015,6 +2010,7 @@ def _test(): # let doctest run in pyspark.sql, so DataTypes can be picklable import pyspark.sql from pyspark.sql import Row, SQLContext + from pyspark.tests import ExamplePoint, ExamplePointUDT globs = pyspark.sql.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: @@ -2026,6 +2022,8 @@ def _test(): Row(field1=2, field2="row2"), Row(field1=3, field2="row3")] ) + globs['ExamplePoint'] = ExamplePoint + globs['ExamplePointUDT'] = ExamplePointUDT jsonStrings = [ '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 83d18daf7a259..84eaf401f240c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -484,7 +484,7 @@ class SQLContext(@transient val sparkContext: SparkContext) case ArrayType(_, _) => true case MapType(_, _, _) => true case StructType(_) => true - case _: UserDefinedType[_] => true + case udt: UserDefinedType[_] => needsConversion(udt.sqlType) case other => false }