Skip to content

Commit

Permalink
Cast 's', 'ms' and 'ns' PyArrow timestamp to 'us' precision on write (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
sungwy authored Jul 10, 2024
1 parent 3f574d3 commit 301e336
Show file tree
Hide file tree
Showing 7 changed files with 224 additions and 45 deletions.
6 changes: 5 additions & 1 deletion mkdocs/docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -305,4 +305,8 @@ PyIceberg uses multiple threads to parallelize operations. The number of workers

# Backward Compatibility

Previous versions of Java (`<1.4.0`) implementations incorrectly assume the optional attribute `current-snapshot-id` to be a required attribute in TableMetadata. This means that if `current-snapshot-id` is missing in the metadata file (e.g. on table creation), the application will throw an exception without being able to load the table. This assumption has been corrected in more recent Iceberg versions. However, it is possible to force PyIceberg to create a table with a metadata file that will be compatible with previous versions. This can be configured by setting the `legacy-current-snapshot-id` entry as "True" in the configuration file, or by setting the `PYICEBERG_LEGACY_CURRENT_SNAPSHOT_ID` environment variable. Refer to the [PR discussion](https://github.com/apache/iceberg-python/pull/473) for more details on the issue
Previous versions of Java (`<1.4.0`) implementations incorrectly assume the optional attribute `current-snapshot-id` to be a required attribute in TableMetadata. This means that if `current-snapshot-id` is missing in the metadata file (e.g. on table creation), the application will throw an exception without being able to load the table. This assumption has been corrected in more recent Iceberg versions. However, it is possible to force PyIceberg to create a table with a metadata file that will be compatible with previous versions. This can be configured by setting the `legacy-current-snapshot-id` property as "True" in the configuration file, or by setting the `PYICEBERG_LEGACY_CURRENT_SNAPSHOT_ID` environment variable. Refer to the [PR discussion](https://github.com/apache/iceberg-python/pull/473) for more details on the issue

# Nanoseconds Support

PyIceberg currently only supports upto microsecond precision in its TimestampType. PyArrow timestamp types in 's' and 'ms' will be upcast automatically to 'us' precision timestamps on write. Timestamps in 'ns' precision can also be downcast automatically on write if desired. This can be configured by setting the `downcast-ns-timestamp-to-us-on-write` property as "True" in the configuration file, or by setting the `PYICEBERG_DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE` environment variable. Refer to the [nanoseconds timestamp proposal document](https://docs.google.com/document/d/1bE1DcEGNzZAMiVJSZ0X1wElKLNkT9kRkk0hDlfkXzvU/edit#heading=h.ibflcctc9i1d) for more details on the long term roadmap for nanoseconds support
6 changes: 5 additions & 1 deletion pyiceberg/catalog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from pyiceberg.schema import Schema
from pyiceberg.serializers import ToOutputFile
from pyiceberg.table import (
DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE,
CommitTableRequest,
CommitTableResponse,
CreateTableTransaction,
Expand Down Expand Up @@ -675,8 +676,11 @@ def _convert_schema_if_needed(schema: Union[Schema, "pa.Schema"]) -> Schema:

from pyiceberg.io.pyarrow import _ConvertToIcebergWithoutIDs, visit_pyarrow

downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
if isinstance(schema, pa.Schema):
schema: Schema = visit_pyarrow(schema, _ConvertToIcebergWithoutIDs()) # type: ignore
schema: Schema = visit_pyarrow( # type: ignore
schema, _ConvertToIcebergWithoutIDs(downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
)
return schema
except ModuleNotFoundError:
pass
Expand Down
83 changes: 62 additions & 21 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@
UUIDType,
)
from pyiceberg.utils.concurrent import ExecutorFactory
from pyiceberg.utils.config import Config
from pyiceberg.utils.datetime import millis_to_datetime
from pyiceberg.utils.singleton import Singleton
from pyiceberg.utils.truncate import truncate_upper_bound_binary_string, truncate_upper_bound_text_string
Expand Down Expand Up @@ -470,7 +471,9 @@ def __setstate__(self, state: Dict[str, Any]) -> None:


def schema_to_pyarrow(
schema: Union[Schema, IcebergType], metadata: Dict[bytes, bytes] = EMPTY_DICT, include_field_ids: bool = True
schema: Union[Schema, IcebergType],
metadata: Dict[bytes, bytes] = EMPTY_DICT,
include_field_ids: bool = True,
) -> pa.schema:
return visit(schema, _ConvertToArrowSchema(metadata, include_field_ids))

Expand Down Expand Up @@ -663,21 +666,23 @@ def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], start
return np.subtract(np.setdiff1d(np.arange(start_index, end_index), all_chunks, assume_unique=False), start_index)


def pyarrow_to_schema(schema: pa.Schema, name_mapping: Optional[NameMapping] = None) -> Schema:
def pyarrow_to_schema(
schema: pa.Schema, name_mapping: Optional[NameMapping] = None, downcast_ns_timestamp_to_us: bool = False
) -> Schema:
has_ids = visit_pyarrow(schema, _HasIds())
if has_ids:
visitor = _ConvertToIceberg()
visitor = _ConvertToIceberg(downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
elif name_mapping is not None:
visitor = _ConvertToIceberg(name_mapping=name_mapping)
visitor = _ConvertToIceberg(name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
else:
raise ValueError(
"Parquet file does not have field-ids and the Iceberg table does not have 'schema.name-mapping.default' defined"
)
return visit_pyarrow(schema, visitor)


def _pyarrow_to_schema_without_ids(schema: pa.Schema) -> Schema:
return visit_pyarrow(schema, _ConvertToIcebergWithoutIDs())
def _pyarrow_to_schema_without_ids(schema: pa.Schema, downcast_ns_timestamp_to_us: bool = False) -> Schema:
return visit_pyarrow(schema, _ConvertToIcebergWithoutIDs(downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us))


def _pyarrow_schema_ensure_large_types(schema: pa.Schema) -> pa.Schema:
Expand Down Expand Up @@ -849,9 +854,10 @@ class _ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]):
_field_names: List[str]
_name_mapping: Optional[NameMapping]

def __init__(self, name_mapping: Optional[NameMapping] = None) -> None:
def __init__(self, name_mapping: Optional[NameMapping] = None, downcast_ns_timestamp_to_us: bool = False) -> None:
self._field_names = []
self._name_mapping = name_mapping
self._downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us

def _field_id(self, field: pa.Field) -> int:
if self._name_mapping:
Expand Down Expand Up @@ -918,11 +924,24 @@ def primitive(self, primitive: pa.DataType) -> PrimitiveType:
return TimeType()
elif pa.types.is_timestamp(primitive):
primitive = cast(pa.TimestampType, primitive)
if primitive.unit == "us":
if primitive.tz == "UTC" or primitive.tz == "+00:00":
return TimestamptzType()
elif primitive.tz is None:
return TimestampType()
if primitive.unit in ("s", "ms", "us"):
# Supported types, will be upcast automatically to 'us'
pass
elif primitive.unit == "ns":
if self._downcast_ns_timestamp_to_us:
logger.warning("Iceberg does not yet support 'ns' timestamp precision. Downcasting to 'us'.")
else:
raise TypeError(
"Iceberg does not yet support 'ns' timestamp precision. Use 'downcast-ns-timestamp-to-us-on-write' configuration property to automatically downcast 'ns' to 'us' on write."
)
else:
raise TypeError(f"Unsupported precision for timestamp type: {primitive.unit}")

if primitive.tz == "UTC" or primitive.tz == "+00:00":
return TimestamptzType()
elif primitive.tz is None:
return TimestampType()

elif pa.types.is_binary(primitive) or pa.types.is_large_binary(primitive):
return BinaryType()
elif pa.types.is_fixed_size_binary(primitive):
Expand Down Expand Up @@ -1010,8 +1029,11 @@ def _task_to_record_batches(
with fs.open_input_file(path) as fin:
fragment = arrow_format.make_fragment(fin)
physical_schema = fragment.physical_schema
file_schema = pyarrow_to_schema(physical_schema, name_mapping)

# In V1 and V2 table formats, we only support Timestamp 'us' in Iceberg Schema
# Hence it is reasonable to always cast 'ns' timestamp to 'us' on read.
# When V3 support is introduced, we will update `downcast_ns_timestamp_to_us` flag based on
# the table format version.
file_schema = pyarrow_to_schema(physical_schema, name_mapping, downcast_ns_timestamp_to_us=True)
pyarrow_filter = None
if bound_row_filter is not AlwaysTrue():
translated_row_filter = translate_column_names(bound_row_filter, file_schema, case_sensitive=case_sensitive)
Expand Down Expand Up @@ -1049,7 +1071,7 @@ def _task_to_record_batches(
arrow_table = pa.Table.from_batches([batch])
arrow_table = arrow_table.filter(pyarrow_filter)
batch = arrow_table.to_batches()[0]
yield to_requested_schema(projected_schema, file_project_schema, batch)
yield to_requested_schema(projected_schema, file_project_schema, batch, downcast_ns_timestamp_to_us=True)
current_index += len(batch)


Expand Down Expand Up @@ -1248,8 +1270,12 @@ def project_batches(
total_row_count += len(batch)


def to_requested_schema(requested_schema: Schema, file_schema: Schema, batch: pa.RecordBatch) -> pa.RecordBatch:
struct_array = visit_with_partner(requested_schema, batch, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema))
def to_requested_schema(
requested_schema: Schema, file_schema: Schema, batch: pa.RecordBatch, downcast_ns_timestamp_to_us: bool = False
) -> pa.RecordBatch:
struct_array = visit_with_partner(
requested_schema, batch, ArrowProjectionVisitor(file_schema, downcast_ns_timestamp_to_us), ArrowAccessor(file_schema)
)

arrays = []
fields = []
Expand All @@ -1263,8 +1289,9 @@ def to_requested_schema(requested_schema: Schema, file_schema: Schema, batch: pa
class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Array]]):
file_schema: Schema

def __init__(self, file_schema: Schema):
def __init__(self, file_schema: Schema, downcast_ns_timestamp_to_us: bool = False):
self.file_schema = file_schema
self.downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us

def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
file_field = self.file_schema.find_field(field.field_id)
Expand All @@ -1275,7 +1302,15 @@ def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
# if file_field and field_type (e.g. String) are the same
# but the pyarrow type of the array is different from the expected type
# (e.g. string vs larger_string), we want to cast the array to the larger type
return values.cast(target_type)
safe = True
if (
pa.types.is_timestamp(target_type)
and target_type.unit == "us"
and pa.types.is_timestamp(values.type)
and values.type.unit == "ns"
):
safe = False
return values.cast(target_type, safe=safe)
return values

def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Field:
Expand Down Expand Up @@ -1899,7 +1934,7 @@ def data_file_statistics_from_parquet_metadata(


def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:
from pyiceberg.table import PropertyUtil, TableProperties
from pyiceberg.table import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE, PropertyUtil, TableProperties

parquet_writer_kwargs = _get_parquet_writer_kwargs(table_metadata.properties)
row_group_size = PropertyUtil.property_as_int(
Expand All @@ -1918,8 +1953,14 @@ def write_parquet(task: WriteTask) -> DataFile:
else:
file_schema = table_schema

downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
batches = [
to_requested_schema(requested_schema=file_schema, file_schema=table_schema, batch=batch)
to_requested_schema(
requested_schema=file_schema,
file_schema=table_schema,
batch=batch,
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us,
)
for batch in task.record_batches
]
arrow_table = pa.Table.from_batches(batches)
Expand Down
10 changes: 7 additions & 3 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@
transform_dict_value_to_str,
)
from pyiceberg.utils.concurrent import ExecutorFactory
from pyiceberg.utils.config import Config
from pyiceberg.utils.datetime import datetime_to_millis
from pyiceberg.utils.deprecated import deprecated
from pyiceberg.utils.singleton import _convert_to_hashable_type
Expand All @@ -161,7 +162,7 @@

ALWAYS_TRUE = AlwaysTrue()
TABLE_ROOT_ID = -1

DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE = "downcast-ns-timestamp-to-us-on-write"
_JAVA_LONG_MAX = 9223372036854775807


Expand All @@ -176,11 +177,14 @@ def _check_schema_compatible(table_schema: Schema, other_schema: "pa.Schema") ->
"""
from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema

downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
name_mapping = table_schema.name_mapping
try:
task_schema = pyarrow_to_schema(other_schema, name_mapping=name_mapping)
task_schema = pyarrow_to_schema(
other_schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)
except ValueError as e:
other_schema = _pyarrow_to_schema_without_ids(other_schema)
other_schema = _pyarrow_to_schema_without_ids(other_schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
additional_names = set(other_schema.column_names) - set(table_schema.column_names)
raise ValueError(
f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)."
Expand Down
59 changes: 59 additions & 0 deletions tests/integration/test_add_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
# under the License.
# pylint:disable=redefined-outer-name

import os
from datetime import date
from typing import Iterator, Optional

import pyarrow as pa
import pyarrow.parquet as pq
import pytest
from pyspark.sql import SparkSession
from pytest_mock.plugin import MockerFixture

from pyiceberg.catalog import Catalog
from pyiceberg.exceptions import NoSuchTableError
Expand All @@ -36,6 +38,7 @@
IntegerType,
NestedField,
StringType,
TimestamptzType,
)

TABLE_SCHEMA = Schema(
Expand Down Expand Up @@ -448,3 +451,59 @@ def test_add_files_snapshot_properties(spark: SparkSession, session_catalog: Cat

assert "snapshot_prop_a" in summary
assert summary["snapshot_prop_a"] == "test_prop_a"


@pytest.mark.integration
def test_timestamp_tz_ns_downcast_on_read(session_catalog: Catalog, format_version: int, mocker: MockerFixture) -> None:
nanoseconds_schema_iceberg = Schema(NestedField(1, "quux", TimestamptzType()))

nanoseconds_schema = pa.schema([
("quux", pa.timestamp("ns", tz="UTC")),
])

arrow_table = pa.Table.from_pylist(
[
{
"quux": 1615967687249846175, # 2021-03-17 07:54:47.249846159
}
],
schema=nanoseconds_schema,
)
mocker.patch.dict(os.environ, values={"PYICEBERG_DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE": "True"})

identifier = f"default.timestamptz_ns_added{format_version}"

try:
session_catalog.drop_table(identifier=identifier)
except NoSuchTableError:
pass

tbl = session_catalog.create_table(
identifier=identifier,
schema=nanoseconds_schema_iceberg,
properties={"format-version": str(format_version)},
partition_spec=PartitionSpec(),
)

file_paths = [f"s3://warehouse/default/test_timestamp_tz/v{format_version}/test-{i}.parquet" for i in range(5)]
# write parquet files
for file_path in file_paths:
fo = tbl.io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=nanoseconds_schema) as writer:
writer.write_table(arrow_table)

# add the parquet files as data files
tbl.add_files(file_paths=file_paths)

assert tbl.scan().to_arrow() == pa.concat_tables(
[
arrow_table.cast(
pa.schema([
("quux", pa.timestamp("us", tz="UTC")),
]),
safe=False,
)
]
* 5
)
Loading

0 comments on commit 301e336

Please sign in to comment.