diff --git a/pyiceberg/manifest.py b/pyiceberg/manifest.py index 03dc3199bf..5277eed9e6 100644 --- a/pyiceberg/manifest.py +++ b/pyiceberg/manifest.py @@ -37,7 +37,7 @@ from pyiceberg.io import FileIO, InputFile, OutputFile from pyiceberg.partitioning import PartitionSpec from pyiceberg.schema import Schema -from pyiceberg.typedef import EMPTY_DICT, Record +from pyiceberg.typedef import EMPTY_DICT, Record, TableVersion from pyiceberg.types import ( BinaryType, BooleanType, @@ -302,7 +302,7 @@ def _(partition_field_type: PrimitiveType) -> PrimitiveType: return partition_field_type -def data_file_with_partition(partition_type: StructType, format_version: Literal[1, 2]) -> StructType: +def data_file_with_partition(partition_type: StructType, format_version: TableVersion) -> StructType: data_file_partition_type = StructType(*[ NestedField( field_id=field.field_id, @@ -372,7 +372,7 @@ def __setattr__(self, name: str, value: Any) -> None: value = FileFormat[value] super().__setattr__(name, value) - def __init__(self, format_version: Literal[1, 2] = DEFAULT_READ_VERSION, *data: Any, **named_data: Any) -> None: + def __init__(self, format_version: TableVersion = DEFAULT_READ_VERSION, *data: Any, **named_data: Any) -> None: super().__init__( *data, **{"struct": DATA_FILE_TYPE[format_version], **named_data}, @@ -408,7 +408,7 @@ def __eq__(self, other: Any) -> bool: MANIFEST_ENTRY_SCHEMAS_STRUCT = {format_version: schema.as_struct() for format_version, schema in MANIFEST_ENTRY_SCHEMAS.items()} -def manifest_entry_schema_with_data_file(format_version: Literal[1, 2], data_file: StructType) -> Schema: +def manifest_entry_schema_with_data_file(format_version: TableVersion, data_file: StructType) -> Schema: return Schema(*[ NestedField(2, "data_file", data_file, required=True) if field.field_id == 2 else field for field in MANIFEST_ENTRY_SCHEMAS[format_version].fields @@ -719,9 +719,9 @@ def content(self) -> ManifestContent: ... @property @abstractmethod - def version(self) -> Literal[1, 2]: ... + def version(self) -> TableVersion: ... - def _with_partition(self, format_version: Literal[1, 2]) -> Schema: + def _with_partition(self, format_version: TableVersion) -> Schema: data_file_type = data_file_with_partition( format_version=format_version, partition_type=self._spec.partition_type(self._schema) ) @@ -807,7 +807,7 @@ def content(self) -> ManifestContent: return ManifestContent.DATA @property - def version(self) -> Literal[1, 2]: + def version(self) -> TableVersion: return 1 def prepare_entry(self, entry: ManifestEntry) -> ManifestEntry: @@ -834,7 +834,7 @@ def content(self) -> ManifestContent: return ManifestContent.DATA @property - def version(self) -> Literal[1, 2]: + def version(self) -> TableVersion: return 2 def prepare_entry(self, entry: ManifestEntry) -> ManifestEntry: @@ -847,7 +847,7 @@ def prepare_entry(self, entry: ManifestEntry) -> ManifestEntry: def write_manifest( - format_version: Literal[1, 2], spec: PartitionSpec, schema: Schema, output_file: OutputFile, snapshot_id: int + format_version: TableVersion, spec: PartitionSpec, schema: Schema, output_file: OutputFile, snapshot_id: int ) -> ManifestWriter: if format_version == 1: return ManifestWriterV1(spec, schema, output_file, snapshot_id) @@ -858,14 +858,14 @@ def write_manifest( class ManifestListWriter(ABC): - _format_version: Literal[1, 2] + _format_version: TableVersion _output_file: OutputFile _meta: Dict[str, str] _manifest_files: List[ManifestFile] _commit_snapshot_id: int _writer: AvroOutputFile[ManifestFile] - def __init__(self, format_version: Literal[1, 2], output_file: OutputFile, meta: Dict[str, Any]): + def __init__(self, format_version: TableVersion, output_file: OutputFile, meta: Dict[str, Any]): self._format_version = format_version self._output_file = output_file self._meta = meta @@ -957,7 +957,7 @@ def prepare_manifest(self, manifest_file: ManifestFile) -> ManifestFile: def write_manifest_list( - format_version: Literal[1, 2], + format_version: TableVersion, output_file: OutputFile, snapshot_id: int, parent_snapshot_id: Optional[int], diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 787bdb860b..5f67c05c75 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -121,6 +121,7 @@ KeyDefaultDict, Properties, Record, + TableVersion, ) from pyiceberg.types import ( IcebergType, @@ -293,7 +294,7 @@ def _apply(self, updates: Tuple[TableUpdate, ...], requirements: Tuple[TableRequ return self - def upgrade_table_version(self, format_version: Literal[1, 2]) -> Transaction: + def upgrade_table_version(self, format_version: TableVersion) -> Transaction: """Set the table to a certain version. Args: @@ -1023,7 +1024,7 @@ def scan( ) @property - def format_version(self) -> Literal[1, 2]: + def format_version(self) -> TableVersion: return self.metadata.format_version def schema(self) -> Schema: diff --git a/pyiceberg/typedef.py b/pyiceberg/typedef.py index e57bf3490c..4bed386c77 100644 --- a/pyiceberg/typedef.py +++ b/pyiceberg/typedef.py @@ -26,6 +26,7 @@ Dict, Generic, List, + Literal, Optional, Protocol, Set, @@ -37,6 +38,7 @@ from uuid import UUID from pydantic import BaseModel, ConfigDict, RootModel +from typing_extensions import TypeAlias if TYPE_CHECKING: from pyiceberg.types import StructType @@ -199,3 +201,6 @@ def __repr__(self) -> str: def record_fields(self) -> List[str]: """Return values of all the fields of the Record class except those specified in skip_fields.""" return [self.__getattribute__(v) if hasattr(self, v) else None for v in self._position_to_field_name] + + +TableVersion: TypeAlias = Literal[1, 2] diff --git a/tests/utils/test_manifest.py b/tests/utils/test_manifest.py index 3e789cb854..8bb03cd80e 100644 --- a/tests/utils/test_manifest.py +++ b/tests/utils/test_manifest.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=redefined-outer-name,arguments-renamed,fixme from tempfile import TemporaryDirectory -from typing import Dict, Literal +from typing import Dict import fastavro import pytest @@ -39,7 +39,7 @@ from pyiceberg.schema import Schema from pyiceberg.table.snapshots import Operation, Snapshot, Summary from pyiceberg.transforms import IdentityTransform -from pyiceberg.typedef import Record +from pyiceberg.typedef import Record, TableVersion from pyiceberg.types import IntegerType, NestedField @@ -308,7 +308,7 @@ def test_read_manifest_v2(generated_manifest_file_file_v2: str) -> None: @pytest.mark.parametrize("format_version", [1, 2]) def test_write_manifest( - generated_manifest_file_file_v1: str, generated_manifest_file_file_v2: str, format_version: Literal[1, 2] + generated_manifest_file_file_v1: str, generated_manifest_file_file_v2: str, format_version: TableVersion ) -> None: io = load_file_io() snapshot = Snapshot( @@ -478,7 +478,7 @@ def test_write_manifest( @pytest.mark.parametrize("format_version", [1, 2]) def test_write_manifest_list( - generated_manifest_file_file_v1: str, generated_manifest_file_file_v2: str, format_version: Literal[1, 2] + generated_manifest_file_file_v1: str, generated_manifest_file_file_v2: str, format_version: TableVersion ) -> None: io = load_file_io()