Skip to content

Commit

Permalink
Python: Infer Iceberg schema from the Parquet file (apache#6997)
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasJ-ap authored and manisin committed May 9, 2023
1 parent d96a368 commit 0353ed5
Showing 4 changed files with 584 additions and 34 deletions.
219 changes: 207 additions & 12 deletions python/pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=redefined-outer-name,arguments-renamed
# pylint: disable=redefined-outer-name,arguments-renamed,fixme
"""FileIO implementation for reading and writing table files that uses pyarrow.fs
This file contains a FileIO implementation that relies on the filesystem interface provided
@@ -26,19 +26,23 @@

import multiprocessing
import os
from functools import lru_cache
from abc import ABC, abstractmethod
from functools import lru_cache, singledispatch
from multiprocessing.pool import ThreadPool
from multiprocessing.sharedctypes import Synchronized
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Iterable,
List,
Optional,
Set,
Tuple,
TypeVar,
Union,
cast,
)
from urllib.parse import urlparse

@@ -122,6 +126,12 @@
ONE_MEGABYTE = 1024 * 1024
BUFFER_SIZE = "buffer-size"
ICEBERG_SCHEMA = b"iceberg.schema"
FIELD_ID = "field_id"
DOC = "doc"
PYARROW_FIELD_ID_KEYS = [b"PARQUET:field_id", b"field_id"]
PYARROW_FIELD_DOC_KEYS = [b"PARQUET:field_doc", b"field_doc", b"doc"]

T = TypeVar("T")


class PyArrowFile(InputFile, OutputFile):
@@ -357,14 +367,17 @@ def field(self, field: NestedField, field_result: pa.DataType) -> pa.Field:
name=field.name,
type=field_result,
nullable=field.optional,
metadata={"doc": field.doc, "id": str(field.field_id)} if field.doc else {},
metadata={DOC: field.doc, FIELD_ID: str(field.field_id)} if field.doc else {FIELD_ID: str(field.field_id)},
)

def list(self, _: ListType, element_result: pa.DataType) -> pa.DataType:
return pa.list_(value_type=element_result)
def list(self, list_type: ListType, element_result: pa.DataType) -> pa.DataType:
element_field = self.field(list_type.element_field, element_result)
return pa.list_(value_type=element_field)

def map(self, _: MapType, key_result: pa.DataType, value_result: pa.DataType) -> pa.DataType:
return pa.map_(key_type=key_result, item_type=value_result)
def map(self, map_type: MapType, key_result: pa.DataType, value_result: pa.DataType) -> pa.DataType:
key_field = self.field(map_type.key_field, key_result)
value_field = self.field(map_type.value_field, value_result)
return pa.map_(key_type=key_field, item_type=value_field)

def visit_fixed(self, fixed_type: FixedType) -> pa.DataType:
return pa.binary(len(fixed_type))
@@ -485,6 +498,190 @@ def expression_to_pyarrow(expr: BooleanExpression) -> pc.Expression:
return boolean_expression_visit(expr, _ConvertToArrowExpression())


def pyarrow_to_schema(schema: pa.Schema) -> Schema:
visitor = _ConvertToIceberg()
return visit_pyarrow(schema, visitor)


@singledispatch
def visit_pyarrow(obj: pa.DataType | pa.Schema, visitor: PyArrowSchemaVisitor[T]) -> T:
"""A generic function for applying a pyarrow schema visitor to any point within a schema
The function traverses the schema in post-order fashion
Args:
obj(pa.DataType): An instance of a Schema or an IcebergType
visitor (PyArrowSchemaVisitor[T]): An instance of an implementation of the generic PyarrowSchemaVisitor base class
Raises:
NotImplementedError: If attempting to visit an unrecognized object type
"""
raise NotImplementedError("Cannot visit non-type: %s" % obj)


@visit_pyarrow.register(pa.Schema)
def _(obj: pa.Schema, visitor: PyArrowSchemaVisitor[T]) -> Optional[T]:
struct_results: List[Optional[T]] = []
for field in obj:
visitor.before_field(field)
struct_result = visit_pyarrow(field.type, visitor)
visitor.after_field(field)
struct_results.append(struct_result)

return visitor.schema(obj, struct_results)


@visit_pyarrow.register(pa.StructType)
def _(obj: pa.StructType, visitor: PyArrowSchemaVisitor[T]) -> Optional[T]:
struct_results: List[Optional[T]] = []
for field in obj:
visitor.before_field(field)
struct_result = visit_pyarrow(field.type, visitor)
visitor.after_field(field)
struct_results.append(struct_result)

return visitor.struct(obj, struct_results)


@visit_pyarrow.register(pa.ListType)
def _(obj: pa.ListType, visitor: PyArrowSchemaVisitor[T]) -> Optional[T]:
visitor.before_field(obj.value_field)
list_result = visit_pyarrow(obj.value_field.type, visitor)
visitor.after_field(obj.value_field)
return visitor.list(obj, list_result)


@visit_pyarrow.register(pa.MapType)
def _(obj: pa.MapType, visitor: PyArrowSchemaVisitor[T]) -> Optional[T]:
visitor.before_field(obj.key_field)
key_result = visit_pyarrow(obj.key_field.type, visitor)
visitor.after_field(obj.key_field)
visitor.before_field(obj.item_field)
value_result = visit_pyarrow(obj.item_field.type, visitor)
visitor.after_field(obj.item_field)
return visitor.map(obj, key_result, value_result)


@visit_pyarrow.register(pa.DataType)
def _(obj: pa.DataType, visitor: PyArrowSchemaVisitor[T]) -> Optional[T]:
if pa.types.is_nested(obj):
raise TypeError(f"Expected primitive type, got: {type(obj)}")
return visitor.primitive(obj)


class PyArrowSchemaVisitor(Generic[T], ABC):
def before_field(self, field: pa.Field) -> None:
"""Override this method to perform an action immediately before visiting a field."""

def after_field(self, field: pa.Field) -> None:
"""Override this method to perform an action immediately after visiting a field."""

@abstractmethod
def schema(self, schema: pa.Schema, field_results: List[Optional[T]]) -> Optional[T]:
"""visit a schema"""

@abstractmethod
def struct(self, struct: pa.StructType, field_results: List[Optional[T]]) -> Optional[T]:
"""visit a struct"""

@abstractmethod
def list(self, list_type: pa.ListType, element_result: Optional[T]) -> Optional[T]:
"""visit a list"""

@abstractmethod
def map(self, map_type: pa.MapType, key_result: Optional[T], value_result: Optional[T]) -> Optional[T]:
"""visit a map"""

@abstractmethod
def primitive(self, primitive: pa.DataType) -> Optional[T]:
"""visit a primitive type"""


def _get_field_id(field: pa.Field) -> Optional[int]:
for pyarrow_field_id_key in PYARROW_FIELD_ID_KEYS:
if field_id_str := field.metadata.get(pyarrow_field_id_key):
return int(field_id_str.decode())
return None


def _get_field_doc(field: pa.Field) -> Optional[str]:
for pyarrow_doc_key in PYARROW_FIELD_DOC_KEYS:
if doc_str := field.metadata.get(pyarrow_doc_key):
return doc_str.decode()
return None


class _ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]):
def _convert_fields(self, arrow_fields: Iterable[pa.Field], field_results: List[Optional[IcebergType]]) -> List[NestedField]:
fields = []
for i, field in enumerate(arrow_fields):
field_id = _get_field_id(field)
field_doc = _get_field_doc(field)
field_type = field_results[i]
if field_type is not None and field_id is not None:
fields.append(NestedField(field_id, field.name, field_type, required=not field.nullable, doc=field_doc))
return fields

def schema(self, schema: pa.Schema, field_results: List[Optional[IcebergType]]) -> Schema:
return Schema(*self._convert_fields(schema, field_results))

def struct(self, struct: pa.StructType, field_results: List[Optional[IcebergType]]) -> IcebergType:
return StructType(*self._convert_fields(struct, field_results))

def list(self, list_type: pa.ListType, element_result: Optional[IcebergType]) -> Optional[IcebergType]:
element_field = list_type.value_field
element_id = _get_field_id(element_field)
if element_result is not None and element_id is not None:
return ListType(element_id, element_result, element_required=not element_field.nullable)
return None

def map(
self, map_type: pa.MapType, key_result: Optional[IcebergType], value_result: Optional[IcebergType]
) -> Optional[IcebergType]:
key_field = map_type.key_field
key_id = _get_field_id(key_field)
value_field = map_type.item_field
value_id = _get_field_id(value_field)
if key_result is not None and value_result is not None and key_id is not None and value_id is not None:
return MapType(key_id, key_result, value_id, value_result, value_required=not value_field.nullable)
return None

def primitive(self, primitive: pa.DataType) -> IcebergType:
if pa.types.is_boolean(primitive):
return BooleanType()
elif pa.types.is_int32(primitive):
return IntegerType()
elif pa.types.is_int64(primitive):
return LongType()
elif pa.types.is_float32(primitive):
return FloatType()
elif pa.types.is_float64(primitive):
return DoubleType()
elif isinstance(primitive, pa.Decimal128Type):
primitive = cast(pa.Decimal128Type, primitive)
return DecimalType(primitive.precision, primitive.scale)
elif pa.types.is_string(primitive):
return StringType()
elif pa.types.is_date32(primitive):
return DateType()
elif isinstance(primitive, pa.Time64Type) and primitive.unit == "us":
return TimeType()
elif pa.types.is_timestamp(primitive):
primitive = cast(pa.TimestampType, primitive)
if primitive.unit == "us":
if primitive.tz == "UTC" or primitive.tz == "+00:00":
return TimestamptzType()
elif primitive.tz is None:
return TimestampType()
elif pa.types.is_binary(primitive):
return BinaryType()
elif pa.types.is_fixed_size_binary(primitive):
primitive = cast(pa.FixedSizeBinaryType, primitive)
return FixedType(primitive.byte_width)

raise TypeError(f"Unsupported type: {primitive}")


def _file_to_table(
fs: FileSystem,
task: FileScanTask,
@@ -506,11 +703,9 @@ def _file_to_table(
schema_raw = None
if metadata := physical_schema.metadata:
schema_raw = metadata.get(ICEBERG_SCHEMA)
if schema_raw is None:
raise ValueError(
"Iceberg schema is not embedded into the Parquet file, see https://github.com/apache/iceberg/issues/6505"
)
file_schema = Schema.parse_raw(schema_raw)
# TODO: if field_ids are not present, Name Mapping should be implemented to look them up in the table schema,
# see https://github.com/apache/iceberg/issues/7451
file_schema = Schema.parse_raw(schema_raw) if schema_raw is not None else pyarrow_to_schema(physical_schema)

pyarrow_filter = None
if bound_row_filter is not AlwaysTrue():
Loading

0 comments on commit 0353ed5

Please sign in to comment.