diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index d70d4380f1..2f891cd14a 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -84,11 +84,7 @@ TableMetadata, TableMetadataUtil, ) -from pyiceberg.table.name_mapping import ( - NameMapping, - create_mapping_from_schema, - parse_mapping_from_json, -) +from pyiceberg.table.name_mapping import NameMapping, parse_mapping_from_json, update_mapping from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef from pyiceberg.table.snapshots import ( Operation, @@ -994,12 +990,12 @@ def update_snapshot(self) -> UpdateSnapshot: """ return UpdateSnapshot(self) - def name_mapping(self) -> NameMapping: + def name_mapping(self) -> Optional[NameMapping]: """Return the table's field-id NameMapping.""" if name_mapping_json := self.properties.get(TableProperties.DEFAULT_NAME_MAPPING): return parse_mapping_from_json(name_mapping_json) else: - return create_mapping_from_schema(self.schema()) + return None def append(self, df: pa.Table) -> None: """ @@ -1950,6 +1946,12 @@ def commit(self) -> None: else: updates = (SetCurrentSchemaUpdate(schema_id=existing_schema_id),) # type: ignore + if name_mapping := self._table.name_mapping(): + updated_name_mapping = update_mapping(name_mapping, self._updates, self._adds) + updates += ( # type: ignore + SetPropertiesUpdate(updates={TableProperties.DEFAULT_NAME_MAPPING: updated_name_mapping.model_dump_json()}), + ) + if self._transaction is not None: self._transaction._append_updates(*updates) # pylint: disable=W0212 self._transaction._append_requirements(*requirements) # pylint: disable=W0212 diff --git a/pyiceberg/table/name_mapping.py b/pyiceberg/table/name_mapping.py index ffe96359a8..94974836cd 100644 --- a/pyiceberg/table/name_mapping.py +++ b/pyiceberg/table/name_mapping.py @@ -26,7 +26,7 @@ from abc import ABC, abstractmethod from collections import ChainMap from functools import cached_property, singledispatch -from typing import Any, Dict, Generic, List, TypeVar, Union +from typing import Any, Dict, Generic, List, Optional, TypeVar, Union from pydantic import Field, conlist, field_validator, model_serializer @@ -45,6 +45,18 @@ class MappedField(IcebergBaseModel): def convert_null_to_empty_List(cls, v: Any) -> Any: return v or [] + @field_validator('names', mode='after') + @classmethod + def check_at_least_one(cls, v: List[str]) -> Any: + """ + Conlist constraint does not seem to be validating the class on instantiation. + + Adding a custom validator to enforce min_length=1 constraint. + """ + if len(v) < 1: + raise ValueError("At least one mapped name must be provided for the field") + return v + @model_serializer def ser_model(self) -> Dict[str, Any]: """Set custom serializer to leave out the field when it is empty.""" @@ -93,24 +105,25 @@ def __str__(self) -> str: return "[\n " + "\n ".join([str(e) for e in self.root]) + "\n]" +S = TypeVar('S') T = TypeVar("T") -class NameMappingVisitor(Generic[T], ABC): +class NameMappingVisitor(Generic[S, T], ABC): @abstractmethod - def mapping(self, nm: NameMapping, field_results: T) -> T: + def mapping(self, nm: NameMapping, field_results: S) -> S: """Visit a NameMapping.""" @abstractmethod - def fields(self, struct: List[MappedField], field_results: List[T]) -> T: + def fields(self, struct: List[MappedField], field_results: List[T]) -> S: """Visit a List[MappedField].""" @abstractmethod - def field(self, field: MappedField, field_result: T) -> T: + def field(self, field: MappedField, field_result: S) -> T: """Visit a MappedField.""" -class _IndexByName(NameMappingVisitor[Dict[str, MappedField]]): +class _IndexByName(NameMappingVisitor[Dict[str, MappedField], Dict[str, MappedField]]): def mapping(self, nm: NameMapping, field_results: Dict[str, MappedField]) -> Dict[str, MappedField]: return field_results @@ -129,18 +142,18 @@ def field(self, field: MappedField, field_result: Dict[str, MappedField]) -> Dic @singledispatch -def visit_name_mapping(obj: Union[NameMapping, List[MappedField], MappedField], visitor: NameMappingVisitor[T]) -> T: +def visit_name_mapping(obj: Union[NameMapping, List[MappedField], MappedField], visitor: NameMappingVisitor[S, T]) -> S: """Traverse the name mapping in post-order traversal.""" raise NotImplementedError(f"Cannot visit non-type: {obj}") @visit_name_mapping.register(NameMapping) -def _(obj: NameMapping, visitor: NameMappingVisitor[T]) -> T: +def _(obj: NameMapping, visitor: NameMappingVisitor[S, T]) -> S: return visitor.mapping(obj, visit_name_mapping(obj.root, visitor)) @visit_name_mapping.register(list) -def _(fields: List[MappedField], visitor: NameMappingVisitor[T]) -> T: +def _(fields: List[MappedField], visitor: NameMappingVisitor[S, T]) -> S: results = [visitor.field(field, visit_name_mapping(field.fields, visitor)) for field in fields] return visitor.fields(fields, results) @@ -175,5 +188,71 @@ def primitive(self, primitive: PrimitiveType) -> List[MappedField]: return [] +class _UpdateMapping(NameMappingVisitor[List[MappedField], MappedField]): + _updates: Dict[int, NestedField] + _adds: Dict[int, List[NestedField]] + + def __init__(self, updates: Dict[int, NestedField], adds: Dict[int, List[NestedField]]): + self._updates = updates + self._adds = adds + + @staticmethod + def _remove_reassigned_names(field: MappedField, assignments: Dict[str, int]) -> Optional[MappedField]: + removed_names = set() + for name in field.names: + if (assigned_id := assignments.get(name)) and assigned_id != field.field_id: + removed_names.add(name) + + remaining_names = [f for f in field.names if f not in removed_names] + if remaining_names: + return MappedField(field_id=field.field_id, names=remaining_names, fields=field.fields) + else: + return None + + def _add_new_fields(self, mapped_fields: List[MappedField], parent_id: int) -> List[MappedField]: + if fields_to_add := self._adds.get(parent_id): + fields: List[MappedField] = [] + new_fields: List[MappedField] = [] + + for add in fields_to_add: + new_fields.append( + MappedField(field_id=add.field_id, names=[add.name], fields=visit(add.field_type, _CreateMapping())) + ) + + reassignments = {f.name: f.field_id for f in fields_to_add} + fields = [ + updated_field + for field in mapped_fields + if (updated_field := self._remove_reassigned_names(field, reassignments)) is not None + ] + new_fields + return fields + else: + return mapped_fields + + def mapping(self, nm: NameMapping, field_results: List[MappedField]) -> List[MappedField]: + return self._add_new_fields(field_results, -1) + + def fields(self, struct: List[MappedField], field_results: List[MappedField]) -> List[MappedField]: + reassignments: Dict[str, int] = { + update.name: update.field_id for f in field_results if (update := self._updates.get(f.field_id)) + } + return [ + updated_field + for field in field_results + if (updated_field := self._remove_reassigned_names(field, reassignments)) is not None + ] + + def field(self, field: MappedField, field_result: List[MappedField]) -> MappedField: + field_names = field.names + if (update := self._updates.get(field.field_id)) is not None and update.name not in field_names: + field_names.append(update.name) + + return MappedField(field_id=field.field_id, names=field_names, fields=self._add_new_fields(field_result, field.field_id)) + + def create_mapping_from_schema(schema: Schema) -> NameMapping: return NameMapping(visit(schema, _CreateMapping())) + + +def update_mapping(mapping: NameMapping, updates: Dict[int, NestedField], adds: Dict[int, List[NestedField]]) -> NameMapping: + return NameMapping(visit_name_mapping(mapping, _UpdateMapping(updates, adds))) diff --git a/tests/integration/test_rest_schema.py b/tests/integration/test_rest_schema.py index a3320e4d3a..aae07caba9 100644 --- a/tests/integration/test_rest_schema.py +++ b/tests/integration/test_rest_schema.py @@ -22,7 +22,8 @@ from pyiceberg.exceptions import CommitFailedException, NoSuchTableError, ValidationError from pyiceberg.partitioning import PartitionField, PartitionSpec from pyiceberg.schema import Schema, prune_columns -from pyiceberg.table import Table, UpdateSchema +from pyiceberg.table import Table, TableProperties, UpdateSchema +from pyiceberg.table.name_mapping import MappedField, NameMapping, create_mapping_from_schema from pyiceberg.table.sorting import SortField, SortOrder from pyiceberg.transforms import IdentityTransform from pyiceberg.types import ( @@ -73,7 +74,11 @@ def _create_table_with_schema(catalog: Catalog, schema: Schema) -> Table: catalog.drop_table(tbl_name) except NoSuchTableError: pass - return catalog.create_table(identifier=tbl_name, schema=schema) + return catalog.create_table( + identifier=tbl_name, + schema=schema, + properties={TableProperties.DEFAULT_NAME_MAPPING: create_mapping_from_schema(schema).model_dump_json()}, + ) @pytest.mark.integration @@ -674,6 +679,13 @@ def test_rename_simple(simple_table: Table) -> None: identifier_field_ids=[2], ) + # Check that the name mapping gets updated + assert simple_table.name_mapping() == NameMapping([ + MappedField(field_id=1, names=['foo', 'vo']), + MappedField(field_id=2, names=['bar']), + MappedField(field_id=3, names=['baz']), + ]) + @pytest.mark.integration def test_rename_simple_nested(catalog: Catalog) -> None: @@ -701,6 +713,11 @@ def test_rename_simple_nested(catalog: Catalog) -> None: ), ) + # Check that the name mapping gets updated + assert tbl.name_mapping() == NameMapping([ + MappedField(field_id=1, names=['foo'], fields=[MappedField(field_id=2, names=['bar', 'vo'])]), + ]) + @pytest.mark.integration def test_rename_simple_nested_with_dots(catalog: Catalog) -> None: diff --git a/tests/table/test_name_mapping.py b/tests/table/test_name_mapping.py index d74aa3234c..e039415ce3 100644 --- a/tests/table/test_name_mapping.py +++ b/tests/table/test_name_mapping.py @@ -17,7 +17,14 @@ import pytest from pyiceberg.schema import Schema -from pyiceberg.table.name_mapping import MappedField, NameMapping, create_mapping_from_schema, parse_mapping_from_json +from pyiceberg.table.name_mapping import ( + MappedField, + NameMapping, + create_mapping_from_schema, + parse_mapping_from_json, + update_mapping, +) +from pyiceberg.types import NestedField, StringType @pytest.fixture(scope="session") @@ -238,3 +245,67 @@ def test_mapping_lookup_by_name(table_name_mapping_nested: NameMapping) -> None: with pytest.raises(ValueError, match="Could not find field with name: boom"): table_name_mapping_nested.find("boom") + + +def test_invalid_mapped_field() -> None: + with pytest.raises(ValueError): + MappedField(field_id=1, names=[]) + + +def test_update_mapping_no_updates_or_adds(table_name_mapping_nested: NameMapping) -> None: + assert update_mapping(table_name_mapping_nested, {}, {}) == table_name_mapping_nested + + +def test_update_mapping(table_name_mapping_nested: NameMapping) -> None: + updates = {1: NestedField(1, "foo_update", StringType(), True)} + adds = { + -1: [NestedField(18, "add_18", StringType(), True)], + 15: [NestedField(19, "name", StringType(), True), NestedField(20, "add_20", StringType(), True)], + } + + expected = NameMapping([ + MappedField(field_id=1, names=['foo', 'foo_update']), + MappedField(field_id=2, names=['bar']), + MappedField(field_id=3, names=['baz']), + MappedField(field_id=4, names=['qux'], fields=[MappedField(field_id=5, names=['element'])]), + MappedField( + field_id=6, + names=['quux'], + fields=[ + MappedField(field_id=7, names=['key']), + MappedField( + field_id=8, + names=['value'], + fields=[ + MappedField(field_id=9, names=['key']), + MappedField(field_id=10, names=['value']), + ], + ), + ], + ), + MappedField( + field_id=11, + names=['location'], + fields=[ + MappedField( + field_id=12, + names=['element'], + fields=[ + MappedField(field_id=13, names=['latitude']), + MappedField(field_id=14, names=['longitude']), + ], + ) + ], + ), + MappedField( + field_id=15, + names=['person'], + fields=[ + MappedField(field_id=17, names=['age']), + MappedField(field_id=19, names=['name']), + MappedField(field_id=20, names=['add_20']), + ], + ), + MappedField(field_id=18, names=['add_18']), + ]) + assert update_mapping(table_name_mapping_nested, updates, adds) == expected