diff --git a/cli/medperf/commands/mlcube/mlcube.py b/cli/medperf/commands/mlcube/mlcube.py index 9256f35f2..2d2ccb9cf 100644 --- a/cli/medperf/commands/mlcube/mlcube.py +++ b/cli/medperf/commands/mlcube/mlcube.py @@ -87,6 +87,13 @@ def submit( help="Identifier to download the image file. See the description above", ), image_hash: str = typer.Option("", "--image-hash", help="hash of image file"), + stages_file: str = typer.Option( + "", + "--stages-file", + "-s", + help="Identifier to download the stages file. See the description above" + ), + stages_hash: str = typer.Option("", "--stages-hash", help="Hash of the stages file"), operational: bool = typer.Option( False, "--operational", @@ -99,6 +106,7 @@ def submit( - parameters_file\n - additional_file\n - image_file\n + - stages_file\n are expected to be given in the following format: where `source_prefix` instructs the client how to download the resource, and `resource_identifier` is the identifier used to download the asset. The following are supported:\n @@ -117,6 +125,8 @@ def submit( "image_tarball_hash": image_hash, "additional_files_tarball_url": additional_file, "additional_files_tarball_hash": additional_hash, + "stages_url": stages_file, + "stages_hash": stages_hash, "state": "OPERATION" if operational else "DEVELOPMENT", } SubmitCube.run(mlcube_info) diff --git a/cli/medperf/comms/entity_resources/resources.py b/cli/medperf/comms/entity_resources/resources.py index 09dc7c0b8..b59324e3a 100644 --- a/cli/medperf/comms/entity_resources/resources.py +++ b/cli/medperf/comms/entity_resources/resources.py @@ -95,6 +95,12 @@ def get_cube_params(url: str, cube_path: str, expected_hash: str = None): return _get_regular_file(url, output_path, expected_hash) +def get_cube_stages(url: str, cube_path: str, expected_hash: str = None): + """Downloads and writes a cube stages.yaml file""" + output_path = os.path.join(cube_path, config.stages_filename) + return _get_regular_file(url, output_path, expected_hash) + + def get_cube_image(url: str, cube_path: str, hash_value: str = None) -> str: """Retrieves and stores the image file from the server. Stores images on a shared location, and retrieves a cached image by hash if found locally. diff --git a/cli/medperf/config.py b/cli/medperf/config.py index 2c5b520be..21242a304 100644 --- a/cli/medperf/config.py +++ b/cli/medperf/config.py @@ -150,6 +150,7 @@ # MLCube assets conventions cube_filename = "mlcube.yaml" params_filename = "parameters.yaml" +stages_filename = "stages.yaml" workspace_path = "workspace" additional_path = "workspace/additional_files" image_path = "workspace/.image" diff --git a/cli/medperf/entities/cube.py b/cli/medperf/entities/cube.py index 714342c53..edfb9b1f2 100644 --- a/cli/medperf/entities/cube.py +++ b/cli/medperf/entities/cube.py @@ -39,6 +39,8 @@ class Cube(Entity, DeployableSchema): image_hash: Optional[str] additional_files_tarball_url: Optional[str] = Field(None, alias="tarball_url") additional_files_tarball_hash: Optional[str] = Field(None, alias="tarball_hash") + stages_url: Optional[str] + stages_hash: Optional[str] metadata: dict = {} user_metadata: dict = {} @@ -74,6 +76,9 @@ def __init__(self, *args, **kwargs): self.params_path = None if self.git_parameters_url: self.params_path = os.path.join(self.path, config.params_filename) + self.stages_path = None + if self.stages_url: + self.stages_path = os.path.join(self.path, config.stages_filename) @property def local_id(self): @@ -136,6 +141,15 @@ def download_additional(self): ) self.additional_files_tarball_hash = file_hash + def download_stages(self): + url = self.stages_url + if url: + path, file_hash = resources.get_cube_stages( + url, self.path, self.stages_hash + ) + self.stages_path = path + self.stages_hash = file_hash + def download_image(self): url = self.image_tarball_url tarball_hash = self.image_tarball_hash @@ -216,6 +230,11 @@ def download_config_files(self): except InvalidEntityError as e: raise InvalidEntityError(f"MLCube {self.name} parameters file: {e}") + try: + self.download_stages() + except InvalidEntityError as e: + raise InvalidEntityError(f"MLCube {self.name} stages file: {e}") + def download_run_files(self): try: self.download_additional() diff --git a/cli/medperf/tests/comms/entity_resources/test_resources.py b/cli/medperf/tests/comms/entity_resources/test_resources.py index 437e518be..609d05d5a 100644 --- a/cli/medperf/tests/comms/entity_resources/test_resources.py +++ b/cli/medperf/tests/comms/entity_resources/test_resources.py @@ -120,61 +120,36 @@ def test_get_additional_files_will_download_if_folder_exists_and_hash_valid_but_ assert spy.call_count == 2 -class TestGetCube: - def test_get_cube_does_not_download_if_folder_exists_and_hash_valid( - self, mocker, fs +@pytest.mark.parametrize("method", [ + resources.get_cube, + resources.get_cube_params, + resources.get_cube_stages, +]) +class TestGet: + def test_get_does_not_download_if_folder_exists_and_hash_valid( + self, mocker, fs, method ): # Arrange cube_path = "cube/1" spy = mocker.spy(resources, "download_resource") - _, exp_hash = resources.get_cube(url, cube_path) + _, exp_hash = method(url, cube_path) # Act - resources.get_cube(url, cube_path, exp_hash) + method(url, cube_path, exp_hash) # Assert spy.assert_called_once() # second time shouldn't download - def test_get_cube_does_will_download_if_folder_exists_and_hash_outdated( - self, mocker, fs - ): - # Arrange - cube_path = "cube/1" - spy = mocker.spy(resources, "download_resource") - resources.get_cube(url, cube_path) - - # Act - resources.get_cube(url, cube_path, "incorrect hash") - - # Assert - assert spy.call_count == 2 - - -class TestGetCubeParams: - def test_get_cube_params_does_not_download_if_folder_exists_and_hash_valid( - self, mocker, fs - ): - # Arrange - cube_path = "cube/1" - spy = mocker.spy(resources, "download_resource") - _, exp_hash = resources.get_cube_params(url, cube_path) - - # Act - resources.get_cube_params(url, cube_path, exp_hash) - - # Assert - spy.assert_called_once() # second time shouldn't download - - def test_get_cube_params_does_will_download_if_folder_exists_and_hash_outdated( - self, mocker, fs + def test_get_does_will_download_if_folder_exists_and_hash_outdated( + self, mocker, fs, method ): # Arrange cube_path = "cube/1" spy = mocker.spy(resources, "download_resource") - resources.get_cube_params(url, cube_path) + method(url, cube_path) # Act - resources.get_cube_params(url, cube_path, "incorrect hash") + method(url, cube_path, "incorrect hash") # Assert assert spy.call_count == 2 diff --git a/cli/medperf/tests/entities/test_cube.py b/cli/medperf/tests/entities/test_cube.py index 89e7cc5a9..0d8e00512 100644 --- a/cli/medperf/tests/entities/test_cube.py +++ b/cli/medperf/tests/entities/test_cube.py @@ -68,7 +68,8 @@ def set_common_attributes(self, setup): self.cube_path, config.additional_path, config.tarball_filename ) self.img_path = os.path.join(self.cube_path, config.image_path, "img.tar.gz") - self.config_files_paths = [self.manifest_path, self.params_path] + self.stages_path = os.path.join(self.cube_path, config.stages_filename) + self.config_files_paths = [self.manifest_path, self.params_path, self.stages_path] self.run_files_paths = [self.add_path, self.img_path] @pytest.mark.parametrize("setup", [{"remote": [DEFAULT_CUBE]}], indirect=True) diff --git a/cli/medperf/tests/entities/utils.py b/cli/medperf/tests/entities/utils.py index c3bde6feb..fc324c0a7 100644 --- a/cli/medperf/tests/entities/utils.py +++ b/cli/medperf/tests/entities/utils.py @@ -112,16 +112,20 @@ def setup_cube_comms_downloads(mocker, fs): add_file = config.tarball_filename img_path = config.image_path img_file = "img.tar.gz" + stages_path = "" + stages_file = config.stages_filename get_cube_fn = generate_cubefile_fn(fs, cube_path, cube_file) get_params_fn = generate_cubefile_fn(fs, params_path, params_file) get_add_fn = generate_cubefile_fn(fs, add_path, add_file) get_img_fn = generate_cubefile_fn(fs, img_path, img_file) + get_stages_fn = generate_cubefile_fn(fs, stages_path, stages_file) mocker.patch(PATCH_RESOURCES.format("get_cube"), side_effect=get_cube_fn) mocker.patch(PATCH_RESOURCES.format("get_cube_params"), side_effect=get_params_fn) mocker.patch(PATCH_RESOURCES.format("get_cube_additional"), side_effect=get_add_fn) mocker.patch(PATCH_RESOURCES.format("get_cube_image"), side_effect=get_img_fn) + mocker.patch(PATCH_RESOURCES.format("get_cube_stages"), side_effect=get_stages_fn) # Setup Dataset diff --git a/cli/medperf/tests/mocks/cube.py b/cli/medperf/tests/mocks/cube.py index 9c1acbb8a..9d8a8f0a3 100644 --- a/cli/medperf/tests/mocks/cube.py +++ b/cli/medperf/tests/mocks/cube.py @@ -18,5 +18,7 @@ class TestCube(Cube): "https://test.com/additional_files.tar.gz" ) additional_files_tarball_hash: Optional[str] = EMPTY_FILE_HASH + stages_url: Optional[str] = "https://test.com/stages.yaml" + stages_hash: Optional[str] = EMPTY_FILE_HASH state: str = "OPERATION" is_valid = True diff --git a/server/mlcube/migrations/0003_auto_20241025_1703.py b/server/mlcube/migrations/0003_auto_20241025_1703.py new file mode 100644 index 000000000..bb055acc6 --- /dev/null +++ b/server/mlcube/migrations/0003_auto_20241025_1703.py @@ -0,0 +1,23 @@ +# Generated by Django 3.2.20 on 2024-10-25 17:03 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('mlcube', '0002_alter_mlcube_unique_together'), + ] + + operations = [ + migrations.AddField( + model_name='mlcube', + name='stages_hash', + field=models.CharField(blank=True, max_length=100), + ), + migrations.AddField( + model_name='mlcube', + name='stages_url', + field=models.CharField(blank=True, max_length=256), + ), + ] diff --git a/server/mlcube/models.py b/server/mlcube/models.py index 8835b31cb..039fc18ac 100644 --- a/server/mlcube/models.py +++ b/server/mlcube/models.py @@ -20,6 +20,8 @@ class MlCube(models.Model): image_hash = models.CharField(max_length=100, blank=True) additional_files_tarball_url = models.CharField(max_length=256, blank=True) additional_files_tarball_hash = models.CharField(max_length=100, blank=True) + stages_url = models.CharField(max_length=256, blank=True) + stages_hash = models.CharField(max_length=100, blank=True) owner = models.ForeignKey(User, on_delete=models.PROTECT) state = models.CharField( choices=MLCUBE_STATE, max_length=100, default="DEVELOPMENT"