Skip to content

Commit

Permalink
Python: Delegate JSON serialization to Pydantic (#8286)
Browse files Browse the repository at this point in the history
* refactor TableMetadataUtil parse_object with annotated type

* refactor TableMetadataUtil parse_obj as parse_raw (from dict to string)

* move parse raw and deserialization logic in a separate TableMetadata factory
  • Loading branch information
aless10 authored Aug 15, 2023
1 parent 06f3929 commit 700127d
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 42 deletions.
2 changes: 1 addition & 1 deletion python/pyiceberg/catalog/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class Endpoints:

class TableResponse(IcebergBaseModel):
metadata_location: str = Field(alias="metadata-location")
metadata: TableMetadata = Field()
metadata: TableMetadata
config: Properties = Field(default_factory=dict)


Expand Down
5 changes: 2 additions & 3 deletions python/pyiceberg/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import codecs
import gzip
import json
from abc import ABC, abstractmethod
from typing import Callable

Expand Down Expand Up @@ -89,9 +88,9 @@ def table_metadata(
with compression.stream_decompressor(byte_stream) as byte_stream:
reader = codecs.getreader(encoding)
json_bytes = reader(byte_stream)
metadata = json.load(json_bytes)
metadata = json_bytes.read()

return TableMetadataUtil.parse_obj(metadata)
return TableMetadataUtil.parse_raw(metadata)


class FromInputFile:
Expand Down
4 changes: 2 additions & 2 deletions python/pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,13 +370,13 @@ class CommitTableRequest(IcebergBaseModel):


class CommitTableResponse(IcebergBaseModel):
metadata: TableMetadata = Field()
metadata: TableMetadata
metadata_location: str = Field(alias="metadata-location")


class Table:
identifier: Identifier = Field()
metadata: TableMetadata = Field()
metadata: TableMetadata
metadata_location: str = Field()
io: FileIO
catalog: Catalog
Expand Down
26 changes: 21 additions & 5 deletions python/pyiceberg/table/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
Union,
)

from pydantic import Field, root_validator
from pydantic import Field
from pydantic import ValidationError as PydanticValidationError
from pydantic import root_validator
from typing_extensions import Annotated

from pyiceberg.exceptions import ValidationError
from pyiceberg.partitioning import PartitionSpec, assign_fresh_partition_spec_ids
Expand Down Expand Up @@ -353,7 +356,16 @@ def check_sort_orders(cls, values: Dict[str, Any]) -> Dict[str, Any]:
increasing long that tracks the order of snapshots in a table."""


TableMetadata = Union[TableMetadataV1, TableMetadataV2]
TableMetadata = Annotated[Union[TableMetadataV1, TableMetadataV2], Field(discriminator="format_version")]


class TableMetadataFactory(IcebergBaseModel):
table_metadata: TableMetadata

@classmethod
def parse_data(cls, data: str) -> "TableMetadataFactory":
labeled_data = f'{{"table_metadata": {data}}}'
return cls.parse_raw(labeled_data)


def new_table_metadata(
Expand All @@ -380,14 +392,18 @@ def new_table_metadata(
class TableMetadataUtil:
"""Helper class for parsing TableMetadata."""

# Once this has been resolved, we can simplify this: https://github.com/samuelcolvin/pydantic/issues/3846
# TableMetadata = Annotated[TableMetadata, Field(alias="format-version", discriminator="format-version")]
@staticmethod
def parse_raw(data: str) -> TableMetadata:
try:
table_metadata_factory = TableMetadataFactory.parse_data(data)
return table_metadata_factory.table_metadata
except PydanticValidationError as e:
raise ValidationError(e) from e

@staticmethod
def parse_obj(data: Dict[str, Any]) -> TableMetadata:
if "format-version" not in data:
raise ValidationError(f"Missing format-version in TableMetadata: {data}")

format_version = data["format-version"]

if format_version == 1:
Expand Down
5 changes: 2 additions & 3 deletions python/tests/catalog/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=protected-access,redefined-outer-name
import json
import uuid
from unittest.mock import MagicMock, patch

Expand Down Expand Up @@ -255,9 +254,9 @@ def test_create_table(table_schema_simple: Schema, hive_database: HiveDatabase,
)

with open(metadata_location, encoding="utf-8") as f:
payload = json.load(f)
payload = f.read()

metadata = TableMetadataUtil.parse_obj(payload)
metadata = TableMetadataUtil.parse_raw(payload)

assert "database/table" in metadata.location

Expand Down
79 changes: 53 additions & 26 deletions python/tests/table/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from pyiceberg.serializers import FromByteStream
from pyiceberg.table import SortOrder
from pyiceberg.table.metadata import (
TableMetadataFactory,
TableMetadataUtil,
TableMetadataV1,
TableMetadataV2,
Expand All @@ -49,41 +50,67 @@
StringType,
StructType,
)
from tests.conftest import EXAMPLE_TABLE_METADATA_V2

EXAMPLE_TABLE_METADATA_V1 = {
"format-version": 1,
"table-uuid": "d20125c8-7284-442c-9aea-15fee620737c",
"location": "s3://bucket/test/location",
"last-updated-ms": 1602638573874,
"last-column-id": 3,
"schema": {
"type": "struct",
"fields": [
{"id": 1, "name": "x", "required": True, "type": "long"},
{"id": 2, "name": "y", "required": True, "type": "long", "doc": "comment"},
{"id": 3, "name": "z", "required": True, "type": "long"},
],
},
"partition-spec": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}],
"properties": {},
"current-snapshot-id": -1,
"snapshots": [{"snapshot-id": 1925, "timestamp-ms": 1602638573822}],
}


@pytest.fixture(scope="session")
def example_table_metadata_v1() -> Dict[str, Any]:
return {
"format-version": 1,
"table-uuid": "d20125c8-7284-442c-9aea-15fee620737c",
"location": "s3://bucket/test/location",
"last-updated-ms": 1602638573874,
"last-column-id": 3,
"schema": {
"type": "struct",
"fields": [
{"id": 1, "name": "x", "required": True, "type": "long"},
{"id": 2, "name": "y", "required": True, "type": "long", "doc": "comment"},
{"id": 3, "name": "z", "required": True, "type": "long"},
],
},
"partition-spec": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}],
"properties": {},
"current-snapshot-id": -1,
"snapshots": [{"snapshot-id": 1925, "timestamp-ms": 1602638573822}],
}
return EXAMPLE_TABLE_METADATA_V1


def test_from_dict_v1(example_table_metadata_v1: Dict[str, Any]) -> None:
"""Test initialization of a TableMetadata instance from a dictionary"""
TableMetadataUtil.parse_obj(example_table_metadata_v1)


def test_from_dict_v1_parse_raw(example_table_metadata_v1: Dict[str, Any]) -> None:
"""Test initialization of a TableMetadata instance from a str"""
TableMetadataUtil.parse_raw(json.dumps(example_table_metadata_v1))


def test_from_dict_v2(example_table_metadata_v2: Dict[str, Any]) -> None:
"""Test initialization of a TableMetadata instance from a dictionary"""
TableMetadataUtil.parse_obj(example_table_metadata_v2)


def test_from_dict_v2_parse_raw(example_table_metadata_v2: Dict[str, Any]) -> None:
"""Test initialization of a TableMetadata instance from a str"""
TableMetadataUtil.parse_raw(json.dumps(example_table_metadata_v2))


@pytest.mark.parametrize(
"table_metadata, expected_version",
[
(EXAMPLE_TABLE_METADATA_V1, 1),
(EXAMPLE_TABLE_METADATA_V2, 2),
],
)
def test_table_metadata_factory(table_metadata: Dict[str, Any], expected_version: int) -> None:
"""Test initialization of a TableMetadataFactory instance"""
factory = TableMetadataFactory(table_metadata=table_metadata)
assert factory.table_metadata.format_version == expected_version


def test_from_byte_stream(example_table_metadata_v2: Dict[str, Any]) -> None:
"""Test generating a TableMetadata instance from a file-like byte stream"""
data = bytes(json.dumps(example_table_metadata_v2), encoding="utf-8")
Expand All @@ -93,7 +120,7 @@ def test_from_byte_stream(example_table_metadata_v2: Dict[str, Any]) -> None:

def test_v2_metadata_parsing(example_table_metadata_v2: Dict[str, Any]) -> None:
"""Test retrieving values from a TableMetadata instance of version 2"""
table_metadata = TableMetadataUtil.parse_obj(example_table_metadata_v2)
table_metadata = TableMetadataFactory(table_metadata=example_table_metadata_v2).table_metadata

assert table_metadata.format_version == 2
assert table_metadata.table_uuid == UUID("9c12d441-03fe-4693-9a96-a0705ddf69c1")
Expand Down Expand Up @@ -233,9 +260,9 @@ def test_invalid_format_version() -> None:
}

with pytest.raises(ValidationError) as exc_info:
TableMetadataUtil.parse_obj(table_metadata_invalid_format_version)
TableMetadataUtil.parse_raw(json.dumps(table_metadata_invalid_format_version))

assert "Unknown format version: -1" in str(exc_info.value)
assert "No match for discriminator 'format_version' and value -1 (allowed values: 1, 2)" in str(exc_info.value)


def test_current_schema_not_found() -> None:
Expand Down Expand Up @@ -271,7 +298,7 @@ def test_current_schema_not_found() -> None:
}

with pytest.raises(ValidationError) as exc_info:
TableMetadataUtil.parse_obj(table_metadata_schema_not_found)
TableMetadataUtil.parse_raw(json.dumps(table_metadata_schema_not_found))

assert "current-schema-id 2 can't be found in the schemas" in str(exc_info.value)

Expand Down Expand Up @@ -317,7 +344,7 @@ def test_sort_order_not_found() -> None:
}

with pytest.raises(ValidationError) as exc_info:
TableMetadataUtil.parse_obj(table_metadata_schema_not_found)
TableMetadataUtil.parse_raw(json.dumps(table_metadata_schema_not_found))

assert "default-sort-order-id 4 can't be found" in str(exc_info.value)

Expand Down Expand Up @@ -354,7 +381,7 @@ def test_sort_order_unsorted() -> None:
"snapshots": [],
}

table_metadata = TableMetadataUtil.parse_obj(table_metadata_schema_not_found)
table_metadata = TableMetadataUtil.parse_raw(json.dumps(table_metadata_schema_not_found))

# Most important here is that we correctly handle sort-order-id 0
assert len(table_metadata.sort_orders) == 0
Expand Down Expand Up @@ -389,7 +416,7 @@ def test_invalid_partition_spec() -> None:
"last-partition-id": 1000,
}
with pytest.raises(ValidationError) as exc_info:
TableMetadataUtil.parse_obj(table_metadata_spec_not_found)
TableMetadataUtil.parse_raw(json.dumps(table_metadata_spec_not_found))

assert "default-spec-id 1 can't be found" in str(exc_info.value)

Expand Down
4 changes: 2 additions & 2 deletions python/tests/table/test_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint:disable=redefined-outer-name,eval-used

import json
from typing import Any, Dict

import pytest
Expand Down Expand Up @@ -57,7 +57,7 @@ def test_deserialize_sort_order(sort_order: SortOrder) -> None:


def test_sorting_schema(example_table_metadata_v2: Dict[str, Any]) -> None:
table_metadata = TableMetadataUtil.parse_obj(example_table_metadata_v2)
table_metadata = TableMetadataUtil.parse_raw(json.dumps(example_table_metadata_v2))

assert table_metadata.sort_orders == [
SortOrder(
Expand Down

0 comments on commit 700127d

Please sign in to comment.