Skip to content

Commit

Permalink
Update NameMapping on update_schema() (apache#441)
Browse files Browse the repository at this point in the history
* update name-mapping

* Update __init__.py

Co-authored-by: Fokko Driesprong <[email protected]>

* Update pyiceberg/table/name_mapping.py

Co-authored-by: Fokko Driesprong <[email protected]>

* validation mode after

* type

---------

Co-authored-by: Fokko Driesprong <[email protected]>
  • Loading branch information
sungwy and Fokko authored Feb 20, 2024
1 parent b32d3a5 commit fd9dc88
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 19 deletions.
16 changes: 9 additions & 7 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,7 @@
TableMetadata,
TableMetadataUtil,
)
from pyiceberg.table.name_mapping import (
NameMapping,
create_mapping_from_schema,
parse_mapping_from_json,
)
from pyiceberg.table.name_mapping import NameMapping, parse_mapping_from_json, update_mapping
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef
from pyiceberg.table.snapshots import (
Operation,
Expand Down Expand Up @@ -994,12 +990,12 @@ def update_snapshot(self) -> UpdateSnapshot:
"""
return UpdateSnapshot(self)

def name_mapping(self) -> NameMapping:
def name_mapping(self) -> Optional[NameMapping]:
"""Return the table's field-id NameMapping."""
if name_mapping_json := self.properties.get(TableProperties.DEFAULT_NAME_MAPPING):
return parse_mapping_from_json(name_mapping_json)
else:
return create_mapping_from_schema(self.schema())
return None

def append(self, df: pa.Table) -> None:
"""
Expand Down Expand Up @@ -1950,6 +1946,12 @@ def commit(self) -> None:
else:
updates = (SetCurrentSchemaUpdate(schema_id=existing_schema_id),) # type: ignore

if name_mapping := self._table.name_mapping():
updated_name_mapping = update_mapping(name_mapping, self._updates, self._adds)
updates += ( # type: ignore
SetPropertiesUpdate(updates={TableProperties.DEFAULT_NAME_MAPPING: updated_name_mapping.model_dump_json()}),
)

if self._transaction is not None:
self._transaction._append_updates(*updates) # pylint: disable=W0212
self._transaction._append_requirements(*requirements) # pylint: disable=W0212
Expand Down
97 changes: 88 additions & 9 deletions pyiceberg/table/name_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from abc import ABC, abstractmethod
from collections import ChainMap
from functools import cached_property, singledispatch
from typing import Any, Dict, Generic, List, TypeVar, Union
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union

from pydantic import Field, conlist, field_validator, model_serializer

Expand All @@ -45,6 +45,18 @@ class MappedField(IcebergBaseModel):
def convert_null_to_empty_List(cls, v: Any) -> Any:
return v or []

@field_validator('names', mode='after')
@classmethod
def check_at_least_one(cls, v: List[str]) -> Any:
"""
Conlist constraint does not seem to be validating the class on instantiation.
Adding a custom validator to enforce min_length=1 constraint.
"""
if len(v) < 1:
raise ValueError("At least one mapped name must be provided for the field")
return v

@model_serializer
def ser_model(self) -> Dict[str, Any]:
"""Set custom serializer to leave out the field when it is empty."""
Expand Down Expand Up @@ -93,24 +105,25 @@ def __str__(self) -> str:
return "[\n " + "\n ".join([str(e) for e in self.root]) + "\n]"


S = TypeVar('S')
T = TypeVar("T")


class NameMappingVisitor(Generic[T], ABC):
class NameMappingVisitor(Generic[S, T], ABC):
@abstractmethod
def mapping(self, nm: NameMapping, field_results: T) -> T:
def mapping(self, nm: NameMapping, field_results: S) -> S:
"""Visit a NameMapping."""

@abstractmethod
def fields(self, struct: List[MappedField], field_results: List[T]) -> T:
def fields(self, struct: List[MappedField], field_results: List[T]) -> S:
"""Visit a List[MappedField]."""

@abstractmethod
def field(self, field: MappedField, field_result: T) -> T:
def field(self, field: MappedField, field_result: S) -> T:
"""Visit a MappedField."""


class _IndexByName(NameMappingVisitor[Dict[str, MappedField]]):
class _IndexByName(NameMappingVisitor[Dict[str, MappedField], Dict[str, MappedField]]):
def mapping(self, nm: NameMapping, field_results: Dict[str, MappedField]) -> Dict[str, MappedField]:
return field_results

Expand All @@ -129,18 +142,18 @@ def field(self, field: MappedField, field_result: Dict[str, MappedField]) -> Dic


@singledispatch
def visit_name_mapping(obj: Union[NameMapping, List[MappedField], MappedField], visitor: NameMappingVisitor[T]) -> T:
def visit_name_mapping(obj: Union[NameMapping, List[MappedField], MappedField], visitor: NameMappingVisitor[S, T]) -> S:
"""Traverse the name mapping in post-order traversal."""
raise NotImplementedError(f"Cannot visit non-type: {obj}")


@visit_name_mapping.register(NameMapping)
def _(obj: NameMapping, visitor: NameMappingVisitor[T]) -> T:
def _(obj: NameMapping, visitor: NameMappingVisitor[S, T]) -> S:
return visitor.mapping(obj, visit_name_mapping(obj.root, visitor))


@visit_name_mapping.register(list)
def _(fields: List[MappedField], visitor: NameMappingVisitor[T]) -> T:
def _(fields: List[MappedField], visitor: NameMappingVisitor[S, T]) -> S:
results = [visitor.field(field, visit_name_mapping(field.fields, visitor)) for field in fields]
return visitor.fields(fields, results)

Expand Down Expand Up @@ -175,5 +188,71 @@ def primitive(self, primitive: PrimitiveType) -> List[MappedField]:
return []


class _UpdateMapping(NameMappingVisitor[List[MappedField], MappedField]):
_updates: Dict[int, NestedField]
_adds: Dict[int, List[NestedField]]

def __init__(self, updates: Dict[int, NestedField], adds: Dict[int, List[NestedField]]):
self._updates = updates
self._adds = adds

@staticmethod
def _remove_reassigned_names(field: MappedField, assignments: Dict[str, int]) -> Optional[MappedField]:
removed_names = set()
for name in field.names:
if (assigned_id := assignments.get(name)) and assigned_id != field.field_id:
removed_names.add(name)

remaining_names = [f for f in field.names if f not in removed_names]
if remaining_names:
return MappedField(field_id=field.field_id, names=remaining_names, fields=field.fields)
else:
return None

def _add_new_fields(self, mapped_fields: List[MappedField], parent_id: int) -> List[MappedField]:
if fields_to_add := self._adds.get(parent_id):
fields: List[MappedField] = []
new_fields: List[MappedField] = []

for add in fields_to_add:
new_fields.append(
MappedField(field_id=add.field_id, names=[add.name], fields=visit(add.field_type, _CreateMapping()))
)

reassignments = {f.name: f.field_id for f in fields_to_add}
fields = [
updated_field
for field in mapped_fields
if (updated_field := self._remove_reassigned_names(field, reassignments)) is not None
] + new_fields
return fields
else:
return mapped_fields

def mapping(self, nm: NameMapping, field_results: List[MappedField]) -> List[MappedField]:
return self._add_new_fields(field_results, -1)

def fields(self, struct: List[MappedField], field_results: List[MappedField]) -> List[MappedField]:
reassignments: Dict[str, int] = {
update.name: update.field_id for f in field_results if (update := self._updates.get(f.field_id))
}
return [
updated_field
for field in field_results
if (updated_field := self._remove_reassigned_names(field, reassignments)) is not None
]

def field(self, field: MappedField, field_result: List[MappedField]) -> MappedField:
field_names = field.names
if (update := self._updates.get(field.field_id)) is not None and update.name not in field_names:
field_names.append(update.name)

return MappedField(field_id=field.field_id, names=field_names, fields=self._add_new_fields(field_result, field.field_id))


def create_mapping_from_schema(schema: Schema) -> NameMapping:
return NameMapping(visit(schema, _CreateMapping()))


def update_mapping(mapping: NameMapping, updates: Dict[int, NestedField], adds: Dict[int, List[NestedField]]) -> NameMapping:
return NameMapping(visit_name_mapping(mapping, _UpdateMapping(updates, adds)))
21 changes: 19 additions & 2 deletions tests/integration/test_rest_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from pyiceberg.exceptions import CommitFailedException, NoSuchTableError, ValidationError
from pyiceberg.partitioning import PartitionField, PartitionSpec
from pyiceberg.schema import Schema, prune_columns
from pyiceberg.table import Table, UpdateSchema
from pyiceberg.table import Table, TableProperties, UpdateSchema
from pyiceberg.table.name_mapping import MappedField, NameMapping, create_mapping_from_schema
from pyiceberg.table.sorting import SortField, SortOrder
from pyiceberg.transforms import IdentityTransform
from pyiceberg.types import (
Expand Down Expand Up @@ -73,7 +74,11 @@ def _create_table_with_schema(catalog: Catalog, schema: Schema) -> Table:
catalog.drop_table(tbl_name)
except NoSuchTableError:
pass
return catalog.create_table(identifier=tbl_name, schema=schema)
return catalog.create_table(
identifier=tbl_name,
schema=schema,
properties={TableProperties.DEFAULT_NAME_MAPPING: create_mapping_from_schema(schema).model_dump_json()},
)


@pytest.mark.integration
Expand Down Expand Up @@ -674,6 +679,13 @@ def test_rename_simple(simple_table: Table) -> None:
identifier_field_ids=[2],
)

# Check that the name mapping gets updated
assert simple_table.name_mapping() == NameMapping([
MappedField(field_id=1, names=['foo', 'vo']),
MappedField(field_id=2, names=['bar']),
MappedField(field_id=3, names=['baz']),
])


@pytest.mark.integration
def test_rename_simple_nested(catalog: Catalog) -> None:
Expand Down Expand Up @@ -701,6 +713,11 @@ def test_rename_simple_nested(catalog: Catalog) -> None:
),
)

# Check that the name mapping gets updated
assert tbl.name_mapping() == NameMapping([
MappedField(field_id=1, names=['foo'], fields=[MappedField(field_id=2, names=['bar', 'vo'])]),
])


@pytest.mark.integration
def test_rename_simple_nested_with_dots(catalog: Catalog) -> None:
Expand Down
73 changes: 72 additions & 1 deletion tests/table/test_name_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@
import pytest

from pyiceberg.schema import Schema
from pyiceberg.table.name_mapping import MappedField, NameMapping, create_mapping_from_schema, parse_mapping_from_json
from pyiceberg.table.name_mapping import (
MappedField,
NameMapping,
create_mapping_from_schema,
parse_mapping_from_json,
update_mapping,
)
from pyiceberg.types import NestedField, StringType


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -238,3 +245,67 @@ def test_mapping_lookup_by_name(table_name_mapping_nested: NameMapping) -> None:

with pytest.raises(ValueError, match="Could not find field with name: boom"):
table_name_mapping_nested.find("boom")


def test_invalid_mapped_field() -> None:
with pytest.raises(ValueError):
MappedField(field_id=1, names=[])


def test_update_mapping_no_updates_or_adds(table_name_mapping_nested: NameMapping) -> None:
assert update_mapping(table_name_mapping_nested, {}, {}) == table_name_mapping_nested


def test_update_mapping(table_name_mapping_nested: NameMapping) -> None:
updates = {1: NestedField(1, "foo_update", StringType(), True)}
adds = {
-1: [NestedField(18, "add_18", StringType(), True)],
15: [NestedField(19, "name", StringType(), True), NestedField(20, "add_20", StringType(), True)],
}

expected = NameMapping([
MappedField(field_id=1, names=['foo', 'foo_update']),
MappedField(field_id=2, names=['bar']),
MappedField(field_id=3, names=['baz']),
MappedField(field_id=4, names=['qux'], fields=[MappedField(field_id=5, names=['element'])]),
MappedField(
field_id=6,
names=['quux'],
fields=[
MappedField(field_id=7, names=['key']),
MappedField(
field_id=8,
names=['value'],
fields=[
MappedField(field_id=9, names=['key']),
MappedField(field_id=10, names=['value']),
],
),
],
),
MappedField(
field_id=11,
names=['location'],
fields=[
MappedField(
field_id=12,
names=['element'],
fields=[
MappedField(field_id=13, names=['latitude']),
MappedField(field_id=14, names=['longitude']),
],
)
],
),
MappedField(
field_id=15,
names=['person'],
fields=[
MappedField(field_id=17, names=['age']),
MappedField(field_id=19, names=['name']),
MappedField(field_id=20, names=['add_20']),
],
),
MappedField(field_id=18, names=['add_18']),
])
assert update_mapping(table_name_mapping_nested, updates, adds) == expected

0 comments on commit fd9dc88

Please sign in to comment.