Skip to content

Commit

Permalink
Implement create_table_if_not_exists (#415)
Browse files Browse the repository at this point in the history
* Feat: Add fail_if_exists param to create_table

* create create_table_if_not_exists method

* fix reset test

* fix mypy check
  • Loading branch information
hussein-awala authored Feb 20, 2024
1 parent fd9dc88 commit 9b01248
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 2 deletions.
30 changes: 29 additions & 1 deletion pyiceberg/catalog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
cast,
)

from pyiceberg.exceptions import NoSuchNamespaceError, NoSuchTableError, NotInstalledError
from pyiceberg.exceptions import NoSuchNamespaceError, NoSuchTableError, NotInstalledError, TableAlreadyExistsError
from pyiceberg.io import FileIO, load_file_io
from pyiceberg.manifest import ManifestFile
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec
Expand Down Expand Up @@ -315,6 +315,34 @@ def create_table(
TableAlreadyExistsError: If a table with the name already exists.
"""

def create_table_if_not_exists(
self,
identifier: Union[str, Identifier],
schema: Union[Schema, "pa.Schema"],
location: Optional[str] = None,
partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
sort_order: SortOrder = UNSORTED_SORT_ORDER,
properties: Properties = EMPTY_DICT,
) -> Table:
"""Create a table if it does not exist.
Args:
identifier (str | Identifier): Table identifier.
schema (Schema): Table's schema.
location (str | None): Location for the table. Optional Argument.
partition_spec (PartitionSpec): PartitionSpec for the table.
sort_order (SortOrder): SortOrder for the table.
properties (Properties): Table properties that can be a string based dictionary.
Returns:
Table: the created table instance if the table does not exist, else the existing
table instance.
"""
try:
return self.create_table(identifier, schema, location, partition_spec, sort_order, properties)
except TableAlreadyExistsError:
return self.load_table(identifier)

@abstractmethod
def load_table(self, identifier: Union[str, Identifier]) -> Table:
"""Load the table's metadata and returns the table instance.
Expand Down
9 changes: 9 additions & 0 deletions tests/catalog/integration_test_dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,15 @@ def test_create_duplicated_table(test_catalog: Catalog, table_schema_nested: Sch
test_catalog.create_table((database_name, table_name), table_schema_nested)


def test_create_table_if_not_exists_duplicated_table(
test_catalog: Catalog, table_schema_nested: Schema, database_name: str, table_name: str
) -> None:
test_catalog.create_namespace(database_name)
table1 = test_catalog.create_table((database_name, table_name), table_schema_nested)
table2 = test_catalog.create_table_if_not_exists((database_name, table_name), table_schema_nested)
assert table1.identifier == table2.identifier


def test_load_table(test_catalog: Catalog, table_schema_nested: Schema, database_name: str, table_name: str) -> None:
identifier = (database_name, table_name)
test_catalog.create_namespace(database_name)
Expand Down
9 changes: 9 additions & 0 deletions tests/catalog/integration_test_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,15 @@ def test_create_duplicated_table(test_catalog: Catalog, table_schema_nested: Sch
test_catalog.create_table((database_name, table_name), table_schema_nested)


def test_create_table_if_not_exists_duplicated_table(
test_catalog: Catalog, table_schema_nested: Schema, table_name: str, database_name: str
) -> None:
test_catalog.create_namespace(database_name)
table1 = test_catalog.create_table((database_name, table_name), table_schema_nested)
table2 = test_catalog.create_table_if_not_exists((database_name, table_name), table_schema_nested)
assert table1.identifier == table2.identifier


def test_load_table(test_catalog: Catalog, table_schema_nested: Schema, table_name: str, database_name: str) -> None:
identifier = (database_name, table_name)
test_catalog.create_namespace(database_name)
Expand Down
12 changes: 12 additions & 0 deletions tests/catalog/test_dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,18 @@ def test_create_duplicated_table(
test_catalog.create_table(identifier, table_schema_nested)


@mock_aws
def test_create_table_if_not_exists_duplicated_table(
_bucket_initialize: None, moto_endpoint_url: str, table_schema_nested: Schema, database_name: str, table_name: str
) -> None:
identifier = (database_name, table_name)
test_catalog = DynamoDbCatalog("test_ddb_catalog", **{"warehouse": f"s3://{BUCKET_NAME}", "s3.endpoint": moto_endpoint_url})
test_catalog.create_namespace(namespace=database_name)
table1 = test_catalog.create_table(identifier, table_schema_nested)
table2 = test_catalog.create_table_if_not_exists(identifier, table_schema_nested)
assert table1.identifier == table2.identifier


@mock_aws
def test_load_table(
_bucket_initialize: None, moto_endpoint_url: str, table_schema_nested: Schema, database_name: str, table_name: str
Expand Down
60 changes: 59 additions & 1 deletion tests/catalog/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
# pylint: disable=redefined-outer-name,unused-argument
import os
from typing import Any, Dict, cast
from typing import Any, Callable, Dict, cast
from unittest import mock

import pytest
Expand Down Expand Up @@ -560,6 +560,64 @@ def test_create_table_409(rest_mock: Mocker, table_schema_simple: Schema) -> Non
assert "Table already exists" in str(e.value)


def test_create_table_if_not_exists_200(
rest_mock: Mocker, table_schema_simple: Schema, example_table_metadata_no_snapshot_v1_rest_json: Dict[str, Any]
) -> None:
def json_callback() -> Callable[[Any, Any], Dict[str, Any]]:
call_count = 0

def callback(request: Any, context: Any) -> Dict[str, Any]:
nonlocal call_count
call_count += 1

if call_count == 1:
context.status_code = 200
return example_table_metadata_no_snapshot_v1_rest_json
else:
context.status_code = 409
return {
"error": {
"message": "Table already exists: fokko.already_exists in warehouse 8bcb0838-50fc-472d-9ddb-8feb89ef5f1e",
"type": "AlreadyExistsException",
"code": 409,
}
}

return callback

rest_mock.post(
f"{TEST_URI}v1/namespaces/fokko/tables",
json=json_callback(),
request_headers=TEST_HEADERS,
)
rest_mock.get(
f"{TEST_URI}v1/namespaces/fokko/tables/fokko2",
json=example_table_metadata_no_snapshot_v1_rest_json,
status_code=200,
request_headers=TEST_HEADERS,
)
catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN)
table1 = catalog.create_table(
identifier=("fokko", "fokko2"),
schema=table_schema_simple,
location=None,
partition_spec=PartitionSpec(
PartitionField(source_id=1, field_id=1000, transform=TruncateTransform(width=3), name="id"), spec_id=1
),
sort_order=SortOrder(SortField(source_id=2, transform=IdentityTransform())),
properties={"owner": "fokko"},
)
table2 = catalog.create_table_if_not_exists(
identifier=("fokko", "fokko2"),
schema=table_schema_simple,
location=None,
partition_spec=PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=TruncateTransform(width=3), name="id")),
sort_order=SortOrder(SortField(source_id=2, transform=IdentityTransform())),
properties={"owner": "fokko"},
)
assert table1 == table2


def test_create_table_419(rest_mock: Mocker, table_schema_simple: Schema) -> None:
rest_mock.post(
f"{TEST_URI}v1/namespaces/fokko/tables",
Expand Down
17 changes: 17 additions & 0 deletions tests/catalog/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,23 @@ def test_create_duplicated_table(catalog: SqlCatalog, table_schema_nested: Schem
catalog.create_table(random_identifier, table_schema_nested)


@pytest.mark.parametrize(
'catalog',
[
lazy_fixture('catalog_memory'),
lazy_fixture('catalog_sqlite'),
],
)
def test_create_table_if_not_exists_duplicated_table(
catalog: SqlCatalog, table_schema_nested: Schema, random_identifier: Identifier
) -> None:
database_name, _table_name = random_identifier
catalog.create_namespace(database_name)
table1 = catalog.create_table(random_identifier, table_schema_nested)
table2 = catalog.create_table_if_not_exists(random_identifier, table_schema_nested)
assert table1.identifier == table2.identifier


@pytest.mark.parametrize(
'catalog',
[
Expand Down

0 comments on commit 9b01248

Please sign in to comment.