Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add stages file to MLCube #621

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
10 changes: 10 additions & 0 deletions cli/medperf/commands/mlcube/mlcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ def submit(
help="Identifier to download the image file. See the description above",
),
image_hash: str = typer.Option("", "--image-hash", help="hash of image file"),
stages_file: str = typer.Option(
"",
"--stages-file",
"-s",
help="Identifier to download the stages file. See the description above"
),
stages_hash: str = typer.Option("", "--stages-hash", help="Hash of the stages file"),
operational: bool = typer.Option(
False,
"--operational",
Expand All @@ -99,6 +106,7 @@ def submit(
- parameters_file\n
- additional_file\n
- image_file\n
- stages_file\n
are expected to be given in the following format: <source_prefix:resource_identifier>
where `source_prefix` instructs the client how to download the resource, and `resource_identifier`
is the identifier used to download the asset. The following are supported:\n
Expand All @@ -117,6 +125,8 @@ def submit(
"image_tarball_hash": image_hash,
"additional_files_tarball_url": additional_file,
"additional_files_tarball_hash": additional_hash,
"stages_url": stages_file,
"stages_hash": stages_hash,
"state": "OPERATION" if operational else "DEVELOPMENT",
}
SubmitCube.run(mlcube_info)
Expand Down
6 changes: 6 additions & 0 deletions cli/medperf/comms/entity_resources/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ def get_cube_params(url: str, cube_path: str, expected_hash: str = None):
return _get_regular_file(url, output_path, expected_hash)


def get_cube_stages(url: str, cube_path: str, expected_hash: str = None):
"""Downloads and writes a cube stages.yaml file"""
output_path = os.path.join(cube_path, config.stages_filename)
return _get_regular_file(url, output_path, expected_hash)


def get_cube_image(url: str, cube_path: str, hash_value: str = None) -> str:
"""Retrieves and stores the image file from the server. Stores images
on a shared location, and retrieves a cached image by hash if found locally.
Expand Down
1 change: 1 addition & 0 deletions cli/medperf/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@
# MLCube assets conventions
cube_filename = "mlcube.yaml"
params_filename = "parameters.yaml"
stages_filename = "stages.yaml"
workspace_path = "workspace"
additional_path = "workspace/additional_files"
image_path = "workspace/.image"
Expand Down
19 changes: 19 additions & 0 deletions cli/medperf/entities/cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class Cube(Entity, DeployableSchema):
image_hash: Optional[str]
additional_files_tarball_url: Optional[str] = Field(None, alias="tarball_url")
additional_files_tarball_hash: Optional[str] = Field(None, alias="tarball_hash")
stages_url: Optional[str]
stages_hash: Optional[str]
metadata: dict = {}
user_metadata: dict = {}

Expand Down Expand Up @@ -74,6 +76,9 @@ def __init__(self, *args, **kwargs):
self.params_path = None
if self.git_parameters_url:
self.params_path = os.path.join(self.path, config.params_filename)
self.stages_path = None
if self.stages_url:
self.stages_path = os.path.join(self.path, config.stages_filename)

@property
def local_id(self):
Expand Down Expand Up @@ -136,6 +141,15 @@ def download_additional(self):
)
self.additional_files_tarball_hash = file_hash

def download_stages(self):
url = self.stages_url
if url:
path, file_hash = resources.get_cube_stages(
url, self.path, self.stages_hash
)
self.stages_path = path
self.stages_hash = file_hash

def download_image(self):
url = self.image_tarball_url
tarball_hash = self.image_tarball_hash
Expand Down Expand Up @@ -216,6 +230,11 @@ def download_config_files(self):
except InvalidEntityError as e:
raise InvalidEntityError(f"MLCube {self.name} parameters file: {e}")

try:
self.download_stages()
except InvalidEntityError as e:
raise InvalidEntityError(f"MLCube {self.name} stages file: {e}")

def download_run_files(self):
try:
self.download_additional()
Expand Down
53 changes: 14 additions & 39 deletions cli/medperf/tests/comms/entity_resources/test_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,61 +120,36 @@ def test_get_additional_files_will_download_if_folder_exists_and_hash_valid_but_
assert spy.call_count == 2


class TestGetCube:
def test_get_cube_does_not_download_if_folder_exists_and_hash_valid(
self, mocker, fs
@pytest.mark.parametrize("method", [
resources.get_cube,
resources.get_cube_params,
resources.get_cube_stages,
])
class TestGet:
def test_get_does_not_download_if_folder_exists_and_hash_valid(
self, mocker, fs, method
):
# Arrange
cube_path = "cube/1"
spy = mocker.spy(resources, "download_resource")
_, exp_hash = resources.get_cube(url, cube_path)
_, exp_hash = method(url, cube_path)

# Act
resources.get_cube(url, cube_path, exp_hash)
method(url, cube_path, exp_hash)

# Assert
spy.assert_called_once() # second time shouldn't download

def test_get_cube_does_will_download_if_folder_exists_and_hash_outdated(
self, mocker, fs
):
# Arrange
cube_path = "cube/1"
spy = mocker.spy(resources, "download_resource")
resources.get_cube(url, cube_path)

# Act
resources.get_cube(url, cube_path, "incorrect hash")

# Assert
assert spy.call_count == 2


class TestGetCubeParams:
def test_get_cube_params_does_not_download_if_folder_exists_and_hash_valid(
self, mocker, fs
):
# Arrange
cube_path = "cube/1"
spy = mocker.spy(resources, "download_resource")
_, exp_hash = resources.get_cube_params(url, cube_path)

# Act
resources.get_cube_params(url, cube_path, exp_hash)

# Assert
spy.assert_called_once() # second time shouldn't download

def test_get_cube_params_does_will_download_if_folder_exists_and_hash_outdated(
self, mocker, fs
def test_get_does_will_download_if_folder_exists_and_hash_outdated(
self, mocker, fs, method
):
# Arrange
cube_path = "cube/1"
spy = mocker.spy(resources, "download_resource")
resources.get_cube_params(url, cube_path)
method(url, cube_path)

# Act
resources.get_cube_params(url, cube_path, "incorrect hash")
method(url, cube_path, "incorrect hash")

# Assert
assert spy.call_count == 2
3 changes: 2 additions & 1 deletion cli/medperf/tests/entities/test_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def set_common_attributes(self, setup):
self.cube_path, config.additional_path, config.tarball_filename
)
self.img_path = os.path.join(self.cube_path, config.image_path, "img.tar.gz")
self.config_files_paths = [self.manifest_path, self.params_path]
self.stages_path = os.path.join(self.cube_path, config.stages_filename)
self.config_files_paths = [self.manifest_path, self.params_path, self.stages_path]
self.run_files_paths = [self.add_path, self.img_path]

@pytest.mark.parametrize("setup", [{"remote": [DEFAULT_CUBE]}], indirect=True)
Expand Down
4 changes: 4 additions & 0 deletions cli/medperf/tests/entities/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,16 +112,20 @@ def setup_cube_comms_downloads(mocker, fs):
add_file = config.tarball_filename
img_path = config.image_path
img_file = "img.tar.gz"
stages_path = ""
stages_file = config.stages_filename

get_cube_fn = generate_cubefile_fn(fs, cube_path, cube_file)
get_params_fn = generate_cubefile_fn(fs, params_path, params_file)
get_add_fn = generate_cubefile_fn(fs, add_path, add_file)
get_img_fn = generate_cubefile_fn(fs, img_path, img_file)
get_stages_fn = generate_cubefile_fn(fs, stages_path, stages_file)

mocker.patch(PATCH_RESOURCES.format("get_cube"), side_effect=get_cube_fn)
mocker.patch(PATCH_RESOURCES.format("get_cube_params"), side_effect=get_params_fn)
mocker.patch(PATCH_RESOURCES.format("get_cube_additional"), side_effect=get_add_fn)
mocker.patch(PATCH_RESOURCES.format("get_cube_image"), side_effect=get_img_fn)
mocker.patch(PATCH_RESOURCES.format("get_cube_stages"), side_effect=get_stages_fn)


# Setup Dataset
Expand Down
2 changes: 2 additions & 0 deletions cli/medperf/tests/mocks/cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,7 @@ class TestCube(Cube):
"https://test.com/additional_files.tar.gz"
)
additional_files_tarball_hash: Optional[str] = EMPTY_FILE_HASH
stages_url: Optional[str] = "https://test.com/stages.yaml"
stages_hash: Optional[str] = EMPTY_FILE_HASH
state: str = "OPERATION"
is_valid = True
23 changes: 23 additions & 0 deletions server/mlcube/migrations/0003_auto_20241025_1703.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Generated by Django 3.2.20 on 2024-10-25 17:03

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('mlcube', '0002_alter_mlcube_unique_together'),
]

operations = [
migrations.AddField(
model_name='mlcube',
name='stages_hash',
field=models.CharField(blank=True, max_length=100),
),
migrations.AddField(
model_name='mlcube',
name='stages_url',
field=models.CharField(blank=True, max_length=256),
),
]
2 changes: 2 additions & 0 deletions server/mlcube/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class MlCube(models.Model):
image_hash = models.CharField(max_length=100, blank=True)
additional_files_tarball_url = models.CharField(max_length=256, blank=True)
additional_files_tarball_hash = models.CharField(max_length=100, blank=True)
stages_url = models.CharField(max_length=256, blank=True)
stages_hash = models.CharField(max_length=100, blank=True)
owner = models.ForeignKey(User, on_delete=models.PROTECT)
state = models.CharField(
choices=MLCUBE_STATE, max_length=100, default="DEVELOPMENT"
Expand Down
Loading