Skip to content

Commit

Permalink
[Python] Create PyArrow dataset fragments from delta log (#525)
Browse files Browse the repository at this point in the history
  • Loading branch information
wjones127 authored Jan 6, 2022
1 parent 8e69e2d commit 3c83a1d
Show file tree
Hide file tree
Showing 13 changed files with 211 additions and 50 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
/target
**/target
*.sw*
tlaplus/*.toolbox/*_SnapShot_*/
tlaplus/*.toolbox/*_SnapShot_*.launch
Expand Down
6 changes: 6 additions & 0 deletions python/.cargo/config
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,9 @@ rustflags = [
"-C", "link-arg=-undefined",
"-C", "link-arg=dynamic_lookup",
]

[target.aarch64-apple-darwin]
rustflags = [
"-C", "link-arg=-undefined",
"-C", "link-arg=dynamic_lookup",
]
38 changes: 15 additions & 23 deletions python/deltalake/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from urllib.parse import urlparse

import pyarrow
import pyarrow.fs as pa_fs
from pyarrow.dataset import dataset, partitioning
from pyarrow.dataset import FileSystemDataset, ParquetFileFormat

if TYPE_CHECKING:
import pandas
Expand Down Expand Up @@ -260,31 +259,24 @@ def to_pyarrow_dataset(
:param filesystem: A concrete implementation of the Pyarrow FileSystem or a fsspec-compatible interface. If None, the first file path will be used to determine the right FileSystem
:return: the PyArrow dataset in PyArrow
"""
if not partitions:
file_paths = self._table.file_uris()
else:
file_paths = self._table.files_by_partitions(partitions)

empty_delta_table = len(file_paths) == 0
if empty_delta_table:
return dataset(
[],
schema=self.pyarrow_schema(),
partitioning=partitioning(flavor="hive"),
)

parsed = urlparse(file_paths[0])
if not filesystem and parsed.netloc:
if not filesystem:
filesystem = pa_fs.PyFileSystem(
DeltaStorageHandler(self._table.table_uri())
)

return dataset(
file_paths,
schema=self.pyarrow_schema(),
format="parquet",
filesystem=filesystem,
partitioning=partitioning(flavor="hive"),
format = ParquetFileFormat()

fragments = [
format.make_fragment(
file,
filesystem=filesystem,
partition_expression=part_expression,
)
for file, part_expression in self._table.dataset_partitions(partitions)
]

return FileSystemDataset(
fragments, self.pyarrow_schema(), format, filesystem
)

def to_pyarrow_table(
Expand Down
120 changes: 115 additions & 5 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
extern crate pyo3;

use chrono::{DateTime, FixedOffset, Utc};
use deltalake::action::Stats;
use deltalake::action::{ColumnCountStat, ColumnValueStat};
use deltalake::arrow::datatypes::Schema as ArrowSchema;
use deltalake::partitions::PartitionFilter;
use deltalake::storage;
Expand All @@ -12,6 +14,7 @@ use pyo3::exceptions::PyException;
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyTuple, PyType};
use std::collections::HashMap;
use std::collections::HashSet;
use std::convert::TryFrom;

create_exception!(deltalake, PyDeltaTableError, PyException);
Expand All @@ -38,10 +41,7 @@ impl PyDeltaTableError {
}

fn from_chrono(err: chrono::ParseError) -> pyo3::PyErr {
PyDeltaTableError::new_err(format!(
"Parse date and time string failed: {}",
err.to_string()
))
PyDeltaTableError::new_err(format!("Parse date and time string failed: {}", err))
}
}

Expand Down Expand Up @@ -183,7 +183,7 @@ impl RawDeltaTable {
}

pub fn file_uris(&self) -> PyResult<Vec<String>> {
Ok(self._table.get_file_uris())
Ok(self._table.get_file_uris().collect())
}

pub fn schema_json(&self) -> PyResult<String> {
Expand Down Expand Up @@ -231,6 +231,116 @@ impl RawDeltaTable {
.block_on(self._table.update_incremental())
.map_err(PyDeltaTableError::from_raw)
}

pub fn dataset_partitions<'py>(
&mut self,
py: Python<'py>,
partition_filters: Option<Vec<(&str, &str, PartitionFilterValue)>>,
) -> PyResult<Vec<(String, Option<&'py PyAny>)>> {
let path_set = match partition_filters {
Some(filters) => Some(HashSet::<_>::from_iter(
self.files_by_partitions(filters)?.iter().cloned(),
)),
None => None,
};

self._table
.get_file_uris()
.zip(self._table.get_partition_values())
.zip(self._table.get_stats())
.filter(|((path, _), _)| match &path_set {
Some(path_set) => path_set.contains(path),
None => true,
})
.map(|((path, partition_values), stats)| {
let stats = stats.map_err(PyDeltaTableError::from_raw)?;
let expression = filestats_to_expression(py, partition_values, stats)?;
Ok((path, expression))
})
.collect()
}
}

fn json_value_to_py(value: &serde_json::Value, py: Python) -> PyObject {
match value {
serde_json::Value::Null => py.None(),
serde_json::Value::Bool(val) => val.to_object(py),
serde_json::Value::Number(val) => {
if val.is_f64() {
val.as_f64().expect("not an f64").to_object(py)
} else if val.is_i64() {
val.as_i64().expect("not an i64").to_object(py)
} else {
val.as_u64().expect("not an u64").to_object(py)
}
}
serde_json::Value::String(val) => val.to_object(py),
_ => py.None(),
}
}

/// Create expression that file statistics guarantee to be true.
///
/// PyArrow uses this expression to determine which Dataset fragments may be
/// skipped during a scan.
fn filestats_to_expression<'py>(
py: Python<'py>,
partitions_values: &HashMap<String, Option<String>>,
stats: Option<Stats>,
) -> PyResult<Option<&'py PyAny>> {
let ds = PyModule::import(py, "pyarrow.dataset")?;
let field = ds.getattr("field")?;
let mut expressions: Vec<PyResult<&PyAny>> = Vec::new();

for (column, value) in partitions_values.iter() {
if let Some(value) = value {
expressions.push(
field
.call1((column,))?
.call_method1("__eq__", (value.to_object(py),)),
);
}
}

if let Some(stats) = stats {
for (column, minimum) in stats.min_values.iter().filter_map(|(k, v)| match v {
ColumnValueStat::Value(val) => Some((k.clone(), json_value_to_py(val, py))),
// TODO(wjones127): Handle nested field statistics.
// Blocked on https://issues.apache.org/jira/browse/ARROW-11259
_ => None,
}) {
expressions.push(field.call1((column,))?.call_method1("__ge__", (minimum,)));
}

for (column, maximum) in stats.max_values.iter().filter_map(|(k, v)| match v {
ColumnValueStat::Value(val) => Some((k.clone(), json_value_to_py(val, py))),
_ => None,
}) {
expressions.push(field.call1((column,))?.call_method1("__le__", (maximum,)));
}

for (column, null_count) in stats.null_count.iter().filter_map(|(k, v)| match v {
ColumnCountStat::Value(val) => Some((k, val)),
_ => None,
}) {
if *null_count == stats.num_records {
expressions.push(field.call1((column.clone(),))?.call_method0("is_null"));
}

if *null_count == 0 {
expressions.push(field.call1((column.clone(),))?.call_method0("is_valid"));
}
}
}

if expressions.is_empty() {
Ok(None)
} else {
expressions
.into_iter()
.reduce(|accum, item| accum?.getattr("__and__")?.call1((item?,)))
.transpose()
}
}

#[pyclass]
Expand Down
2 changes: 2 additions & 0 deletions python/stubs/pyarrow/dataset.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ from typing import Any
Dataset: Any
dataset: Any
partitioning: Any
FileSystemDataset: Any
ParquetFileFormat: Any
57 changes: 53 additions & 4 deletions python/tests/test_table_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from threading import Barrier, Thread

import pandas as pd
import pyarrow.dataset as ds
import pytest
from pyarrow.fs import LocalFileSystem

Expand All @@ -17,7 +18,8 @@ def test_read_simple_table_to_dict():
def test_read_simple_table_by_version_to_dict():
table_path = "../rust/tests/data/delta-0.2.0"
dt = DeltaTable(table_path, version=2)
assert dt.to_pyarrow_dataset().to_table().to_pydict() == {"value": [1, 2, 3]}
assert dt.to_pyarrow_dataset().to_table().to_pydict() == {
"value": [1, 2, 3]}


def test_load_with_datetime():
Expand Down Expand Up @@ -69,7 +71,8 @@ def test_load_with_datetime_bad_format():
def test_read_simple_table_update_incremental():
table_path = "../rust/tests/data/simple_table"
dt = DeltaTable(table_path, version=0)
assert dt.to_pyarrow_dataset().to_table().to_pydict() == {"id": [0, 1, 2, 3, 4]}
assert dt.to_pyarrow_dataset().to_table().to_pydict() == {
"id": [0, 1, 2, 3, 4]}
dt.update_incremental()
assert dt.to_pyarrow_dataset().to_table().to_pydict() == {"id": [5, 7, 9]}

Expand Down Expand Up @@ -121,6 +124,50 @@ def test_read_table_with_column_subset():
)


def test_read_table_with_filter():
table_path = "../rust/tests/data/delta-0.8.0-partitioned"
dt = DeltaTable(table_path)
expected = {
"value": ["6", "7", "5"],
"year": ["2021", "2021", "2021"],
"month": ["12", "12", "12"],
"day": ["20", "20", "4"],
}
filter_expr = (ds.field("year") == "2021") & (ds.field("month") == "12")

dataset = dt.to_pyarrow_dataset()

assert len(list(dataset.get_fragments(filter=filter_expr))) == 2
assert dataset.to_table(filter=filter_expr).to_pydict() == expected


def test_read_table_with_stats():
table_path = "../rust/tests/data/COVID-19_NYT"
dt = DeltaTable(table_path)
dataset = dt.to_pyarrow_dataset()

filter_expr = ds.field("date") > "2021-02-20"
assert len(list(dataset.get_fragments(filter=filter_expr))) == 2

data = dataset.to_table(filter=filter_expr)
assert data.num_rows < 147181 + 47559

filter_expr = ds.field("cases") < 0
assert len(list(dataset.get_fragments(filter=filter_expr))) == 0

data = dataset.to_table(filter=filter_expr)
assert data.num_rows == 0

# TODO(wjones127): Enable these tests once C++ Arrow implements is_null and is_valid
# simplification. Blocked on: https://issues.apache.org/jira/browse/ARROW-12659

# filter_expr = ds.field("cases").is_null()
# assert len(list(dataset.get_fragments(filter=filter_expr))) == 0

# data = dataset.to_table(filter=filter_expr)
# assert data.num_rows == 0


def test_vacuum_dry_run_simple_table():
table_path = "../rust/tests/data/delta-0.2.0"
dt = DeltaTable(table_path)
Expand Down Expand Up @@ -258,7 +305,8 @@ def test_delta_table_with_filesystem():
table_path = "../rust/tests/data/simple_table"
dt = DeltaTable(table_path)
filesystem = LocalFileSystem()
assert dt.to_pandas(filesystem=filesystem).equals(pd.DataFrame({"id": [5, 7, 9]}))
assert dt.to_pandas(filesystem=filesystem).equals(
pd.DataFrame({"id": [5, 7, 9]}))


def test_import_delta_table_error():
Expand Down Expand Up @@ -347,7 +395,8 @@ def read_table():
"part-00000-2befed33-c358-4768-a43c-3eda0d2a499d-c000.snappy.parquet",
]

threads = [ExcPassThroughThread(target=read_table) for _ in range(thread_count)]
threads = [ExcPassThroughThread(target=read_table)
for _ in range(thread_count)]
for t in threads:
t.start()
for t in threads:
Expand Down
4 changes: 1 addition & 3 deletions rust/src/action.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ fn decode_path(raw_path: &str) -> Result<String, ActionError> {
percent_decode(raw_path.as_bytes())
.decode_utf8()
.map(|c| c.to_string())
.map_err(|e| {
ActionError::InvalidField(format!("Decode path failed for action: {}", e.to_string(),))
})
.map_err(|e| ActionError::InvalidField(format!("Decode path failed for action: {}", e,)))
}

/// Struct used to represent minValues and maxValues in add action statistics.
Expand Down
2 changes: 1 addition & 1 deletion rust/src/bin/delta-inspect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async fn main() -> anyhow::Result<()> {
};

if files_matches.is_present("full_uri") {
table.get_file_uris().iter().for_each(|f| println!("{}", f));
table.get_file_uris().for_each(|f| println!("{}", f));
} else {
table.get_files_iter().for_each(|f| println!("{}", f));
};
Expand Down
15 changes: 10 additions & 5 deletions rust/src/delta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1009,25 +1009,30 @@ impl DeltaTable {
note = "Please use the get_file_uris function instead"
)]
pub fn get_file_paths(&self) -> Vec<String> {
self.get_file_uris()
self.get_file_uris().collect()
}

/// Returns a URIs for all active files present in the current table version.
pub fn get_file_uris(&self) -> Vec<String> {
pub fn get_file_uris(&self) -> impl Iterator<Item = String> + '_ {
self.state
.files()
.iter()
.map(|add| self.storage.join_path(&self.table_uri, &add.path))
.collect()
}

/// Returns statistics for files, in order
pub fn get_stats(&self) -> Vec<Result<Option<Stats>, DeltaTableError>> {
pub fn get_stats(&self) -> impl Iterator<Item = Result<Option<Stats>, DeltaTableError>> + '_ {
self.state
.files()
.iter()
.map(|add| add.get_stats().map_err(DeltaTableError::from))
.collect()
}

/// Returns partition values for files, in order
pub fn get_partition_values(
&self,
) -> impl Iterator<Item = &HashMap<String, Option<String>>> + '_ {
self.state.files().iter().map(|add| &add.partition_values)
}

/// Returns the currently loaded state snapshot.
Expand Down
Loading

0 comments on commit 3c83a1d

Please sign in to comment.