From afcf047cf7c58264bb0a831dc26531fdcd5bc1d9 Mon Sep 17 00:00:00 2001 From: Kai Huang Date: Wed, 17 Jun 2020 16:11:05 +0800 Subject: [PATCH] initial --- pyzoo/zoo/orca/data/pandas/preprocessing.py | 74 ++++++--------------- 1 file changed, 21 insertions(+), 53 deletions(-) diff --git a/pyzoo/zoo/orca/data/pandas/preprocessing.py b/pyzoo/zoo/orca/data/pandas/preprocessing.py index 944335908ac..3955f935662 100644 --- a/pyzoo/zoo/orca/data/pandas/preprocessing.py +++ b/pyzoo/zoo/orca/data/pandas/preprocessing.py @@ -107,63 +107,31 @@ def read_file_spark(context, file_path, file_type, **kwargs): if not file_paths: raise Exception("The file path is invalid/empty or does not include csv/json files") - rdd = context.parallelize(file_paths, node_num * core_num) - - if prefix == "hdfs": - def loadFile(iterator): - import pandas as pd - import pyarrow as pa - fs = pa.hdfs.connect() - - for x in iterator: - with fs.open(x, 'rb') as f: - if file_type == "csv": - df = pd.read_csv(f, **kwargs) - elif file_type == "json": - df = pd.read_json(f, **kwargs) - else: - raise Exception("Unsupported file type") - yield df - - pd_rdd = rdd.mapPartitions(loadFile) - elif prefix == "s3": - def loadFile(iterator): - access_key_id = os.environ["AWS_ACCESS_KEY_ID"] - secret_access_key = os.environ["AWS_SECRET_ACCESS_KEY"] - import boto3 - import pandas as pd - s3_client = boto3.Session( - aws_access_key_id=access_key_id, - aws_secret_access_key=secret_access_key, - ).client('s3', verify=False) - for x in iterator: - path_parts = x.split("://")[1].split('/') - bucket = path_parts.pop(0) - key = "/".join(path_parts) - obj = s3_client.get_object(Bucket=bucket, Key=key) - if file_type == "json": - df = pd.read_json(obj['Body'], **kwargs) - elif file_type == "csv": - df = pd.read_csv(obj['Body'], **kwargs) - else: - raise Exception("Unsupported file type") - yield df - - pd_rdd = rdd.mapPartitions(loadFile) + num_files = len(file_paths) + total_cores = node_num * core_num + num_partitions = num_files if num_files < total_cores else total_cores + + from pyspark.sql import SQLContext + sqlContext = SQLContext.getOrCreate(context) + spark = sqlContext.sparkSession + # TODO: add S3 confidentials + if file_type == "json": + df = spark.read.json(file_paths).repartition(num_partitions) + elif file_type == "csv": + df = spark.read.csv(file_paths).repartition(num_partitions) else: - def loadFile(iterator): + raise Exception("Unsupported file type") + + def to_pandas(columns): + def f(iter): import pandas as pd - for x in iterator: - if file_type == "csv": - df = pd.read_csv(x, **kwargs) - elif file_type == "json": - df = pd.read_json(x, **kwargs) - else: - raise Exception("Unsupported file type") - yield df + data = list(iter) + yield pd.DataFrame(data, columns=columns) - pd_rdd = rdd.mapPartitions(loadFile) + return f + pd_rdd = df.rdd.mapPartitions(to_pandas(df.columns)) + pd_rdd.map(lambda df: print(len(df))).count() data_shards = SparkXShards(pd_rdd) return data_shards