Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix multi hive-partition parquet reading in dask-cudf #9122

Merged
merged 4 commits into from
Sep 9, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 137 additions & 59 deletions python/dask_cudf/dask_cudf/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from io import BufferedWriter, BytesIO, IOBase

import numpy as np
from pyarrow import parquet as pq
from pyarrow import dataset as pa_ds, parquet as pq

from dask import dataframe as dd
from dask.dataframe.io.parquet.arrow import ArrowDatasetEngine
Expand Down Expand Up @@ -54,9 +54,85 @@ def multi_support(cls):
# and that multi-part reading is supported
return cls == CudfEngine

@staticmethod
@classmethod
def _read_paths(
cls,
paths,
fs,
columns=None,
row_groups=None,
strings_to_categorical=None,
partitions=None,
partitioning=None,
partition_keys=None,
**kwargs,
):

# Use cudf to read in data
df = cudf.read_parquet(
paths,
engine="cudf",
columns=columns,
row_groups=row_groups if row_groups else None,
strings_to_categorical=strings_to_categorical,
**kwargs,
)

if partitions and partition_keys is None:
ds = pa_ds.dataset(
paths,
filesystem=fs,
format="parquet",
partitioning=partitioning["obj"].discover(
*partitioning.get("args", []),
**partitioning.get("kwargs", {}),
),
)
frag = next(ds.get_fragments())
if frag:
# Extract hive-partition keys, and make sure they
# are orderd the same as they are in `partitions`
raw_keys = pa_ds._get_partition_keys(frag.partition_expression)
partition_keys = [
(hive_part.name, raw_keys[hive_part.name])
for hive_part in partitions
]

if partition_keys:
if partitions is None:
raise ValueError("Must pass partition sets")

for i, (name, index2) in enumerate(partition_keys):

# Build the column from `codes` directly
# (since the category is often a larger dtype)
codes = (
as_column(partitions[i].keys.index(index2))
.as_frame()
.repeat(len(df))
._data[None]
)
df[name] = build_categorical_column(
categories=partitions[i].keys,
codes=codes,
size=codes.size,
offset=codes.offset,
ordered=False,
)

return df

@classmethod
def read_partition(
fs, pieces, columns, index, categories=(), partitions=(), **kwargs
cls,
fs,
pieces,
columns,
index,
categories=(),
partitions=(),
partitioning=None,
**kwargs,
):
if columns is not None:
columns = [c for c in columns]
Expand All @@ -67,64 +143,88 @@ def read_partition(
pieces = [pieces]

strings_to_cats = kwargs.get("strings_to_categorical", False)

if len(pieces) > 1:

# Multi-peice read
paths = []
rgs = []
partition_keys = []

for piece in pieces:
if isinstance(piece, str):
paths.append(piece)
rgs.append(None)
else:
(path, row_group, partition_keys) = piece

row_group = None if row_group == [None] else row_group

paths.append(path)
rgs.append(
[row_group]
if not isinstance(row_group, list)
else row_group
last_partition_keys = None
dfs = []

for i, piece in enumerate(pieces):

(path, row_group, partition_keys) = piece
row_group = None if row_group == [None] else row_group

if i > 0 and partition_keys != last_partition_keys:
dfs.append(
cls._read_paths(
paths,
fs,
columns=columns,
row_groups=rgs if rgs else None,
strings_to_categorical=strings_to_cats,
partitions=partitions,
partitioning=partitioning,
partition_keys=last_partition_keys,
**kwargs.get("read", {}),
)
)
paths = rgs = []
last_partition_keys = None
paths.append(path)
rgs.append(
[row_group]
if not isinstance(row_group, list)
else row_group
)
last_partition_keys = partition_keys

df = cudf.read_parquet(
paths,
engine="cudf",
columns=columns,
row_groups=rgs if rgs else None,
strings_to_categorical=strings_to_cats,
**kwargs.get("read", {}),
dfs.append(
cls._read_paths(
paths,
fs,
columns=columns,
row_groups=rgs if rgs else None,
strings_to_categorical=strings_to_cats,
partitions=partitions,
partitioning=partitioning,
partition_keys=last_partition_keys,
**kwargs.get("read", {}),
)
)
df = cudf.concat(dfs)

else:

# Single-piece read
if isinstance(pieces[0], str):
path = pieces[0]
row_group = None
partition_keys = []
else:
(path, row_group, partition_keys) = pieces[0]
row_group = None if row_group == [None] else row_group
(path, row_group, partition_keys) = pieces[0]
row_group = None if row_group == [None] else row_group

if cudf.utils.ioutils._is_local_filesystem(fs):
df = cudf.read_parquet(
df = cls._read_paths(
path,
engine="cudf",
fs,
columns=columns,
row_groups=row_group,
strings_to_categorical=strings_to_cats,
partitions=partitions,
partitioning=partitioning,
partition_keys=partition_keys,
**kwargs.get("read", {}),
)
else:
with fs.open(path, mode="rb") as f:
df = cudf.read_parquet(
df = cls._read_paths(
f,
engine="cudf",
fs,
columns=columns,
row_groups=row_group,
strings_to_categorical=strings_to_cats,
partitions=partitions,
partitioning=partitioning,
partition_keys=partition_keys,
**kwargs.get("read", {}),
)

Expand All @@ -138,28 +238,6 @@ def read_partition(
# names in `columns` are actually in `df.columns`
df.reset_index(inplace=True)

if partition_keys:
if partitions is None:
raise ValueError("Must pass partition sets")

for i, (name, index2) in enumerate(partition_keys):

# Build the column from `codes` directly
# (since the category is often a larger dtype)
codes = (
as_column(partitions[i].keys.index(index2))
.as_frame()
.repeat(len(df))
._data[None]
)
df[name] = build_categorical_column(
categories=partitions[i].keys,
codes=codes,
size=codes.size,
offset=codes.offset,
ordered=False,
)

return df

@staticmethod
Expand Down
10 changes: 10 additions & 0 deletions python/dask_cudf/dask_cudf/io/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,15 @@ def test_roundtrip_from_dask_partitioned(tmpdir, parts, daskcudf, metadata):
if not fn.startswith("_"):
assert "part" in fn

if parse_version(dask.__version__) > parse_version("2021.07.0"):
# This version of Dask supports `aggregate_files=True`.
# Check that we can aggregate by a partition name.
df_read = dd.read_parquet(
tmpdir, engine="pyarrow", aggregate_files="year"
)
gdf_read = dask_cudf.read_parquet(tmpdir, aggregate_files="year")
dd.assert_eq(df_read, gdf_read)


@pytest.mark.parametrize("metadata", [True, False])
@pytest.mark.parametrize("chunksize", [None, 1024, 4096, "1MiB"])
Expand Down Expand Up @@ -327,6 +336,7 @@ def test_chunksize(tmpdir, chunksize, metadata):
split_row_groups=True,
gather_statistics=True,
)
ddf2.compute(scheduler="synchronous")

dd.assert_eq(ddf1, ddf2, check_divisions=False)

Expand Down