Skip to content

Commit

Permalink
optim get schema and get class name in shards into 1 rdd operation (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
dding3 authored Aug 31, 2022
1 parent 0ee9418 commit 6958678
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions python/orca/src/bigdl/orca/data/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,21 +598,32 @@ def utility_func(x, func, *args, **kwargs):
def get_schema(self):
if 'schema' in self.type:
return self.type['schema']
else:
if self._get_class_name() == 'pandas.core.frame.DataFrame':
import pandas as pd
columns, dtypes = self.rdd.map(lambda x: (x.columns, x.dtypes)).first()
self.type['schema'] = {'columns': columns, 'dtypes': dtypes}
return self.type['schema']
return None

if 'class_name' not in self.type\
or self.type['class_name'] == 'pandas.core.frame.DataFrame':
class_name, schema = self._get_schema_class_name()
self.type['class_name'] = class_name
self.type['schema'] = schema
return self.type['schema']

def _get_class_name(self):
if 'class_name' in self.type:
return self.type['class_name']
else:
self.type['class_name'] = self._for_each(get_class_name).first()
class_name, schema = self._get_schema_class_name()
self.type['class_name'] = class_name
self.type['schema'] = schema
return self.type['class_name']

def _get_schema_class_name(self):
def func(x):
class_name = get_class_name(x)
schema = None
if class_name == 'pandas.core.frame.DataFrame':
schema = {'columns': x.columns, 'dtypes': x.dtypes}
return (class_name, schema)
return self.rdd.map(lambda x: func(x)).first()


class SharedValue(object):
def __init__(self, data):
Expand Down

0 comments on commit 6958678

Please sign in to comment.