Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Use spark.read.csv/json #2473

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 21 additions & 53 deletions pyzoo/zoo/orca/data/pandas/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,63 +107,31 @@ def read_file_spark(context, file_path, file_type, **kwargs):
if not file_paths:
raise Exception("The file path is invalid/empty or does not include csv/json files")

rdd = context.parallelize(file_paths, node_num * core_num)

if prefix == "hdfs":
def loadFile(iterator):
import pandas as pd
import pyarrow as pa
fs = pa.hdfs.connect()

for x in iterator:
with fs.open(x, 'rb') as f:
if file_type == "csv":
df = pd.read_csv(f, **kwargs)
elif file_type == "json":
df = pd.read_json(f, **kwargs)
else:
raise Exception("Unsupported file type")
yield df

pd_rdd = rdd.mapPartitions(loadFile)
elif prefix == "s3":
def loadFile(iterator):
access_key_id = os.environ["AWS_ACCESS_KEY_ID"]
secret_access_key = os.environ["AWS_SECRET_ACCESS_KEY"]
import boto3
import pandas as pd
s3_client = boto3.Session(
aws_access_key_id=access_key_id,
aws_secret_access_key=secret_access_key,
).client('s3', verify=False)
for x in iterator:
path_parts = x.split("://")[1].split('/')
bucket = path_parts.pop(0)
key = "/".join(path_parts)
obj = s3_client.get_object(Bucket=bucket, Key=key)
if file_type == "json":
df = pd.read_json(obj['Body'], **kwargs)
elif file_type == "csv":
df = pd.read_csv(obj['Body'], **kwargs)
else:
raise Exception("Unsupported file type")
yield df

pd_rdd = rdd.mapPartitions(loadFile)
num_files = len(file_paths)
total_cores = node_num * core_num
num_partitions = num_files if num_files < total_cores else total_cores

from pyspark.sql import SQLContext
sqlContext = SQLContext.getOrCreate(context)
spark = sqlContext.sparkSession
# TODO: add S3 confidentials
if file_type == "json":
df = spark.read.json(file_paths).repartition(num_partitions)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to repartition?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure. How many partitions would spark.read.json have for each file? (When I run locally each file has three partitions..)
Or we can let user specify the number of partitions?
Otherwise, if the num partitions is small, then after reading and converting to a pandas df for each partition, there is no way to split the dataframes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compare with node number

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or we need to support splitting dataframes for repartition

elif file_type == "csv":
df = spark.read.csv(file_paths).repartition(num_partitions)
else:
def loadFile(iterator):
raise Exception("Unsupported file type")

def to_pandas(columns):
def f(iter):
import pandas as pd
for x in iterator:
if file_type == "csv":
df = pd.read_csv(x, **kwargs)
elif file_type == "json":
df = pd.read_json(x, **kwargs)
else:
raise Exception("Unsupported file type")
yield df
data = list(iter)
yield pd.DataFrame(data, columns=columns)

pd_rdd = rdd.mapPartitions(loadFile)
return f

pd_rdd = df.rdd.mapPartitions(to_pandas(df.columns))
pd_rdd.map(lambda df: print(len(df))).count()
data_shards = SparkXShards(pd_rdd)
return data_shards

Expand Down