Skip to content

Commit

Permalink
take more rows to infer schema, or infer the schema by sampling the RDD
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Oct 8, 2014
1 parent f18dd59 commit 3603e00
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 64 deletions.
183 changes: 119 additions & 64 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,17 @@ def __call__(cls):
return cls._instances[cls]


class NullType(DataType):

"""Spark SQL NullType"""

__metaclass__ = PrimitiveTypeSingleton

def __eq__(self, other):
# because they should be the same object
return self is other


class PrimitiveType(DataType):

"""Spark SQL PrimitiveType"""
Expand Down Expand Up @@ -443,6 +454,7 @@ def _parse_datatype_string(datatype_string):

# Mapping Python types to Spark SQL DateType
_type_mappings = {
type(None): NullType,
bool: BooleanType,
int: IntegerType,
long: LongType,
Expand All @@ -459,22 +471,22 @@ def _parse_datatype_string(datatype_string):

def _infer_type(obj):
"""Infer the DataType from obj"""
if obj is None:
raise ValueError("Can not infer type for None")

dataType = _type_mappings.get(type(obj))
if dataType is not None:
return dataType()

if isinstance(obj, dict):
if not obj:
raise ValueError("Can not infer type for empty dict")
key, value = obj.iteritems().next()
return MapType(_infer_type(key), _infer_type(value), True)
for key, value in obj.iteritems():
if key is not None and value is not None:
return MapType(_infer_type(key), _infer_type(value), True)
else:
return MapType(NullType(), NullType(), True)
elif isinstance(obj, (list, array)):
if not obj:
raise ValueError("Can not infer type for empty list/array")
return ArrayType(_infer_type(obj[0]), True)
for v in obj:
if v is not None:
return ArrayType(_infer_type(obj[0]), True)
else:
return ArrayType(NullType(), True)
else:
try:
return _infer_schema(obj)
Expand Down Expand Up @@ -507,60 +519,85 @@ def _infer_schema(row):
return StructType(fields)


def _create_converter(obj, dataType):
def _has_nulltype(dt):
""" Return whether there is NullType in `dt` or not """
if isinstance(dt, StructType):
return any(_has_nulltype(f.dataType) for f in dt.fields)
elif isinstance(dt, ArrayType):
return _has_nulltype((dt.elementType))
elif isinstance(dt, MapType):
return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType)
else:
return isinstance(dt, NullType)


def _merge_type(a, b):
if isinstance(a, NullType):
return b
elif isinstance(b, NullType):
return a
elif type(a) is not type(b):
raise TypeError("Can not merge type %s and %s" % (a, b))

# same type
if isinstance(a, StructType):
# TODO: merge different fields
if [f.name for f in a.fields] != [f.name for f in b.fields]:
raise TypeError("Can not merge two StructType with different fields")
fields = [StructField(fa.name, _merge_type(fa.dataType, fb.dataType), True)
for fa, fb in zip(a.fields, b.fields)]
return StructType(fields)

elif isinstance(a, ArrayType):
return ArrayType(_merge_type(a.elementType, b.elementType), True)
elif isinstance(a, MapType):
return MapType(_merge_type(a.keyType, b.keyType),
_merge_type(a.valueType, b.valueType),
True)
else:
return a


def _create_converter(dataType):
"""Create an converter to drop the names of fields in obj """
if isinstance(dataType, ArrayType):
conv = _create_converter(obj[0], dataType.elementType)
conv = _create_converter(dataType.elementType)
return lambda row: map(conv, row)

elif isinstance(dataType, MapType):
value = obj.values()[0]
conv = _create_converter(value, dataType.valueType)
conv = _create_converter(dataType.valueType)
return lambda row: dict((k, conv(v)) for k, v in row.iteritems())

elif isinstance(dataType, NullType):
return lambda x: None

elif not isinstance(dataType, StructType):
return lambda x: x

# dataType must be StructType
names = [f.name for f in dataType.fields]
converters = [_create_converter(f.dataType) for f in dataType.fields]

if isinstance(obj, dict):
conv = lambda o: tuple(o.get(n) for n in names)

elif isinstance(obj, tuple):
if hasattr(obj, "_fields"): # namedtuple
conv = tuple
elif hasattr(obj, "__FIELDS__"):
conv = tuple
elif all(isinstance(x, tuple) and len(x) == 2 for x in obj):
conv = lambda o: tuple(v for k, v in o)
else:
raise ValueError("unexpected tuple")

elif hasattr(obj, "__dict__"): # object
conv = lambda o: [o.__dict__.get(n, None) for n in names]

if all(isinstance(f.dataType, PrimitiveType) for f in dataType.fields):
return conv
def convert_struct(obj):
if obj is None:
return

row = conv(obj)
convs = [_create_converter(v, f.dataType)
for v, f in zip(row, dataType.fields)]
if isinstance(obj, dict):
vs = [obj.get(n) for n in names]

def nested_conv(row):
return tuple(f(v) for f, v in zip(convs, conv(row)))

return nested_conv
elif isinstance(obj, tuple):
if all(isinstance(x, tuple) and len(x) == 2 for x in obj):
vs = [v for k, v in obj]
else:
vs = obj

elif hasattr(obj, "__dict__"): # object
vs = [obj.__dict__.get(n) for n in names]
else:
raise ValueError("Unexpected obj: %s" % obj)
return tuple([conv(v) for conv, v in zip(converters, vs)])

def _drop_schema(rows, schema):
""" all the names of fields, becoming tuples"""
iterator = iter(rows)
row = iterator.next()
converter = _create_converter(row, schema)
yield converter(row)
for i in iterator:
yield converter(i)
return convert_struct


_BRACKETS = {'(': ')', '[': ']', '{': '}'}
Expand Down Expand Up @@ -672,7 +709,7 @@ def _infer_schema_type(obj, dataType):
return _infer_type(obj)

if not obj:
raise ValueError("Can not infer type from empty value")
return NullType()

if isinstance(dataType, ArrayType):
eType = _infer_schema_type(obj[0], dataType.elementType)
Expand Down Expand Up @@ -994,19 +1031,22 @@ def registerFunction(self, name, f, returnType=StringType()):
self._sc._javaAccumulator,
str(returnType))

def inferSchema(self, rdd):
def inferSchema(self, rdd, samplingRatio=None):
"""Infer and apply a schema to an RDD of L{Row}.
We peek at the first row of the RDD to determine the fields' names
and types. Nested collections are supported, which include array,
dict, list, Row, tuple, namedtuple, or object.
If `samplingRatio` is presented, it infer schema by all of the sampled
dataset.
All the rows in `rdd` should have the same type with the first one,
or it will cause runtime exceptions.
Otherwise, it peeks first few rows of the RDD to determine the fields'
names and types. Nested collections are supported, which include array,
dict, list, Row, tuple, namedtuple, or object.
Each row could be L{pyspark.sql.Row} object or namedtuple or objects,
using dict is deprecated.
If some of rows has different types with inferred types, it may cause
runtime exceptions.
>>> rdd = sc.parallelize(
... [Row(field1=1, field2="row1"),
... Row(field1=2, field2="row2"),
Expand Down Expand Up @@ -1042,8 +1082,23 @@ def inferSchema(self, rdd):
warnings.warn("Using RDD of dict to inferSchema is deprecated,"
"please use pyspark.sql.Row instead")

schema = _infer_schema(first)
rdd = rdd.mapPartitions(lambda rows: _drop_schema(rows, schema))
if samplingRatio is None:
schema = _infer_schema(first)
if _has_nulltype(schema):
for row in rdd.take(100)[1:]:
schema = _merge_type(schema, _infer_schema(row))
if not _has_nulltype(schema):
break
else:
warnings.warn("Some of types cannot be determined by the "
"first 100 rows, please try again with sampling")
else:
if samplingRatio > 0.99:
rdd = rdd.sample(False, float(samplingRatio))
schema = rdd.map(_infer_schema).reduce(_merge_type)

converter = _create_converter(schema)
rdd = rdd.map(converter)
return self.applySchema(rdd, schema)

def applySchema(self, rdd, schema):
Expand Down Expand Up @@ -1161,16 +1216,16 @@ def parquetFile(self, path):
jschema_rdd = self._ssql_ctx.parquetFile(path).toJavaSchemaRDD()
return SchemaRDD(jschema_rdd, self)

def jsonFile(self, path, schema=None):
def jsonFile(self, path, schema=None, samplingRatio=1.0):
"""
Loads a text file storing one JSON object per line as a
L{SchemaRDD}.
If the schema is provided, applies the given schema to this
JSON dataset.
Otherwise, it goes through the entire dataset once to determine
the schema.
Otherwise, it samples the dataset with ratio `samplingRatio` to
determine the schema.
>>> import tempfile, shutil
>>> jsonFile = tempfile.mkdtemp()
Expand Down Expand Up @@ -1216,20 +1271,20 @@ def jsonFile(self, path, schema=None):
[Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)]
"""
if schema is None:
srdd = self._ssql_ctx.jsonFile(path)
srdd = self._ssql_ctx.jsonFile(path, samplingRatio)
else:
scala_datatype = self._ssql_ctx.parseDataType(str(schema))
srdd = self._ssql_ctx.jsonFile(path, scala_datatype)
return SchemaRDD(srdd.toJavaSchemaRDD(), self)

def jsonRDD(self, rdd, schema=None):
def jsonRDD(self, rdd, schema=None, samplingRatio=1.0):
"""Loads an RDD storing one JSON object per string as a L{SchemaRDD}.
If the schema is provided, applies the given schema to this
JSON dataset.
Otherwise, it goes through the entire dataset once to determine
the schema.
Otherwise, it samples the dataset with ratio `samplingRatio` to
determine the schema.
>>> srdd1 = sqlCtx.jsonRDD(json)
>>> sqlCtx.registerRDDAsTable(srdd1, "table1")
Expand Down Expand Up @@ -1286,7 +1341,7 @@ def func(iterator):
keyed._bypass_serializer = True
jrdd = keyed._jrdd.map(self._jvm.BytesToString())
if schema is None:
srdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio)
else:
scala_datatype = self._ssql_ctx.parseDataType(str(schema))
srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
Expand Down
11 changes: 11 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,17 @@ def test_serialize_nested_array_and_map(self):
self.assertEqual(1.0, row.c)
self.assertEqual("2", row.d)

def test_infer_schema(self):
d = [Row(l=[], d={}, s=None),
Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
rdd = self.sc.parallelize(d)
srdd = self.sqlCtx.inferSchema(rdd)
self.assertEqual([], srdd.map(lambda r: r.l).first())
self.assertEqual([None, ""], srdd.map(lambda r: r.s).collect())
srdd2 = self.sqlCtx.inferSchema(rdd, 1.0)
self.assertEqual({}, srdd.map(lambda r: r.d).first())
self.assertEqual(srdd.schema(), srdd2.schema())


class InputFormatTests(ReusedPySparkTestCase):

Expand Down

0 comments on commit 3603e00

Please sign in to comment.