From a892309936effa7ec575195ad3be70193e82d704 Mon Sep 17 00:00:00 2001 From: Honah J Date: Thu, 4 Apr 2024 01:02:33 -0400 Subject: [PATCH] Add CreateTableTransaction API and implement it in Glue and Rest (#498) --- mkdocs/docs/api.md | 19 ++ pyiceberg/catalog/__init__.py | 298 +++++++++++++++++-------- pyiceberg/catalog/dynamodb.py | 4 +- pyiceberg/catalog/glue.py | 136 ++++++----- pyiceberg/catalog/hive.py | 4 +- pyiceberg/catalog/noop.py | 18 ++ pyiceberg/catalog/rest.py | 76 ++++++- pyiceberg/catalog/sql.py | 4 +- pyiceberg/table/__init__.py | 160 +++++++++++-- pyiceberg/table/metadata.py | 2 +- tests/catalog/integration_test_glue.py | 92 +++++++- tests/catalog/test_base.py | 3 +- tests/catalog/test_glue.py | 48 ++++ tests/integration/test_writes.py | 54 +++++ 14 files changed, 723 insertions(+), 195 deletions(-) diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md index 828dd18621..c8620af732 100644 --- a/mkdocs/docs/api.md +++ b/mkdocs/docs/api.md @@ -165,6 +165,25 @@ catalog.create_table( ) ``` +To create a table with some subsequent changes atomically in a transaction: + +```python +with catalog.create_table_transaction( + identifier="docs_example.bids", + schema=schema, + location="s3://pyiceberg", + partition_spec=partition_spec, + sort_order=sort_order, +) as txn: + with txn.update_schema() as update_schema: + update_schema.add_column(path="new_column", field_type=StringType()) + + with txn.update_spec() as update_spec: + update_spec.add_identity("symbol") + + txn.set_properties(test_a="test_aa", test_b="test_b", test_c="test_c") +``` + ## Load a table ### Catalog table diff --git a/pyiceberg/catalog/__init__.py b/pyiceberg/catalog/__init__.py index f2b46fcde7..f104aa94da 100644 --- a/pyiceberg/catalog/__init__.py +++ b/pyiceberg/catalog/__init__.py @@ -45,9 +45,11 @@ from pyiceberg.table import ( CommitTableRequest, CommitTableResponse, + CreateTableTransaction, + StagedTable, Table, ) -from pyiceberg.table.metadata import TableMetadata +from pyiceberg.table.metadata import TableMetadata, TableMetadataV1, new_table_metadata from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder from pyiceberg.typedef import ( EMPTY_DICT, @@ -285,9 +287,6 @@ def __init__(self, name: str, **properties: str): self.name = name self.properties = properties - def _load_file_io(self, properties: Properties = EMPTY_DICT, location: Optional[str] = None) -> FileIO: - return load_file_io({**self.properties, **properties}, location) - @abstractmethod def create_table( self, @@ -315,6 +314,30 @@ def create_table( TableAlreadyExistsError: If a table with the name already exists. """ + @abstractmethod + def create_table_transaction( + self, + identifier: Union[str, Identifier], + schema: Union[Schema, "pa.Schema"], + location: Optional[str] = None, + partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, + sort_order: SortOrder = UNSORTED_SORT_ORDER, + properties: Properties = EMPTY_DICT, + ) -> CreateTableTransaction: + """Create a CreateTableTransaction. + + Args: + identifier (str | Identifier): Table identifier. + schema (Schema): Table's schema. + location (str | None): Location for the table. Optional Argument. + partition_spec (PartitionSpec): PartitionSpec for the table. + sort_order (SortOrder): SortOrder for the table. + properties (Properties): Table properties that can be a string based dictionary. + + Returns: + CreateTableTransaction: createTableTransaction instance. + """ + def create_table_if_not_exists( self, identifier: Union[str, Identifier], @@ -360,6 +383,17 @@ def load_table(self, identifier: Union[str, Identifier]) -> Table: NoSuchTableError: If a table with the name does not exist. """ + @abstractmethod + def table_exists(self, identifier: Union[str, Identifier]) -> bool: + """Check if a table exists. + + Args: + identifier (str | Identifier): Table identifier. + + Returns: + bool: True if the table exists, False otherwise. + """ + @abstractmethod def register_table(self, identifier: Union[str, Identifier], metadata_location: str) -> Table: """Register a new table using existing metadata. @@ -386,6 +420,19 @@ def drop_table(self, identifier: Union[str, Identifier]) -> None: NoSuchTableError: If a table with the name does not exist. """ + @abstractmethod + def purge_table(self, identifier: Union[str, Identifier]) -> None: + """Drop a table and purge all data and metadata files. + + Note: This method only logs warning rather than raise exception when encountering file deletion failure. + + Args: + identifier (str | Identifier): Table identifier. + + Raises: + NoSuchTableError: If a table with the name does not exist, or the identifier is invalid. + """ + @abstractmethod def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: Union[str, Identifier]) -> Table: """Rename a fully classified table name. @@ -501,6 +548,20 @@ def update_namespace_properties( ValueError: If removals and updates have overlapping keys. """ + def identifier_to_tuple_without_catalog(self, identifier: Union[str, Identifier]) -> Identifier: + """Convert an identifier to a tuple and drop this catalog's name from the first element. + + Args: + identifier (str | Identifier): Table identifier. + + Returns: + Identifier: a tuple of strings with this catalog's name removed + """ + identifier_tuple = Catalog.identifier_to_tuple(identifier) + if len(identifier_tuple) >= 3 and identifier_tuple[0] == self.name: + identifier_tuple = identifier_tuple[1:] + return identifier_tuple + @staticmethod def identifier_to_tuple(identifier: Union[str, Identifier]) -> Identifier: """Parse an identifier to a tuple. @@ -539,46 +600,6 @@ def namespace_from(identifier: Union[str, Identifier]) -> Identifier: """ return Catalog.identifier_to_tuple(identifier)[:-1] - @staticmethod - def _check_for_overlap(removals: Optional[Set[str]], updates: Properties) -> None: - if updates and removals: - overlap = set(removals) & set(updates.keys()) - if overlap: - raise ValueError(f"Updates and deletes have an overlap: {overlap}") - - @staticmethod - def _convert_schema_if_needed(schema: Union[Schema, "pa.Schema"]) -> Schema: - if isinstance(schema, Schema): - return schema - try: - import pyarrow as pa - - from pyiceberg.io.pyarrow import _ConvertToIcebergWithoutIDs, visit_pyarrow - - if isinstance(schema, pa.Schema): - schema: Schema = visit_pyarrow(schema, _ConvertToIcebergWithoutIDs()) # type: ignore - return schema - except ModuleNotFoundError: - pass - raise ValueError(f"{type(schema)=}, but it must be pyiceberg.schema.Schema or pyarrow.Schema") - - def _resolve_table_location(self, location: Optional[str], database_name: str, table_name: str) -> str: - if not location: - return self._get_default_warehouse_location(database_name, table_name) - return location - - def _get_default_warehouse_location(self, database_name: str, table_name: str) -> str: - database_properties = self.load_namespace_properties(database_name) - if database_location := database_properties.get(LOCATION): - database_location = database_location.rstrip("/") - return f"{database_location}/{table_name}" - - if warehouse_path := self.properties.get(WAREHOUSE_LOCATION): - warehouse_path = warehouse_path.rstrip("/") - return f"{warehouse_path}/{database_name}.db/{table_name}" - - raise ValueError("No default path is set, please specify a location when creating a table") - @staticmethod def identifier_to_database( identifier: Union[str, Identifier], err: Union[Type[ValueError], Type[NoSuchNamespaceError]] = ValueError @@ -600,31 +621,52 @@ def identifier_to_database_and_table( return tuple_identifier[0], tuple_identifier[1] - def identifier_to_tuple_without_catalog(self, identifier: Union[str, Identifier]) -> Identifier: - """Convert an identifier to a tuple and drop this catalog's name from the first element. + def _load_file_io(self, properties: Properties = EMPTY_DICT, location: Optional[str] = None) -> FileIO: + return load_file_io({**self.properties, **properties}, location) - Args: - identifier (str | Identifier): Table identifier. + @staticmethod + def _convert_schema_if_needed(schema: Union[Schema, "pa.Schema"]) -> Schema: + if isinstance(schema, Schema): + return schema + try: + import pyarrow as pa - Returns: - Identifier: a tuple of strings with this catalog's name removed - """ - identifier_tuple = Catalog.identifier_to_tuple(identifier) - if len(identifier_tuple) >= 3 and identifier_tuple[0] == self.name: - identifier_tuple = identifier_tuple[1:] - return identifier_tuple + from pyiceberg.io.pyarrow import _ConvertToIcebergWithoutIDs, visit_pyarrow - def purge_table(self, identifier: Union[str, Identifier]) -> None: - """Drop a table and purge all data and metadata files. + if isinstance(schema, pa.Schema): + schema: Schema = visit_pyarrow(schema, _ConvertToIcebergWithoutIDs()) # type: ignore + return schema + except ModuleNotFoundError: + pass + raise ValueError(f"{type(schema)=}, but it must be pyiceberg.schema.Schema or pyarrow.Schema") - Note: This method only logs warning rather than raise exception when encountering file deletion failure. + def __repr__(self) -> str: + """Return the string representation of the Catalog class.""" + return f"{self.name} ({self.__class__})" - Args: - identifier (str | Identifier): Table identifier. - Raises: - NoSuchTableError: If a table with the name does not exist, or the identifier is invalid. - """ +class MetastoreCatalog(Catalog, ABC): + def create_table_transaction( + self, + identifier: Union[str, Identifier], + schema: Union[Schema, "pa.Schema"], + location: Optional[str] = None, + partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, + sort_order: SortOrder = UNSORTED_SORT_ORDER, + properties: Properties = EMPTY_DICT, + ) -> CreateTableTransaction: + return CreateTableTransaction( + self._create_staged_table(identifier, schema, location, partition_spec, sort_order, properties) + ) + + def table_exists(self, identifier: Union[str, Identifier]) -> bool: + try: + self.load_table(identifier) + return True + except NoSuchTableError: + return False + + def purge_table(self, identifier: Union[str, Identifier]) -> None: identifier_tuple = self.identifier_to_tuple_without_catalog(identifier) table = self.load_table(identifier_tuple) self.drop_table(identifier_tuple) @@ -646,12 +688,88 @@ def purge_table(self, identifier: Union[str, Identifier]) -> None: delete_files(io, prev_metadata_files, PREVIOUS_METADATA) delete_files(io, {table.metadata_location}, METADATA) - def table_exists(self, identifier: Union[str, Identifier]) -> bool: - try: - self.load_table(identifier) - return True - except NoSuchTableError: - return False + def _create_staged_table( + self, + identifier: Union[str, Identifier], + schema: Union[Schema, "pa.Schema"], + location: Optional[str] = None, + partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, + sort_order: SortOrder = UNSORTED_SORT_ORDER, + properties: Properties = EMPTY_DICT, + ) -> StagedTable: + """Create a table and return the table instance without committing the changes. + + Args: + identifier (str | Identifier): Table identifier. + schema (Schema): Table's schema. + location (str | None): Location for the table. Optional Argument. + partition_spec (PartitionSpec): PartitionSpec for the table. + sort_order (SortOrder): SortOrder for the table. + properties (Properties): Table properties that can be a string based dictionary. + + Returns: + StagedTable: the created staged table instance. + """ + schema: Schema = self._convert_schema_if_needed(schema) # type: ignore + + database_name, table_name = self.identifier_to_database_and_table(identifier) + + location = self._resolve_table_location(location, database_name, table_name) + metadata_location = self._get_metadata_location(location=location) + metadata = new_table_metadata( + location=location, schema=schema, partition_spec=partition_spec, sort_order=sort_order, properties=properties + ) + io = load_file_io(properties=self.properties, location=metadata_location) + return StagedTable( + identifier=(self.name, database_name, table_name), + metadata=metadata, + metadata_location=metadata_location, + io=io, + catalog=self, + ) + + def _get_updated_props_and_update_summary( + self, current_properties: Properties, removals: Optional[Set[str]], updates: Properties + ) -> Tuple[PropertiesUpdateSummary, Properties]: + self._check_for_overlap(updates=updates, removals=removals) + updated_properties = dict(current_properties) + + removed: Set[str] = set() + updated: Set[str] = set() + + if removals: + for key in removals: + if key in updated_properties: + updated_properties.pop(key) + removed.add(key) + if updates: + for key, value in updates.items(): + updated_properties[key] = value + updated.add(key) + + expected_to_change = (removals or set()).difference(removed) + properties_update_summary = PropertiesUpdateSummary( + removed=list(removed or []), updated=list(updated or []), missing=list(expected_to_change) + ) + + return properties_update_summary, updated_properties + + def _resolve_table_location(self, location: Optional[str], database_name: str, table_name: str) -> str: + if not location: + return self._get_default_warehouse_location(database_name, table_name) + return location + + def _get_default_warehouse_location(self, database_name: str, table_name: str) -> str: + database_properties = self.load_namespace_properties(database_name) + if database_location := database_properties.get(LOCATION): + database_location = database_location.rstrip("/") + return f"{database_location}/{table_name}" + + if warehouse_path := self.properties.get(WAREHOUSE_LOCATION): + warehouse_path = warehouse_path.rstrip("/") + return f"{warehouse_path}/{database_name}.db/{table_name}" + + raise ValueError("No default path is set, please specify a location when creating a table") @staticmethod def _write_metadata(metadata: TableMetadata, io: FileIO, metadata_path: str) -> None: @@ -691,32 +809,22 @@ def _parse_metadata_version(metadata_location: str) -> int: else: return -1 - def _get_updated_props_and_update_summary( - self, current_properties: Properties, removals: Optional[Set[str]], updates: Properties - ) -> Tuple[PropertiesUpdateSummary, Properties]: - self._check_for_overlap(updates=updates, removals=removals) - updated_properties = dict(current_properties) - - removed: Set[str] = set() - updated: Set[str] = set() - - if removals: - for key in removals: - if key in updated_properties: - updated_properties.pop(key) - removed.add(key) - if updates: - for key, value in updates.items(): - updated_properties[key] = value - updated.add(key) + @staticmethod + def _check_for_overlap(removals: Optional[Set[str]], updates: Properties) -> None: + if updates and removals: + overlap = set(removals) & set(updates.keys()) + if overlap: + raise ValueError(f"Updates and deletes have an overlap: {overlap}") - expected_to_change = (removals or set()).difference(removed) - properties_update_summary = PropertiesUpdateSummary( - removed=list(removed or []), updated=list(updated or []), missing=list(expected_to_change) - ) + @staticmethod + def _empty_table_metadata() -> TableMetadata: + """Return an empty TableMetadata instance. - return properties_update_summary, updated_properties + It is used to build a TableMetadata from a sequence of initial TableUpdates. + It is a V1 TableMetadata because there will be a UpgradeFormatVersionUpdate in + initial changes to bump the metadata to the target version. - def __repr__(self) -> str: - """Return the string representation of the Catalog class.""" - return f"{self.name} ({self.__class__})" + Returns: + TableMetadata: An empty TableMetadata instance. + """ + return TableMetadataV1(location="", last_column_id=-1, schema=Schema()) diff --git a/pyiceberg/catalog/dynamodb.py b/pyiceberg/catalog/dynamodb.py index 266dd6353d..bc5cbede11 100644 --- a/pyiceberg/catalog/dynamodb.py +++ b/pyiceberg/catalog/dynamodb.py @@ -33,7 +33,7 @@ METADATA_LOCATION, PREVIOUS_METADATA_LOCATION, TABLE_TYPE, - Catalog, + MetastoreCatalog, PropertiesUpdateSummary, ) from pyiceberg.exceptions import ( @@ -79,7 +79,7 @@ ITEM = "Item" -class DynamoDbCatalog(Catalog): +class DynamoDbCatalog(MetastoreCatalog): def __init__(self, name: str, **properties: str): super().__init__(name, **properties) session = boto3.Session( diff --git a/pyiceberg/catalog/glue.py b/pyiceberg/catalog/glue.py index adec150d84..e7532677aa 100644 --- a/pyiceberg/catalog/glue.py +++ b/pyiceberg/catalog/glue.py @@ -45,7 +45,7 @@ METADATA_LOCATION, PREVIOUS_METADATA_LOCATION, TABLE_TYPE, - Catalog, + MetastoreCatalog, PropertiesUpdateSummary, ) from pyiceberg.exceptions import ( @@ -62,8 +62,13 @@ from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec from pyiceberg.schema import Schema, SchemaVisitor, visit from pyiceberg.serializers import FromInputFile -from pyiceberg.table import CommitTableRequest, CommitTableResponse, Table, update_table_metadata -from pyiceberg.table.metadata import TableMetadata, new_table_metadata +from pyiceberg.table import ( + CommitTableRequest, + CommitTableResponse, + Table, + update_table_metadata, +) +from pyiceberg.table.metadata import TableMetadata from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder from pyiceberg.typedef import EMPTY_DICT, Identifier, Properties from pyiceberg.types import ( @@ -273,7 +278,7 @@ def add_glue_catalog_id(params: Dict[str, str], **kwargs: Any) -> None: event_system.register("provide-client-params.glue", add_glue_catalog_id) -class GlueCatalog(Catalog): +class GlueCatalog(MetastoreCatalog): def __init__(self, name: str, **properties: Any): super().__init__(name, **properties) @@ -384,20 +389,18 @@ def create_table( ValueError: If the identifier is invalid, or no path is given to store metadata. """ - schema: Schema = self._convert_schema_if_needed(schema) # type: ignore - - database_name, table_name = self.identifier_to_database_and_table(identifier) - - location = self._resolve_table_location(location, database_name, table_name) - metadata_location = self._get_metadata_location(location=location) - metadata = new_table_metadata( - location=location, schema=schema, partition_spec=partition_spec, sort_order=sort_order, properties=properties + staged_table = self._create_staged_table( + identifier=identifier, + schema=schema, + location=location, + partition_spec=partition_spec, + sort_order=sort_order, + properties=properties, ) - io = load_file_io(properties=self.properties, location=metadata_location) - self._write_metadata(metadata, io, metadata_location) - - table_input = _construct_table_input(table_name, metadata_location, properties, metadata) database_name, table_name = self.identifier_to_database_and_table(identifier) + + self._write_metadata(staged_table.metadata, staged_table.io, staged_table.metadata_location) + table_input = _construct_table_input(table_name, staged_table.metadata_location, properties, staged_table.metadata) self._create_glue_table(database_name=database_name, table_name=table_name, table_input=table_input) return self.load_table(identifier=identifier) @@ -435,46 +438,71 @@ def _commit_table(self, table_request: CommitTableRequest) -> CommitTableRespons ) database_name, table_name = self.identifier_to_database_and_table(identifier_tuple) - current_glue_table = self._get_glue_table(database_name=database_name, table_name=table_name) - glue_table_version_id = current_glue_table.get("VersionId") - if not glue_table_version_id: - raise CommitFailedException(f"Cannot commit {database_name}.{table_name} because Glue table version id is missing") - current_table = self._convert_glue_to_iceberg(glue_table=current_glue_table) - base_metadata = current_table.metadata - - # Validate the update requirements - for requirement in table_request.requirements: - requirement.validate(base_metadata) - - updated_metadata = update_table_metadata(base_metadata, table_request.updates) - if updated_metadata == base_metadata: - # no changes, do nothing - return CommitTableResponse(metadata=base_metadata, metadata_location=current_table.metadata_location) - - # write new metadata - new_metadata_version = self._parse_metadata_version(current_table.metadata_location) + 1 - new_metadata_location = self._get_metadata_location(current_table.metadata.location, new_metadata_version) - self._write_metadata(updated_metadata, current_table.io, new_metadata_location) - - update_table_input = _construct_table_input( - table_name=table_name, - metadata_location=new_metadata_location, - properties=current_table.properties, - metadata=updated_metadata, - glue_table=current_glue_table, - prev_metadata_location=current_table.metadata_location, - ) + try: + current_glue_table = self._get_glue_table(database_name=database_name, table_name=table_name) + # Update the table + glue_table_version_id = current_glue_table.get("VersionId") + if not glue_table_version_id: + raise CommitFailedException( + f"Cannot commit {database_name}.{table_name} because Glue table version id is missing" + ) + current_table = self._convert_glue_to_iceberg(glue_table=current_glue_table) + base_metadata = current_table.metadata + + # Validate the update requirements + for requirement in table_request.requirements: + requirement.validate(base_metadata) + + updated_metadata = update_table_metadata(base_metadata=base_metadata, updates=table_request.updates) + if updated_metadata == base_metadata: + # no changes, do nothing + return CommitTableResponse(metadata=base_metadata, metadata_location=current_table.metadata_location) + + # write new metadata + new_metadata_version = self._parse_metadata_version(current_table.metadata_location) + 1 + new_metadata_location = self._get_metadata_location(current_table.metadata.location, new_metadata_version) + self._write_metadata(updated_metadata, current_table.io, new_metadata_location) + + update_table_input = _construct_table_input( + table_name=table_name, + metadata_location=new_metadata_location, + properties=current_table.properties, + metadata=updated_metadata, + glue_table=current_glue_table, + prev_metadata_location=current_table.metadata_location, + ) - # Pass `version_id` to implement optimistic locking: it ensures updates are rejected if concurrent - # modifications occur. See more details at https://iceberg.apache.org/docs/latest/aws/#optimistic-locking - self._update_glue_table( - database_name=database_name, - table_name=table_name, - table_input=update_table_input, - version_id=glue_table_version_id, - ) + # Pass `version_id` to implement optimistic locking: it ensures updates are rejected if concurrent + # modifications occur. See more details at https://iceberg.apache.org/docs/latest/aws/#optimistic-locking + self._update_glue_table( + database_name=database_name, + table_name=table_name, + table_input=update_table_input, + version_id=glue_table_version_id, + ) + + return CommitTableResponse(metadata=updated_metadata, metadata_location=new_metadata_location) + except NoSuchTableError: + # Create the table + updated_metadata = update_table_metadata( + base_metadata=self._empty_table_metadata(), updates=table_request.updates, enforce_validation=True + ) + new_metadata_version = 0 + new_metadata_location = self._get_metadata_location(updated_metadata.location, new_metadata_version) + self._write_metadata( + updated_metadata, self._load_file_io(updated_metadata.properties, new_metadata_location), new_metadata_location + ) + + create_table_input = _construct_table_input( + table_name=table_name, + metadata_location=new_metadata_location, + properties=updated_metadata.properties, + metadata=updated_metadata, + ) + + self._create_glue_table(database_name=database_name, table_name=table_name, table_input=create_table_input) - return CommitTableResponse(metadata=updated_metadata, metadata_location=new_metadata_location) + return CommitTableResponse(metadata=updated_metadata, metadata_location=new_metadata_location) def load_table(self, identifier: Union[str, Identifier]) -> Table: """Load the table's metadata and returns the table instance. diff --git a/pyiceberg/catalog/hive.py b/pyiceberg/catalog/hive.py index 18bbcfe084..359bdef595 100644 --- a/pyiceberg/catalog/hive.py +++ b/pyiceberg/catalog/hive.py @@ -58,7 +58,7 @@ LOCATION, METADATA_LOCATION, TABLE_TYPE, - Catalog, + MetastoreCatalog, PropertiesUpdateSummary, ) from pyiceberg.exceptions import ( @@ -230,7 +230,7 @@ def primitive(self, primitive: PrimitiveType) -> str: return HIVE_PRIMITIVE_TYPES[type(primitive)] -class HiveCatalog(Catalog): +class HiveCatalog(MetastoreCatalog): _client: _HiveClient def __init__(self, name: str, **properties: str): diff --git a/pyiceberg/catalog/noop.py b/pyiceberg/catalog/noop.py index e294390e61..1dfeb952f9 100644 --- a/pyiceberg/catalog/noop.py +++ b/pyiceberg/catalog/noop.py @@ -28,6 +28,7 @@ from pyiceberg.table import ( CommitTableRequest, CommitTableResponse, + CreateTableTransaction, Table, ) from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder @@ -49,9 +50,23 @@ def create_table( ) -> Table: raise NotImplementedError + def create_table_transaction( + self, + identifier: Union[str, Identifier], + schema: Union[Schema, "pa.Schema"], + location: Optional[str] = None, + partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, + sort_order: SortOrder = UNSORTED_SORT_ORDER, + properties: Properties = EMPTY_DICT, + ) -> CreateTableTransaction: + raise NotImplementedError + def load_table(self, identifier: Union[str, Identifier]) -> Table: raise NotImplementedError + def table_exists(self, identifier: Union[str, Identifier]) -> bool: + raise NotImplementedError + def register_table(self, identifier: Union[str, Identifier], metadata_location: str) -> Table: """Register a new table using existing metadata. @@ -70,6 +85,9 @@ def register_table(self, identifier: Union[str, Identifier], metadata_location: def drop_table(self, identifier: Union[str, Identifier]) -> None: raise NotImplementedError + def purge_table(self, identifier: Union[str, Identifier]) -> None: + raise NotImplementedError + def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: Union[str, Identifier]) -> Table: raise NotImplementedError diff --git a/pyiceberg/catalog/rest.py b/pyiceberg/catalog/rest.py index 81a9b09f87..53e3f6a123 100644 --- a/pyiceberg/catalog/rest.py +++ b/pyiceberg/catalog/rest.py @@ -61,6 +61,8 @@ from pyiceberg.table import ( CommitTableRequest, CommitTableResponse, + CreateTableTransaction, + StagedTable, Table, TableIdentifier, ) @@ -135,7 +137,7 @@ def _retry_hook(retry_state: RetryCallState) -> None: class TableResponse(IcebergBaseModel): - metadata_location: str = Field(alias="metadata-location") + metadata_location: Optional[str] = Field(alias="metadata-location") metadata: TableMetadata config: Properties = Field(default_factory=dict) @@ -460,7 +462,18 @@ def add_headers(self, request: PreparedRequest, **kwargs: Any) -> None: # pylin def _response_to_table(self, identifier_tuple: Tuple[str, ...], table_response: TableResponse) -> Table: return Table( identifier=(self.name,) + identifier_tuple if self.name else identifier_tuple, - metadata_location=table_response.metadata_location, + metadata_location=table_response.metadata_location, # type: ignore + metadata=table_response.metadata, + io=self._load_file_io( + {**table_response.metadata.properties, **table_response.config}, table_response.metadata_location + ), + catalog=self, + ) + + def _response_to_staged_table(self, identifier_tuple: Tuple[str, ...], table_response: TableResponse) -> StagedTable: + return StagedTable( + identifier=(self.name,) + identifier_tuple if self.name else identifier_tuple, + metadata_location=table_response.metadata_location, # type: ignore metadata=table_response.metadata, io=self._load_file_io( {**table_response.metadata.properties, **table_response.config}, table_response.metadata_location @@ -490,8 +503,7 @@ def _config_headers(self, session: Session) -> None: def _extract_headers_from_properties(self) -> Dict[str, str]: return {key[len(HEADER_PREFIX) :]: value for key, value in self.properties.items() if key.startswith(HEADER_PREFIX)} - @retry(**_RETRY_ARGS) - def create_table( + def _create_table( self, identifier: Union[str, Identifier], schema: Union[Schema, "pa.Schema"], @@ -499,7 +511,8 @@ def create_table( partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, sort_order: SortOrder = UNSORTED_SORT_ORDER, properties: Properties = EMPTY_DICT, - ) -> Table: + stage_create: bool = False, + ) -> TableResponse: iceberg_schema = self._convert_schema_if_needed(schema) fresh_schema = assign_fresh_schema_ids(iceberg_schema) fresh_partition_spec = assign_fresh_partition_spec_ids(partition_spec, iceberg_schema, fresh_schema) @@ -512,6 +525,7 @@ def create_table( table_schema=fresh_schema, partition_spec=fresh_partition_spec, write_order=fresh_sort_order, + stage_create=stage_create, properties=properties, ) serialized_json = request.model_dump_json().encode(UTF8) @@ -524,9 +538,51 @@ def create_table( except HTTPError as exc: self._handle_non_200_response(exc, {409: TableAlreadyExistsError}) - table_response = TableResponse(**response.json()) + return TableResponse(**response.json()) + + @retry(**_RETRY_ARGS) + def create_table( + self, + identifier: Union[str, Identifier], + schema: Union[Schema, "pa.Schema"], + location: Optional[str] = None, + partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, + sort_order: SortOrder = UNSORTED_SORT_ORDER, + properties: Properties = EMPTY_DICT, + ) -> Table: + table_response = self._create_table( + identifier=identifier, + schema=schema, + location=location, + partition_spec=partition_spec, + sort_order=sort_order, + properties=properties, + stage_create=False, + ) return self._response_to_table(self.identifier_to_tuple(identifier), table_response) + @retry(**_RETRY_ARGS) + def create_table_transaction( + self, + identifier: Union[str, Identifier], + schema: Union[Schema, "pa.Schema"], + location: Optional[str] = None, + partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, + sort_order: SortOrder = UNSORTED_SORT_ORDER, + properties: Properties = EMPTY_DICT, + ) -> CreateTableTransaction: + table_response = self._create_table( + identifier=identifier, + schema=schema, + location=location, + partition_spec=partition_spec, + sort_order=sort_order, + properties=properties, + stage_create=True, + ) + staged_table = self._response_to_staged_table(self.identifier_to_tuple(identifier), table_response) + return CreateTableTransaction(staged_table) + @retry(**_RETRY_ARGS) def register_table(self, identifier: Union[str, Identifier], metadata_location: str) -> Table: """Register a new table using existing metadata. @@ -720,6 +776,14 @@ def update_namespace_properties( @retry(**_RETRY_ARGS) def table_exists(self, identifier: Union[str, Identifier]) -> bool: + """Check if a table exists. + + Args: + identifier (str | Identifier): Table identifier. + + Returns: + bool: True if the table exists, False otherwise. + """ identifier_tuple = self.identifier_to_tuple_without_catalog(identifier) response = self._session.head( self.url(Endpoints.load_table, prefixed=True, **self._split_identifier_for_path(identifier_tuple)) diff --git a/pyiceberg/catalog/sql.py b/pyiceberg/catalog/sql.py index d44d4996b6..978109b2a3 100644 --- a/pyiceberg/catalog/sql.py +++ b/pyiceberg/catalog/sql.py @@ -43,7 +43,7 @@ from pyiceberg.catalog import ( METADATA_LOCATION, - Catalog, + MetastoreCatalog, PropertiesUpdateSummary, ) from pyiceberg.exceptions import ( @@ -93,7 +93,7 @@ class IcebergNamespaceProperties(SqlCatalogBaseTable): property_value: Mapped[str] = mapped_column(String(1000), nullable=False) -class SqlCatalog(Catalog): +class SqlCatalog(MetastoreCatalog): def __init__(self, name: str, **properties: str): super().__init__(name, **properties) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 5f67c05c75..4e968eb616 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -76,6 +76,7 @@ from pyiceberg.partitioning import ( INITIAL_PARTITION_SPEC_ID, PARTITION_FIELD_ID_START, + UNPARTITIONED_PARTITION_SPEC, PartitionField, PartitionSpec, _PartitionNameGenerator, @@ -111,7 +112,7 @@ Summary, update_snapshot_summaries, ) -from pyiceberg.table.sorting import SortOrder +from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder from pyiceberg.transforms import IdentityTransform, TimeTransform, Transform, VoidTransform from pyiceberg.typedef import ( EMPTY_DICT, @@ -144,7 +145,6 @@ from pyiceberg.catalog import Catalog - ALWAYS_TRUE = AlwaysTrue() TABLE_ROOT_ID = -1 @@ -402,6 +402,59 @@ def commit_transaction(self) -> Table: return self._table +class CreateTableTransaction(Transaction): + def _initial_changes(self, table_metadata: TableMetadata) -> None: + """Set the initial changes that can reconstruct the initial table metadata when creating the CreateTableTransaction.""" + self._updates += ( + AssignUUIDUpdate(uuid=table_metadata.table_uuid), + UpgradeFormatVersionUpdate(format_version=table_metadata.format_version), + ) + + schema: Schema = table_metadata.schema() + self._updates += ( + AddSchemaUpdate(schema_=schema, last_column_id=schema.highest_field_id, initial_change=True), + SetCurrentSchemaUpdate(schema_id=-1), + ) + + spec: PartitionSpec = table_metadata.spec() + if spec.is_unpartitioned(): + self._updates += (AddPartitionSpecUpdate(spec=UNPARTITIONED_PARTITION_SPEC, initial_change=True),) + else: + self._updates += (AddPartitionSpecUpdate(spec=spec, initial_change=True),) + self._updates += (SetDefaultSpecUpdate(spec_id=-1),) + + sort_order: Optional[SortOrder] = table_metadata.sort_order_by_id(table_metadata.default_sort_order_id) + if sort_order is None or sort_order.is_unsorted: + self._updates += (AddSortOrderUpdate(sort_order=UNSORTED_SORT_ORDER, initial_change=True),) + else: + self._updates += (AddSortOrderUpdate(sort_order=sort_order, initial_change=True),) + self._updates += (SetDefaultSortOrderUpdate(sort_order_id=-1),) + + self._updates += ( + SetLocationUpdate(location=table_metadata.location), + SetPropertiesUpdate(updates=table_metadata.properties), + ) + + def __init__(self, table: StagedTable): + super().__init__(table, autocommit=False) + self._initial_changes(table.metadata) + + def commit_transaction(self) -> Table: + """Commit the changes to the catalog. + + In the case of a CreateTableTransaction, the only requirement is AssertCreate. + Returns: + The table with the updates applied. + """ + self._requirements = (AssertCreate(),) + return super().commit_transaction() + + +class AssignUUIDUpdate(IcebergBaseModel): + action: Literal['assign-uuid'] = Field(default="assign-uuid") + uuid: uuid.UUID + + class UpgradeFormatVersionUpdate(IcebergBaseModel): action: Literal['upgrade-format-version'] = Field(default="upgrade-format-version") format_version: int = Field(alias="format-version") @@ -413,6 +466,8 @@ class AddSchemaUpdate(IcebergBaseModel): # This field is required: https://github.com/apache/iceberg/pull/7445 last_column_id: int = Field(alias="last-column-id") + initial_change: bool = Field(default=False, exclude=True) + class SetCurrentSchemaUpdate(IcebergBaseModel): action: Literal['set-current-schema'] = Field(default="set-current-schema") @@ -425,6 +480,8 @@ class AddPartitionSpecUpdate(IcebergBaseModel): action: Literal['add-spec'] = Field(default="add-spec") spec: PartitionSpec + initial_change: bool = Field(default=False, exclude=True) + class SetDefaultSpecUpdate(IcebergBaseModel): action: Literal['set-default-spec'] = Field(default="set-default-spec") @@ -437,6 +494,8 @@ class AddSortOrderUpdate(IcebergBaseModel): action: Literal['add-sort-order'] = Field(default="add-sort-order") sort_order: SortOrder = Field(alias="sort-order") + initial_change: bool = Field(default=False, exclude=True) + class SetDefaultSortOrderUpdate(IcebergBaseModel): action: Literal['set-default-sort-order'] = Field(default="set-default-sort-order") @@ -491,6 +550,7 @@ class RemovePropertiesUpdate(IcebergBaseModel): TableUpdate = Annotated[ Union[ + AssignUUIDUpdate, UpgradeFormatVersionUpdate, AddSchemaUpdate, SetCurrentSchemaUpdate, @@ -527,6 +587,9 @@ def is_added_snapshot(self, snapshot_id: int) -> bool: def is_added_schema(self, schema_id: int) -> bool: return any(update.schema_.schema_id == schema_id for update in self._updates if isinstance(update, AddSchemaUpdate)) + def is_added_partition_spec(self, spec_id: int) -> bool: + return any(update.spec.spec_id == spec_id for update in self._updates if isinstance(update, AddPartitionSpecUpdate)) + def is_added_sort_order(self, sort_order_id: int) -> bool: return any( update.sort_order.order_id == sort_order_id for update in self._updates if isinstance(update, AddSortOrderUpdate) @@ -549,8 +612,27 @@ def _apply_table_update(update: TableUpdate, base_metadata: TableMetadata, conte raise NotImplementedError(f"Unsupported table update: {update}") +@_apply_table_update.register(AssignUUIDUpdate) +def _(update: AssignUUIDUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: + if update.uuid == base_metadata.table_uuid: + return base_metadata + + context.add_update(update) + return base_metadata.model_copy(update={"table_uuid": update.uuid}) + + +@_apply_table_update.register(SetLocationUpdate) +def _(update: SetLocationUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: + context.add_update(update) + return base_metadata.model_copy(update={"location": update.location}) + + @_apply_table_update.register(UpgradeFormatVersionUpdate) -def _(update: UpgradeFormatVersionUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: +def _( + update: UpgradeFormatVersionUpdate, + base_metadata: TableMetadata, + context: _TableMetadataUpdateContext, +) -> TableMetadata: if update.format_version > SUPPORTED_TABLE_FORMAT_VERSION: raise ValueError(f"Unsupported table format version: {update.format_version}") elif update.format_version < base_metadata.format_version: @@ -595,13 +677,13 @@ def _(update: AddSchemaUpdate, base_metadata: TableMetadata, context: _TableMeta if update.last_column_id < base_metadata.last_column_id: raise ValueError(f"Invalid last column id {update.last_column_id}, must be >= {base_metadata.last_column_id}") + metadata_updates: Dict[str, Any] = { + "last_column_id": update.last_column_id, + "schemas": [update.schema_] if update.initial_change else base_metadata.schemas + [update.schema_], + } + context.add_update(update) - return base_metadata.model_copy( - update={ - "last_column_id": update.last_column_id, - "schemas": base_metadata.schemas + [update.schema_], - } - ) + return base_metadata.model_copy(update=metadata_updates) @_apply_table_update.register(SetCurrentSchemaUpdate) @@ -627,18 +709,19 @@ def _(update: SetCurrentSchemaUpdate, base_metadata: TableMetadata, context: _Ta @_apply_table_update.register(AddPartitionSpecUpdate) def _(update: AddPartitionSpecUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: for spec in base_metadata.partition_specs: - if spec.spec_id == update.spec.spec_id: + if spec.spec_id == update.spec.spec_id and not update.initial_change: raise ValueError(f"Partition spec with id {spec.spec_id} already exists: {spec}") + + metadata_updates: Dict[str, Any] = { + "partition_specs": [update.spec] if update.initial_change else base_metadata.partition_specs + [update.spec], + "last_partition_id": max( + max([field.field_id for field in update.spec.fields], default=0), + base_metadata.last_partition_id or PARTITION_FIELD_ID_START - 1, + ), + } + context.add_update(update) - return base_metadata.model_copy( - update={ - "partition_specs": base_metadata.partition_specs + [update.spec], - "last_partition_id": max( - max(field.field_id for field in update.spec.fields), - base_metadata.last_partition_id or PARTITION_FIELD_ID_START - 1, - ), - } - ) + return base_metadata.model_copy(update=metadata_updates) @_apply_table_update.register(SetDefaultSpecUpdate) @@ -646,6 +729,8 @@ def _(update: SetDefaultSpecUpdate, base_metadata: TableMetadata, context: _Tabl new_spec_id = update.spec_id if new_spec_id == -1: new_spec_id = max(spec.spec_id for spec in base_metadata.partition_specs) + if not context.is_added_partition_spec(new_spec_id): + raise ValueError("Cannot set current partition spec to last added one when no partition spec has been added") if new_spec_id == base_metadata.default_spec_id: return base_metadata found_spec_id = False @@ -736,13 +821,17 @@ def _(update: AddSortOrderUpdate, base_metadata: TableMetadata, context: _TableM context.add_update(update) return base_metadata.model_copy( update={ - "sort_orders": base_metadata.sort_orders + [update.sort_order], + "sort_orders": [update.sort_order] if update.initial_change else base_metadata.sort_orders + [update.sort_order], } ) @_apply_table_update.register(SetDefaultSortOrderUpdate) -def _(update: SetDefaultSortOrderUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: +def _( + update: SetDefaultSortOrderUpdate, + base_metadata: TableMetadata, + context: _TableMetadataUpdateContext, +) -> TableMetadata: new_sort_order_id = update.sort_order_id if new_sort_order_id == -1: # The last added sort order should be in base_metadata.sort_orders at this point @@ -761,12 +850,15 @@ def _(update: SetDefaultSortOrderUpdate, base_metadata: TableMetadata, context: return base_metadata.model_copy(update={"default_sort_order_id": new_sort_order_id}) -def update_table_metadata(base_metadata: TableMetadata, updates: Tuple[TableUpdate, ...]) -> TableMetadata: +def update_table_metadata( + base_metadata: TableMetadata, updates: Tuple[TableUpdate, ...], enforce_validation: bool = False +) -> TableMetadata: """Update the table metadata with the given updates in one transaction. Args: base_metadata: The base metadata to be updated. updates: The updates in one transaction. + enforce_validation: Whether to trigger validation after applying the updates. Returns: The metadata with the updates applied. @@ -777,7 +869,10 @@ def update_table_metadata(base_metadata: TableMetadata, updates: Tuple[TableUpda for update in updates: new_metadata = _apply_table_update(update, new_metadata, context) - return new_metadata.model_copy(deep=True) + if enforce_validation: + return TableMetadataUtil.parse_obj(new_metadata.model_dump()) + else: + return new_metadata.model_copy(deep=True) class ValidatableTableRequirement(IcebergBaseModel): @@ -1287,6 +1382,25 @@ def from_metadata(cls, metadata_location: str, properties: Properties = EMPTY_DI ) +class StagedTable(Table): + def refresh(self) -> Table: + raise ValueError("Cannot refresh a staged table") + + def scan( + self, + row_filter: Union[str, BooleanExpression] = ALWAYS_TRUE, + selected_fields: Tuple[str, ...] = ("*",), + case_sensitive: bool = True, + snapshot_id: Optional[int] = None, + options: Properties = EMPTY_DICT, + limit: Optional[int] = None, + ) -> DataScan: + raise ValueError("Cannot scan a staged table") + + def to_daft(self) -> daft.DataFrame: + raise ValueError("Cannot convert a staged table to a Daft DataFrame") + + def _parse_row_filter(expr: Union[str, BooleanExpression]) -> BooleanExpression: """Accept an expression in the form of a BooleanExpression or a string. diff --git a/pyiceberg/table/metadata.py b/pyiceberg/table/metadata.py index 3e1acf95f1..2e20c5092f 100644 --- a/pyiceberg/table/metadata.py +++ b/pyiceberg/table/metadata.py @@ -412,7 +412,7 @@ def to_v2(self) -> TableMetadataV2: """The table’s current schema. (Deprecated: use schemas and current-schema-id instead).""" - partition_spec: List[Dict[str, Any]] = Field(alias="partition-spec") + partition_spec: List[Dict[str, Any]] = Field(alias="partition-spec", default_factory=list) """The table’s current partition spec, stored as only fields. Note that this is used by writers to partition data, but is not used when reading because reads use the specs stored in diff --git a/tests/catalog/integration_test_glue.py b/tests/catalog/integration_test_glue.py index a685b7da7b..5cd60225c4 100644 --- a/tests/catalog/integration_test_glue.py +++ b/tests/catalog/integration_test_glue.py @@ -24,7 +24,7 @@ import pytest from botocore.exceptions import ClientError -from pyiceberg.catalog import Catalog +from pyiceberg.catalog import Catalog, MetastoreCatalog from pyiceberg.catalog.glue import GlueCatalog from pyiceberg.exceptions import ( NamespaceAlreadyExistsError, @@ -35,6 +35,7 @@ ) from pyiceberg.io.pyarrow import schema_to_pyarrow from pyiceberg.schema import Schema +from pyiceberg.table import _dataframe_to_data_files from pyiceberg.types import IntegerType from tests.conftest import clean_up, get_bucket_name, get_s3_path @@ -120,7 +121,7 @@ def test_create_table( assert table.identifier == (CATALOG_NAME,) + identifier metadata_location = table.metadata_location.split(get_bucket_name())[1][1:] s3.head_object(Bucket=get_bucket_name(), Key=metadata_location) - assert test_catalog._parse_metadata_version(table.metadata_location) == 0 + assert MetastoreCatalog._parse_metadata_version(table.metadata_location) == 0 table.append( pa.Table.from_pylist( @@ -184,7 +185,7 @@ def test_create_table_with_default_location( assert table.identifier == (CATALOG_NAME,) + identifier metadata_location = table.metadata_location.split(get_bucket_name())[1][1:] s3.head_object(Bucket=get_bucket_name(), Key=metadata_location) - assert test_catalog._parse_metadata_version(table.metadata_location) == 0 + assert MetastoreCatalog._parse_metadata_version(table.metadata_location) == 0 def test_create_table_with_invalid_database(test_catalog: Catalog, table_schema_nested: Schema, table_name: str) -> None: @@ -217,7 +218,7 @@ def test_load_table(test_catalog: Catalog, table_schema_nested: Schema, table_na assert table.identifier == loaded_table.identifier assert table.metadata_location == loaded_table.metadata_location assert table.metadata == loaded_table.metadata - assert test_catalog._parse_metadata_version(table.metadata_location) == 0 + assert MetastoreCatalog._parse_metadata_version(table.metadata_location) == 0 def test_list_tables(test_catalog: Catalog, table_schema_nested: Schema, database_name: str, table_list: List[str]) -> None: @@ -239,7 +240,7 @@ def test_rename_table( new_table_name = f"rename-{table_name}" identifier = (database_name, table_name) table = test_catalog.create_table(identifier, table_schema_nested) - assert test_catalog._parse_metadata_version(table.metadata_location) == 0 + assert MetastoreCatalog._parse_metadata_version(table.metadata_location) == 0 assert table.identifier == (CATALOG_NAME,) + identifier new_identifier = (new_database_name, new_table_name) test_catalog.rename_table(identifier, new_identifier) @@ -385,7 +386,7 @@ def test_commit_table_update_schema( table = test_catalog.create_table(identifier, table_schema_nested) original_table_metadata = table.metadata - assert test_catalog._parse_metadata_version(table.metadata_location) == 0 + assert MetastoreCatalog._parse_metadata_version(table.metadata_location) == 0 assert original_table_metadata.current_schema_id == 0 assert athena.get_query_results(f'SELECT * FROM "{database_name}"."{table_name}"') == [ @@ -410,7 +411,7 @@ def test_commit_table_update_schema( updated_table_metadata = table.metadata - assert test_catalog._parse_metadata_version(table.metadata_location) == 1 + assert MetastoreCatalog._parse_metadata_version(table.metadata_location) == 1 assert updated_table_metadata.current_schema_id == 1 assert len(updated_table_metadata.schemas) == 2 new_schema = next(schema for schema in updated_table_metadata.schemas if schema.schema_id == 1) @@ -466,7 +467,7 @@ def test_commit_table_properties(test_catalog: Catalog, table_schema_nested: Sch test_catalog.create_namespace(namespace=database_name) table = test_catalog.create_table(identifier=identifier, schema=table_schema_nested, properties={"test_a": "test_a"}) - assert test_catalog._parse_metadata_version(table.metadata_location) == 0 + assert MetastoreCatalog._parse_metadata_version(table.metadata_location) == 0 transaction = table.transaction() transaction.set_properties(test_a="test_aa", test_b="test_b", test_c="test_c") @@ -474,5 +475,78 @@ def test_commit_table_properties(test_catalog: Catalog, table_schema_nested: Sch transaction.commit_transaction() updated_table_metadata = table.metadata - assert test_catalog._parse_metadata_version(table.metadata_location) == 1 + assert MetastoreCatalog._parse_metadata_version(table.metadata_location) == 1 assert updated_table_metadata.properties == {"test_a": "test_aa", "test_c": "test_c"} + + +@pytest.mark.parametrize("format_version", [1, 2]) +def test_create_table_transaction( + test_catalog: Catalog, + s3: boto3.client, + table_schema_nested: Schema, + table_name: str, + database_name: str, + athena: AthenaQueryHelper, + format_version: int, +) -> None: + identifier = (database_name, table_name) + test_catalog.create_namespace(database_name) + + with test_catalog.create_table_transaction( + identifier, + table_schema_nested, + get_s3_path(get_bucket_name(), database_name, table_name), + properties={"format-version": format_version}, + ) as txn: + df = pa.Table.from_pylist( + [ + { + "foo": "foo_val", + "bar": 1, + "baz": False, + "qux": ["x", "y"], + "quux": {"key": {"subkey": 2}}, + "location": [{"latitude": 1.1}], + "person": {"name": "some_name", "age": 23}, + } + ], + schema=schema_to_pyarrow(txn.table_metadata.schema()), + ) + + with txn.update_snapshot().fast_append() as update_snapshot: + data_files = _dataframe_to_data_files( + table_metadata=txn.table_metadata, write_uuid=update_snapshot.commit_uuid, df=df, io=txn._table.io + ) + for data_file in data_files: + update_snapshot.append_data_file(data_file) + + table = test_catalog.load_table(identifier) + assert table.identifier == (CATALOG_NAME,) + identifier + metadata_location = table.metadata_location.split(get_bucket_name())[1][1:] + s3.head_object(Bucket=get_bucket_name(), Key=metadata_location) + assert MetastoreCatalog._parse_metadata_version(table.metadata_location) == 0 + + assert athena.get_query_results(f'SELECT * FROM "{database_name}"."{table_name}"') == [ + { + "Data": [ + {"VarCharValue": "foo"}, + {"VarCharValue": "bar"}, + {"VarCharValue": "baz"}, + {"VarCharValue": "qux"}, + {"VarCharValue": "quux"}, + {"VarCharValue": "location"}, + {"VarCharValue": "person"}, + ] + }, + { + "Data": [ + {"VarCharValue": "foo_val"}, + {"VarCharValue": "1"}, + {"VarCharValue": "false"}, + {"VarCharValue": "[x, y]"}, + {"VarCharValue": "{key={subkey=2}}"}, + {"VarCharValue": "[{latitude=1.1, longitude=null}]"}, + {"VarCharValue": "{name=some_name, age=23}"}, + ] + }, + ] diff --git a/tests/catalog/test_base.py b/tests/catalog/test_base.py index 5f78eb3bc4..8ea04e3fca 100644 --- a/tests/catalog/test_base.py +++ b/tests/catalog/test_base.py @@ -34,6 +34,7 @@ from pyiceberg.catalog import ( Catalog, + MetastoreCatalog, PropertiesUpdateSummary, ) from pyiceberg.exceptions import ( @@ -65,7 +66,7 @@ DEFAULT_WAREHOUSE_LOCATION = "file:///tmp/warehouse" -class InMemoryCatalog(Catalog): +class InMemoryCatalog(MetastoreCatalog): """ An in-memory catalog implementation that uses in-memory data-structures to store the namespaces and tables. diff --git a/tests/catalog/test_glue.py b/tests/catalog/test_glue.py index d4ed085c51..8aa4918636 100644 --- a/tests/catalog/test_glue.py +++ b/tests/catalog/test_glue.py @@ -33,7 +33,9 @@ TableAlreadyExistsError, ) from pyiceberg.io.pyarrow import schema_to_pyarrow +from pyiceberg.partitioning import PartitionField, PartitionSpec from pyiceberg.schema import Schema +from pyiceberg.transforms import IdentityTransform from pyiceberg.types import IntegerType from tests.conftest import BUCKET_NAME, TABLE_METADATA_LOCATION_REGEX @@ -758,3 +760,49 @@ def test_commit_overwrite_table_snapshot_properties( assert summary is not None assert summary["snapshot_prop_a"] is None assert summary["snapshot_prop_b"] == "test_prop_b" + + +@mock_aws +@pytest.mark.parametrize("format_version", [1, 2]) +def test_create_table_transaction( + _glue: boto3.client, + _bucket_initialize: None, + moto_endpoint_url: str, + table_schema_nested: Schema, + database_name: str, + table_name: str, + format_version: int, +) -> None: + catalog_name = "glue" + identifier = (database_name, table_name) + test_catalog = GlueCatalog(catalog_name, **{"s3.endpoint": moto_endpoint_url, "warehouse": f"s3://{BUCKET_NAME}"}) + test_catalog.create_namespace(namespace=database_name) + + with test_catalog.create_table_transaction( + identifier, + table_schema_nested, + partition_spec=PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="foo")), + properties={"format-version": format_version}, + ) as txn: + with txn.update_schema() as update_schema: + update_schema.add_column(path="b", field_type=IntegerType()) + + with txn.update_spec() as update_spec: + update_spec.add_identity("bar") + + txn.set_properties(test_a="test_aa", test_b="test_b", test_c="test_c") + + table = test_catalog.load_table(identifier) + + assert TABLE_METADATA_LOCATION_REGEX.match(table.metadata_location) + assert test_catalog._parse_metadata_version(table.metadata_location) == 0 + assert table.format_version == format_version + assert table.schema().find_field("b").field_type == IntegerType() + assert table.properties == {"test_a": "test_aa", "test_b": "test_b", "test_c": "test_c"} + assert table.spec().last_assigned_field_id == 1001 + assert table.spec().fields_by_source_id(1)[0].name == "foo" + assert table.spec().fields_by_source_id(1)[0].field_id == 1000 + assert table.spec().fields_by_source_id(1)[0].transform == IdentityTransform() + assert table.spec().fields_by_source_id(2)[0].name == "bar" + assert table.spec().fields_by_source_id(2)[0].field_id == 1001 + assert table.spec().fields_by_source_id(2)[0].transform == IdentityTransform() diff --git a/tests/integration/test_writes.py b/tests/integration/test_writes.py index 0186e662dc..e8ad6b08fa 100644 --- a/tests/integration/test_writes.py +++ b/tests/integration/test_writes.py @@ -680,6 +680,60 @@ def test_write_and_evolve(session_catalog: Catalog, format_version: int) -> None snapshot_update.append_data_file(data_file) +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [2]) +def test_create_table_transaction(session_catalog: Catalog, format_version: int) -> None: + if format_version == 1: + pytest.skip( + "There is a bug in the REST catalog (maybe server side) that prevents create and commit a staged version 1 table" + ) + + identifier = f"default.arrow_create_table_transaction{format_version}" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + + pa_table = pa.Table.from_pydict( + { + 'foo': ['a', None, 'z'], + }, + schema=pa.schema([pa.field("foo", pa.string(), nullable=True)]), + ) + + pa_table_with_column = pa.Table.from_pydict( + { + 'foo': ['a', None, 'z'], + 'bar': [19, None, 25], + }, + schema=pa.schema([ + pa.field("foo", pa.string(), nullable=True), + pa.field("bar", pa.int32(), nullable=True), + ]), + ) + + with session_catalog.create_table_transaction( + identifier=identifier, schema=pa_table.schema, properties={"format-version": str(format_version)} + ) as txn: + with txn.update_snapshot().fast_append() as snapshot_update: + for data_file in _dataframe_to_data_files(table_metadata=txn.table_metadata, df=pa_table, io=txn._table.io): + snapshot_update.append_data_file(data_file) + + with txn.update_schema() as schema_txn: + schema_txn.union_by_name(pa_table_with_column.schema) + + with txn.update_snapshot().fast_append() as snapshot_update: + for data_file in _dataframe_to_data_files( + table_metadata=txn.table_metadata, df=pa_table_with_column, io=txn._table.io + ): + snapshot_update.append_data_file(data_file) + + tbl = session_catalog.load_table(identifier=identifier) + assert tbl.format_version == format_version + assert len(tbl.scan().to_arrow()) == 6 + + @pytest.mark.integration @pytest.mark.parametrize("format_version", [1, 2]) def test_table_properties_int_value(