Skip to content

Commit

Permalink
Add typealias for table version (apache#566)
Browse files Browse the repository at this point in the history
* typealias for table version

* typealias for table version

* typealias for table version

* typealias for table version

* typealias for table version

* typealias for table version replaced in all files
  • Loading branch information
MehulBatra authored Apr 3, 2024
1 parent 474b37b commit ba9ff98
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 18 deletions.
24 changes: 12 additions & 12 deletions pyiceberg/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from pyiceberg.io import FileIO, InputFile, OutputFile
from pyiceberg.partitioning import PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.typedef import EMPTY_DICT, Record
from pyiceberg.typedef import EMPTY_DICT, Record, TableVersion
from pyiceberg.types import (
BinaryType,
BooleanType,
Expand Down Expand Up @@ -302,7 +302,7 @@ def _(partition_field_type: PrimitiveType) -> PrimitiveType:
return partition_field_type


def data_file_with_partition(partition_type: StructType, format_version: Literal[1, 2]) -> StructType:
def data_file_with_partition(partition_type: StructType, format_version: TableVersion) -> StructType:
data_file_partition_type = StructType(*[
NestedField(
field_id=field.field_id,
Expand Down Expand Up @@ -372,7 +372,7 @@ def __setattr__(self, name: str, value: Any) -> None:
value = FileFormat[value]
super().__setattr__(name, value)

def __init__(self, format_version: Literal[1, 2] = DEFAULT_READ_VERSION, *data: Any, **named_data: Any) -> None:
def __init__(self, format_version: TableVersion = DEFAULT_READ_VERSION, *data: Any, **named_data: Any) -> None:
super().__init__(
*data,
**{"struct": DATA_FILE_TYPE[format_version], **named_data},
Expand Down Expand Up @@ -408,7 +408,7 @@ def __eq__(self, other: Any) -> bool:
MANIFEST_ENTRY_SCHEMAS_STRUCT = {format_version: schema.as_struct() for format_version, schema in MANIFEST_ENTRY_SCHEMAS.items()}


def manifest_entry_schema_with_data_file(format_version: Literal[1, 2], data_file: StructType) -> Schema:
def manifest_entry_schema_with_data_file(format_version: TableVersion, data_file: StructType) -> Schema:
return Schema(*[
NestedField(2, "data_file", data_file, required=True) if field.field_id == 2 else field
for field in MANIFEST_ENTRY_SCHEMAS[format_version].fields
Expand Down Expand Up @@ -719,9 +719,9 @@ def content(self) -> ManifestContent: ...

@property
@abstractmethod
def version(self) -> Literal[1, 2]: ...
def version(self) -> TableVersion: ...

def _with_partition(self, format_version: Literal[1, 2]) -> Schema:
def _with_partition(self, format_version: TableVersion) -> Schema:
data_file_type = data_file_with_partition(
format_version=format_version, partition_type=self._spec.partition_type(self._schema)
)
Expand Down Expand Up @@ -807,7 +807,7 @@ def content(self) -> ManifestContent:
return ManifestContent.DATA

@property
def version(self) -> Literal[1, 2]:
def version(self) -> TableVersion:
return 1

def prepare_entry(self, entry: ManifestEntry) -> ManifestEntry:
Expand All @@ -834,7 +834,7 @@ def content(self) -> ManifestContent:
return ManifestContent.DATA

@property
def version(self) -> Literal[1, 2]:
def version(self) -> TableVersion:
return 2

def prepare_entry(self, entry: ManifestEntry) -> ManifestEntry:
Expand All @@ -847,7 +847,7 @@ def prepare_entry(self, entry: ManifestEntry) -> ManifestEntry:


def write_manifest(
format_version: Literal[1, 2], spec: PartitionSpec, schema: Schema, output_file: OutputFile, snapshot_id: int
format_version: TableVersion, spec: PartitionSpec, schema: Schema, output_file: OutputFile, snapshot_id: int
) -> ManifestWriter:
if format_version == 1:
return ManifestWriterV1(spec, schema, output_file, snapshot_id)
Expand All @@ -858,14 +858,14 @@ def write_manifest(


class ManifestListWriter(ABC):
_format_version: Literal[1, 2]
_format_version: TableVersion
_output_file: OutputFile
_meta: Dict[str, str]
_manifest_files: List[ManifestFile]
_commit_snapshot_id: int
_writer: AvroOutputFile[ManifestFile]

def __init__(self, format_version: Literal[1, 2], output_file: OutputFile, meta: Dict[str, Any]):
def __init__(self, format_version: TableVersion, output_file: OutputFile, meta: Dict[str, Any]):
self._format_version = format_version
self._output_file = output_file
self._meta = meta
Expand Down Expand Up @@ -957,7 +957,7 @@ def prepare_manifest(self, manifest_file: ManifestFile) -> ManifestFile:


def write_manifest_list(
format_version: Literal[1, 2],
format_version: TableVersion,
output_file: OutputFile,
snapshot_id: int,
parent_snapshot_id: Optional[int],
Expand Down
5 changes: 3 additions & 2 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
KeyDefaultDict,
Properties,
Record,
TableVersion,
)
from pyiceberg.types import (
IcebergType,
Expand Down Expand Up @@ -293,7 +294,7 @@ def _apply(self, updates: Tuple[TableUpdate, ...], requirements: Tuple[TableRequ

return self

def upgrade_table_version(self, format_version: Literal[1, 2]) -> Transaction:
def upgrade_table_version(self, format_version: TableVersion) -> Transaction:
"""Set the table to a certain version.
Args:
Expand Down Expand Up @@ -1023,7 +1024,7 @@ def scan(
)

@property
def format_version(self) -> Literal[1, 2]:
def format_version(self) -> TableVersion:
return self.metadata.format_version

def schema(self) -> Schema:
Expand Down
5 changes: 5 additions & 0 deletions pyiceberg/typedef.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Dict,
Generic,
List,
Literal,
Optional,
Protocol,
Set,
Expand All @@ -37,6 +38,7 @@
from uuid import UUID

from pydantic import BaseModel, ConfigDict, RootModel
from typing_extensions import TypeAlias

if TYPE_CHECKING:
from pyiceberg.types import StructType
Expand Down Expand Up @@ -199,3 +201,6 @@ def __repr__(self) -> str:
def record_fields(self) -> List[str]:
"""Return values of all the fields of the Record class except those specified in skip_fields."""
return [self.__getattribute__(v) if hasattr(self, v) else None for v in self._position_to_field_name]


TableVersion: TypeAlias = Literal[1, 2]
8 changes: 4 additions & 4 deletions tests/utils/test_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
# pylint: disable=redefined-outer-name,arguments-renamed,fixme
from tempfile import TemporaryDirectory
from typing import Dict, Literal
from typing import Dict

import fastavro
import pytest
Expand All @@ -39,7 +39,7 @@
from pyiceberg.schema import Schema
from pyiceberg.table.snapshots import Operation, Snapshot, Summary
from pyiceberg.transforms import IdentityTransform
from pyiceberg.typedef import Record
from pyiceberg.typedef import Record, TableVersion
from pyiceberg.types import IntegerType, NestedField


Expand Down Expand Up @@ -308,7 +308,7 @@ def test_read_manifest_v2(generated_manifest_file_file_v2: str) -> None:

@pytest.mark.parametrize("format_version", [1, 2])
def test_write_manifest(
generated_manifest_file_file_v1: str, generated_manifest_file_file_v2: str, format_version: Literal[1, 2]
generated_manifest_file_file_v1: str, generated_manifest_file_file_v2: str, format_version: TableVersion
) -> None:
io = load_file_io()
snapshot = Snapshot(
Expand Down Expand Up @@ -478,7 +478,7 @@ def test_write_manifest(

@pytest.mark.parametrize("format_version", [1, 2])
def test_write_manifest_list(
generated_manifest_file_file_v1: str, generated_manifest_file_file_v2: str, format_version: Literal[1, 2]
generated_manifest_file_file_v1: str, generated_manifest_file_file_v2: str, format_version: TableVersion
) -> None:
io = load_file_io()

Expand Down

0 comments on commit ba9ff98

Please sign in to comment.