diff --git a/ldb/dataset.py b/ldb/dataset.py index 5b966712..712ba7e9 100644 --- a/ldb/dataset.py +++ b/ldb/dataset.py @@ -314,18 +314,25 @@ def check_datasets_for_data_objects( error: bool = True, ) -> Iterator[str]: # collect data objects from datasets - ds_version_idents = get_all_dataset_version_identifiers(ldb_dir) - objects_in_datasets = { - d - for version_id in ds_version_idents.keys() - for d in get_collection(ldb_dir, version_id) - } + collection_identifiers = get_all_dataset_version_identifiers(ldb_dir) + obj_to_identifiers = defaultdict(list) + for collection_id, ds_identifiers in collection_identifiers.items(): + if ds_identifiers: + for data_obj_hash in load_data_file( + get_hash_path( + ldb_dir / InstanceDir.COLLECTIONS, + collection_id, + ), + ): + obj_to_identifiers[data_obj_hash].extend(ds_identifiers) for data_obj_hash in data_object_hashes: - if data_obj_hash in objects_in_datasets: + obj_ds_identifiers = obj_to_identifiers.get(data_obj_hash) + if obj_ds_identifiers is not None: if error: + ds_ident_str = "\n".join(f" {i}" for i in obj_ds_identifiers) raise LDBException( - f"Data object id:{data_obj_hash} is contained in a saved " - "dataset", + f"Data object id:{data_obj_hash} is contained in saved " + f"datasets:\n{ds_ident_str}", ) else: yield data_obj_hash @@ -342,15 +349,21 @@ def get_all_dataset_version_identifiers(ldb_dir: Path) -> Dict[str, List[str]]: "ab3...": [], } """ + ds_version_dir = os.path.join(ldb_dir, InstanceDir.DATASET_VERSIONS) + collection_dir = os.path.join(ldb_dir, InstanceDir.COLLECTIONS) result = defaultdict(list) for dataset in iter_datasets(ldb_dir): for i, version_id in dataset.numbered_versions().items(): - result[version_id].append( + version_path = get_hash_path( + Path(ds_version_dir), + version_id, + ) + version_obj = DatasetVersion.parse(load_data_file(version_path)) + result[version_obj.collection].append( format_dataset_identifier(dataset.name, i), ) - ds_version_dir = str(ldb_dir / InstanceDir.DATASET_VERSIONS) - for parent in os.listdir(ds_version_dir): - for filename in os.listdir(os.path.join(ds_version_dir, parent)): + for parent in os.listdir(collection_dir): + for filename in os.listdir(os.path.join(collection_dir, parent)): result[ # pylint: disable=pointless-statement f"{parent}{filename}" ] diff --git a/ldb/ds.py b/ldb/ds.py index 7b2e97d7..65d67c3e 100644 --- a/ldb/ds.py +++ b/ldb/ds.py @@ -3,10 +3,20 @@ from pathlib import Path from typing import Iterable, Iterator -from ldb.dataset import iter_datasets +from ldb.dataset import ( + Dataset, + DatasetVersion, + get_all_dataset_version_identifiers, + iter_datasets, +) from ldb.exceptions import DatasetNotFoundError from ldb.path import InstanceDir -from ldb.utils import format_dataset_identifier, parse_dataset_identifier +from ldb.utils import ( + format_dataset_identifier, + get_hash_path, + load_data_file, + parse_dataset_identifier, +) @dataclass @@ -46,6 +56,8 @@ def print_ds_listings( def delete_datasets(ldb_dir: Path, ds_identifiers: Iterable[str]) -> None: ds_dir = os.path.join(ldb_dir, InstanceDir.DATASETS) + ds_version_dir = os.path.join(ldb_dir, InstanceDir.DATASET_VERSIONS) + collection_dir = os.path.join(ldb_dir, InstanceDir.COLLECTIONS) ds_info = [] for ds_ident in ds_identifiers: name, version = parse_dataset_identifier(ds_ident) @@ -63,11 +75,36 @@ def delete_datasets(ldb_dir: Path, ds_identifiers: Iterable[str]) -> None: f"Dataset not found: {ds_ident}", ) + collection_identifiers = get_all_dataset_version_identifiers(ldb_dir) for ds_ident, path in ds_info: try: - os.unlink(path) + dataset = Dataset.parse(load_data_file(Path(path))) except FileNotFoundError as exc: raise DatasetNotFoundError( f"Dataset not found: {ds_ident}", ) from exc + for i, ds_version in dataset.numbered_versions().items(): + version_path = get_hash_path( + Path(ds_version_dir), + ds_version, + ) + version_obj = DatasetVersion.parse(load_data_file(version_path)) + ds_version_ident = format_dataset_identifier(dataset.name, i) + refs = collection_identifiers[version_obj.collection] + refs.remove(ds_version_ident) + if not refs: + collection_path = str( + get_hash_path( + Path(collection_dir), + version_obj.collection, + ), + ) + os.unlink(collection_path) + try: + os.rmdir( + os.path.split(collection_path.rstrip(os.path.sep))[0], + ) + except OSError: + pass + os.unlink(path) print(f"Deleted {ds_ident}")