diff --git a/daft/table/table_io.py b/daft/table/table_io.py index 8ff1fc6c2d..4097b58dce 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -29,7 +29,11 @@ ) from daft.datatype import DataType from daft.expressions import ExpressionsProjection -from daft.filesystem import _resolve_paths_and_filesystem +from daft.filesystem import ( + _resolve_paths_and_filesystem, + canonicalize_protocol, + get_protocol_from_path, +) from daft.logical.schema import Schema from daft.runners.partitioning import ( TableParseCSVOptions, @@ -359,6 +363,15 @@ def write_tabular( from daft.utils import ARROW_VERSION [resolved_path], fs = _resolve_paths_and_filesystem(path, io_config=io_config) + if isinstance(path, pathlib.Path): + path_str = str(path) + else: + path_str = path + + protocol = get_protocol_from_path(path_str) + canonicalized_protocol = canonicalize_protocol(protocol) + + is_local_fs = canonicalized_protocol == "file" tables_to_write: list[MicroPartition] part_keys_postfix_per_table: list[str | None] @@ -413,9 +426,6 @@ def write_tabular( if pf is not None and len(pf) > 0: full_path = f"{full_path}/{pf}" - # TODO: For overwriting behavior, check here if dir exists to determine to delete files - # fs.create_dir(full_path) - arrow_table = tab.to_arrow() size_bytes = arrow_table.nbytes @@ -439,6 +449,9 @@ def file_visitor(written_file, i=i): kwargs["min_rows_per_group"] = rows_per_row_group kwargs["max_rows_per_group"] = rows_per_row_group + if ARROW_VERSION >= (8, 0, 0) and not is_local_fs: + kwargs["create_dir"] = False + pads.write_dataset( arrow_table, base_dir=full_path,