Skip to content

Commit

Permalink
Fix Orca read file path (#2525)
Browse files Browse the repository at this point in the history
* resolve conflict

* update

* fix

* meet review
  • Loading branch information
hkvision authored Jul 16, 2020
1 parent 2737981 commit 35149b1
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 32 deletions.
7 changes: 6 additions & 1 deletion pyzoo/test/zoo/orca/data/test_pandas_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@ def test_read_local_csv(self):
file_path = os.path.join(self.resource_path, "abc")
with self.assertRaises(Exception) as context:
xshards = zoo.orca.data.pandas.read_csv(file_path)
self.assertTrue('The file path is invalid/empty' in str(context.exception))
self.assertTrue('No such file or directory' in str(context.exception))
file_path = os.path.join(self.resource_path, "image3d")
with self.assertRaises(Exception) as context:
xshards = zoo.orca.data.pandas.read_csv(file_path)
# This error is raised by pandas.errors.ParserError
self.assertTrue('Error tokenizing data' in str(context.exception))

def test_read_local_json(self):
ZooContext.orca_pandas_read_backend = "pandas"
Expand Down
3 changes: 2 additions & 1 deletion pyzoo/test/zoo/orca/data/test_spark_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ def test_read_invalid_path(self):
file_path = os.path.join(self.resource_path, "abc")
with self.assertRaises(Exception) as context:
xshards = zoo.orca.data.pandas.read_csv(file_path)
self.assertTrue('The file path is invalid/empty' in str(context.exception))
# This error is raised by pyspark.sql.utils.AnalysisException
self.assertTrue('Path does not exist' in str(context.exception))

def test_read_json(self):
file_path = os.path.join(self.resource_path, "orca/data/json")
Expand Down
27 changes: 14 additions & 13 deletions pyzoo/zoo/orca/data/pandas/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,21 @@ def read_json(file_path, **kwargs):

def read_file_spark(file_path, file_type, **kwargs):
sc = init_nncontext()
file_url_splits = file_path.split("://")
prefix = file_url_splits[0]
node_num, core_num = get_node_and_core_number()

file_paths = []
if isinstance(file_path, list):
[file_paths.extend(extract_one_path(path, file_type, os.environ)) for path in file_path]
else:
file_paths = extract_one_path(file_path, file_type, os.environ)
if ZooContext.orca_pandas_read_backend == "pandas":
file_url_splits = file_path.split("://")
prefix = file_url_splits[0]

file_paths = []
if isinstance(file_path, list):
[file_paths.extend(extract_one_path(path, os.environ)) for path in file_path]
else:
file_paths = extract_one_path(file_path, os.environ)

if not file_paths:
raise Exception("The file path is invalid/empty or does not include csv/json files")
if not file_paths:
raise Exception("The file path is invalid or empty, please check your data")

if ZooContext.orca_pandas_read_backend == "pandas":
num_files = len(file_paths)
total_cores = node_num * core_num
num_partitions = num_files if num_files < total_cores else total_cores
Expand All @@ -78,7 +79,7 @@ def loadFile(iterator):
yield df

pd_rdd = rdd.mapPartitions(loadFile)
else: # Spark backend
else: # Spark backend; spark.read.csv/json accepts a folder path as input
assert file_type == "json" or file_type == "csv", \
"Unsupported file type: %s. Only csv and json files are supported for now" % file_type
from pyspark.sql import SQLContext
Expand Down Expand Up @@ -140,12 +141,12 @@ def loadFile(iterator):
comment = kwargs["comment"]
if not isinstance(comment, str) or len(comment) != 1:
raise ValueError("Only length-1 comment characters supported")
df = spark.read.csv(file_paths, **kwargs)
df = spark.read.csv(file_path, **kwargs)
if header is None:
df = df.selectExpr(
*["`%s` as `%s`" % (field.name, i) for i, field in enumerate(df.schema)])
else:
df = spark.read.json(file_paths, **kwargs)
df = spark.read.json(file_path, **kwargs)

# Handle pandas-compatible postprocessing arguments
if isinstance(names, list):
Expand Down
28 changes: 11 additions & 17 deletions pyzoo/zoo/orca/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from zoo.common import get_file_list


def list_s3_file(file_path, file_type, env):
def list_s3_file(file_path, env):
path_parts = file_path.split('/')
bucket = path_parts.pop(0)
key = "/".join(path_parts)
Expand All @@ -40,35 +40,29 @@ def list_s3_file(file_path, file_type, env):
Prefix=key)
for obj in resp['Contents']:
keys.append(obj['Key'])
# only get json/csv files
files = [file for file in keys if os.path.splitext(file)[1] == "." + file_type]
file_paths = [os.path.join("s3://" + bucket, file) for file in files]
file_paths = [os.path.join("s3://" + bucket, file) for file in keys]
return file_paths


def extract_one_path(file_path, file_type, env):
def extract_one_path(file_path, env):
file_url_splits = file_path.split("://")
prefix = file_url_splits[0]
if prefix == "s3":
file_paths = list_s3_file(file_url_splits[1], file_type, env)
file_paths = list_s3_file(file_url_splits[1], env)
elif prefix == "hdfs":
import pyarrow as pa
fs = pa.hdfs.connect()
if fs.isfile(file_path):
return [file_path]
file_paths = [file_path]
else:
file_paths = get_file_list(file_path)
# only get json/csv files
file_paths = [file for file in file_paths
if os.path.splitext(file)[1] == "." + file_type]
else:
if os.path.isfile(file_path):
return [file_path]
else: # Local file path; could be a relative path.
from os.path import isfile, abspath, join
if isfile(file_path):
file_paths = [abspath(file_path)]
else:
file_paths = get_file_list(file_path)
# only get json/csv files
file_paths = [file for file in file_paths
if os.path.splitext(file)[1] == "." + file_type]
# An error would be already raised here if the path is invalid.
file_paths = [abspath(join(file_path, file)) for file in os.listdir(file_path)]
return file_paths


Expand Down

0 comments on commit 35149b1

Please sign in to comment.