Skip to content

Commit

Permalink
finalize import/export and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mhmdk0 committed Dec 8, 2024
1 parent 846ba7a commit 7465533
Show file tree
Hide file tree
Showing 6 changed files with 403 additions and 73 deletions.
6 changes: 4 additions & 2 deletions cli/medperf/commands/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from medperf.commands.dataset.prepare import DataPreparation
from medperf.commands.dataset.set_operational import DatasetSetOperational
from medperf.commands.dataset.associate import AssociateDataset
from medperf.commands.dataset.manage import ImportDataset, ExportDataset
from medperf.commands.dataset.import_dataset import ImportDataset
from medperf.commands.dataset.export_dataset import ExportDataset


app = typer.Typer()

Expand Down Expand Up @@ -186,7 +188,7 @@ def import_dataset(
help="Path of the tar.gz file (dataset backup) to be imported.",
),
raw_path: str = typer.Option(
"Folder containing the tar.gz file",
None,
"--raw_dataset_path",
help="New path of the DEVELOPMENT dataset raw data to be saved.",
),
Expand Down
60 changes: 60 additions & 0 deletions cli/medperf/commands/dataset/export_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import os
from medperf.entities.dataset import Dataset
from medperf.utils import tar
import medperf.config as config
from medperf.exceptions import ExecutionError
import yaml


class ExportDataset:
@classmethod
def run(cls, dataset_id: str, output_path: str):
export_dataset = cls(dataset_id, output_path)
export_dataset.prepare()
export_dataset.create_tar()

def __init__(self, dataset_id: str, output_path: str):
self.dataset_id = dataset_id
self.output_path = os.path.join(output_path, dataset_id) + ".gz"
self.folders_paths = []
self.dataset = Dataset.get(self.dataset_id)
self.dataset_storage = self.dataset.get_storage_path()
self.dataset_tar_folder = (
config.dataset_backup_foldername + self.dataset_id
) # name of the folder that will contain the backup

def _prepare_development_dataset(self):
raw_data_paths = self.dataset.get_raw_paths()

for folder in raw_data_paths:
# checks if raw_data_path exists and not empty
if not (os.path.exists(folder) and os.listdir(folder)):
raise ExecutionError(f"Cannot find raw data paths at '{folder}'")
self.folders_paths.append(folder)

data_path, labels_path = raw_data_paths
self.paths["data"] = os.path.basename(data_path)
self.paths["labels"] = os.path.basename(labels_path)

def prepare(self):
# Gets server name to be added in paths.yaml for comparing between local and remote servers
# which will save folders names (what each one points to.
dataset_path = os.path.join(self.dataset_storage, self.dataset_id)
self.folders_paths.append(dataset_path)
self.paths = {"server": config.server, "dataset": self.dataset_id}

# If the dataset is in development, it'll need the raw paths as well.
if self.dataset.state == "DEVELOPMENT":
self._prepare_development_dataset()

paths_path = os.path.join(config.tmp_folder, config.backup_config_filename)

# paths.yaml will be created in medperf tmp directory
with open(paths_path, "w") as f:
yaml.dump(self.paths, f)

self.folders_paths.append(paths_path)
config.tmp_paths.append(paths_path)

def create_tar(self):
tar(self.output_path, self.folders_paths, self.dataset_tar_folder)
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from medperf.entities.dataset import Dataset
from medperf.utils import (
untar,
tar,
move_folder,
copy_file,
remove_path,
Expand All @@ -15,7 +14,7 @@

class ImportDataset:
@classmethod
def run(cls, dataset_id: str, input_path: str, raw_data_path: str):
def run(cls, dataset_id: str, input_path: str, raw_data_path):
import_dataset = cls(dataset_id, input_path, raw_data_path)
import_dataset.validate_input()
import_dataset.untar_files()
Expand All @@ -24,17 +23,13 @@ def run(cls, dataset_id: str, input_path: str, raw_data_path: str):
import_dataset.prepare_tarfiles()
import_dataset.process_tarfiles()

def __init__(self, dataset_id: str, input_path: str, raw_data_path: str):
def __init__(self, dataset_id: str, input_path: str, raw_data_path):
self.dataset_id = dataset_id
self.input_path = input_path
self.dataset = Dataset.get(self.dataset_id)
self.dataset_storage = self.dataset.get_storage_path()
self.dataset_path = os.path.join(self.dataset_storage, self.dataset_id)
if self.dataset.state == "DEVELOPMENT":
if raw_data_path == "Folder containing the tar.gz file":
self.raw_data_path = os.path.dirname(input_path)
else:
self.raw_data_path = raw_data_path
self.raw_data_path = raw_data_path

def prepare(self):
if self.dataset.state == "DEVELOPMENT":
Expand All @@ -49,7 +44,9 @@ def validate_input(self):
if not os.path.isfile(self.input_path):
raise InvalidArgumentError(f"{self.input_path} is not a file.")
if self.dataset.state == "DEVELOPMENT" and (
not os.path.exists(self.raw_data_path) or os.path.isfile(self.raw_data_path)
self.raw_data_path is None
or not os.path.exists(self.raw_data_path)
or os.path.isfile(self.raw_data_path)
):
raise InvalidArgumentError(f"Folder {self.raw_data_path} doesn't exist.")

Expand Down Expand Up @@ -82,22 +79,22 @@ def validate(self):

# Checking yaml file existance
if not os.path.exists(backup_config):
raise ExecutionError("Dataset backup is invalid")
raise ExecutionError("Dataset backup is invalid, config file doesn't exist")

self.tarfiles = [os.path.join(self.tarfiles, file) for file in tarfiles_names]
with open(backup_config) as f:
self.paths = yaml.safe_load(f)

print(self.paths)
# Checks if yaml file paths are valid
if self.paths["dataset"] not in tarfiles_names:
raise ExecutionError("Dataset backup is invalid")
raise ExecutionError("Dataset backup is invalid, dataset folders not found")

# Checks if yaml file paths are valid for development datasets
if self.dataset.state == "DEVELOPMENT" and (
self.paths["data"] not in tarfiles_names
or self.paths["labels"] not in tarfiles_names
):
raise ExecutionError("Dataset backup is invalid")
raise ExecutionError("Dataset backup is invalid, config file is invalid")

self._validate_dataset()

Expand Down Expand Up @@ -140,7 +137,6 @@ def process_tarfiles(self):
move_folder(folder, self.raw_data_path)
elif os.path.basename(folder) == self.paths["labels"]:
move_folder(folder, self.raw_data_path)
# move_folder(os.path.dirname(self.tarfiles[0]), self.raw_data_path)
raw_data_path = os.path.join(self.raw_data_path, self.paths["data"])
raw_labels_path = os.path.join(self.raw_data_path, self.paths["labels"])
self.dataset.set_raw_paths(raw_data_path, raw_labels_path)
Expand All @@ -155,57 +151,3 @@ def untar_files(self):
self.tarfiles, config.dataset_backup_foldername + self.dataset_id
)
config.tmp_paths.append(self.tarfiles)


class ExportDataset:
@classmethod
def run(cls, dataset_id: str, output_path: str):
export_dataset = cls(dataset_id, output_path)
export_dataset.prepare()
export_dataset.create_tar()

def __init__(self, dataset_id: str, output_path: str):
self.dataset_id = dataset_id
self.output_path = os.path.join(output_path, dataset_id) + ".gz"
self.folders_paths = []
self.dataset = Dataset.get(self.dataset_id)
self.dataset_storage = self.dataset.get_storage_path()
self.dataset_tar_folder = (
config.dataset_backup_foldername + self.dataset_id
) # name of the folder that will contain the backup

def _prepare_development_dataset(self):
raw_data_paths = self.dataset.get_raw_paths()
if not raw_data_paths:
raise ExecutionError("Cannot find raw data paths")
for folder in raw_data_paths:
# checks if raw_data_path exists and not empty
if not (os.path.exists(folder) and os.listdir(folder)):
raise ExecutionError(f"Cannot find raw data paths at '{folder}'")
self.folders_paths.append(folder)
data_path, labels_path = raw_data_paths
self.paths["data"] = os.path.basename(data_path)
self.paths["labels"] = os.path.basename(labels_path)

def prepare(self):
# Gets server name to be added in paths.yaml for comparing between local and remote servers
# which will save folders names (what each one points to.
dataset_path = os.path.join(self.dataset_storage, self.dataset_id)
self.folders_paths.append(dataset_path)
self.paths = {"server": config.server, "dataset": self.dataset_id}

# If the dataset is in development, it'll need the raw paths as well.
if self.dataset.state == "DEVELOPMENT":
self._prepare_development_dataset()

paths_path = os.path.join(config.tmp_folder, config.backup_config_filename)

# paths.yaml will be created in medperf tmp directory
with open(paths_path, "w") as f:
yaml.dump(self.paths, f)

self.folders_paths.append(paths_path)
config.tmp_paths.append(paths_path)

def create_tar(self):
tar(self.output_path, self.folders_paths, self.dataset_tar_folder)
97 changes: 97 additions & 0 deletions cli/medperf/tests/commands/dataset/test_export_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import os
from medperf.exceptions import ExecutionError
import pytest

from medperf.tests.mocks.dataset import TestDataset
from medperf.commands.dataset.export_dataset import ExportDataset


PATCH_EXPORT = "medperf.commands.dataset.export_dataset.{}"


@pytest.fixture
def dataset(mocker):
dset = TestDataset(id=None, state="DEVELOPMENT")
return dset


@pytest.fixture
def export_dataset(mocker, dataset):
mocker.patch(PATCH_EXPORT.format("Dataset.get"), return_value=dataset)
dataclass = ExportDataset("", "")
return dataclass


def test_export_fail_if_development_dataset_raw_paths_does_not_exist(
mocker, export_dataset
):

# Arrange
mocker.patch(
PATCH_EXPORT.format("Dataset.get_raw_paths"), return_value=["test", "test1"]
)

# Act & Assert
with pytest.raises(ExecutionError):
export_dataset.prepare()


def test_export_fail_if_development_dataset_raw_paths_are_empty(mocker, export_dataset):

# Arrange
mocker.patch(
PATCH_EXPORT.format("Dataset.get_raw_paths"), return_value=["/test", "/test1"]
)
os.makedirs("/test")
os.makedirs("/test1")

# Act & Assert
with pytest.raises(ExecutionError):
export_dataset.prepare()


def test_export_if_development_dataset_length_of_yaml_paths_keys_equal_4(
mocker, export_dataset, fs
):

# Arrange
mocker.patch(
PATCH_EXPORT.format("Dataset.get_raw_paths"), return_value=["/test", "/test1"]
)
os.makedirs("/test")
os.makedirs("/test1")
fs.create_file("/test/testfile")
fs.create_file("/test1/testfile")

# Act
export_dataset.prepare()

# Assert
assert len(export_dataset.paths.keys()) == 4


def test_export_if_operation_dataset_length_of_yaml_paths_keys_equal_2(export_dataset):

# Arrange
export_dataset.dataset.state = "OPERATION"

# Act
export_dataset.prepare()

# Assert
assert len(export_dataset.paths.keys()) == 2


def test_export_if_tar_gz_file_is_created_at_output_path(export_dataset):

# Arrange
export_dataset.dataset.state = "OPERATION"
export_dataset.dataset_id = "1"
export_dataset.output_path = f"/test/{export_dataset.dataset_id}.gz"
os.makedirs("/test/")

# Act
export_dataset.create_tar()

# Assert
assert os.path.exists(export_dataset.output_path) is True
Loading

0 comments on commit 7465533

Please sign in to comment.