diff --git a/python/orca/src/bigdl/orca/data/shard.py b/python/orca/src/bigdl/orca/data/shard.py index c2899c63016..587aebeb04a 100644 --- a/python/orca/src/bigdl/orca/data/shard.py +++ b/python/orca/src/bigdl/orca/data/shard.py @@ -598,21 +598,32 @@ def utility_func(x, func, *args, **kwargs): def get_schema(self): if 'schema' in self.type: return self.type['schema'] - else: - if self._get_class_name() == 'pandas.core.frame.DataFrame': - import pandas as pd - columns, dtypes = self.rdd.map(lambda x: (x.columns, x.dtypes)).first() - self.type['schema'] = {'columns': columns, 'dtypes': dtypes} - return self.type['schema'] - return None + + if 'class_name' not in self.type\ + or self.type['class_name'] == 'pandas.core.frame.DataFrame': + class_name, schema = self._get_schema_class_name() + self.type['class_name'] = class_name + self.type['schema'] = schema + return self.type['schema'] def _get_class_name(self): if 'class_name' in self.type: return self.type['class_name'] else: - self.type['class_name'] = self._for_each(get_class_name).first() + class_name, schema = self._get_schema_class_name() + self.type['class_name'] = class_name + self.type['schema'] = schema return self.type['class_name'] + def _get_schema_class_name(self): + def func(x): + class_name = get_class_name(x) + schema = None + if class_name == 'pandas.core.frame.DataFrame': + schema = {'columns': x.columns, 'dtypes': x.dtypes} + return (class_name, schema) + return self.rdd.map(lambda x: func(x)).first() + class SharedValue(object): def __init__(self, data):