Skip to content

Commit

Permalink
upload as image
Browse files Browse the repository at this point in the history
Signed-off-by: Isabella do Amaral <[email protected]>
  • Loading branch information
isinyaaa committed Nov 1, 2024
1 parent 91c65b6 commit b00c2bf
Show file tree
Hide file tree
Showing 10 changed files with 144 additions and 53 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ test:

.PHONY: test-e2e
test-e2e:
poetry run pytest --e2e -s -x -rA
poetry run pytest --e2e -s -x -rA -v

.PHONY: test-e2e-model-registry
test-e2e-model-registry:
Expand Down
23 changes: 12 additions & 11 deletions e2e/test_cli.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/bin/bash
#!/usr/bin/env bash

SCRIPT_DIR="$(dirname "$(realpath "$BASH_SOURCE")")"
set -e

echo "Preparing venv ..."
Expand All @@ -15,17 +14,19 @@ echo "Running E2E test for CLI ..."
omlmd push localhost:5001/mmortari/mlartifact:v1 README.md --empty-metadata --plain-http
omlmd push localhost:5001/mmortari/mlartifact:v1 README.md --metadata tests/data/md.json --plain-http

omlmd pull localhost:5001/mmortari/mlartifact:v1 -o tmp/a --plain-http
file_count=$(find "tmp/a" -type f | wc -l)
if [ "$file_count" -eq 3 ]; then
echo "Expected 3 files in $DIR, ok."
DIR="tmp/a"
omlmd pull localhost:5001/mmortari/mlartifact:v1 -o "$DIR" --plain-http
file_count=$(find "$DIR" -type f | wc -l)
if [ "$file_count" -eq 2 ]; then
echo "Expected 2 files in $DIR, ok."
else
echo "I was expecting 3 files in $DIR, FAIL."
echo "I was expecting 2 files in $DIR, FAIL."
exit 1
fi

omlmd pull localhost:5001/mmortari/mlartifact:v1 -o tmp/b --media-types "application/x-mlmodel" --plain-http
file_count=$(find "tmp/b" -type f | wc -l)
DIR="tmp/b"
omlmd pull localhost:5001/mmortari/mlartifact:v1 -o "$DIR" --media-types "application/x-mlmodel" --plain-http
file_count=$(find "$DIR" -type f | wc -l)
if [ "$file_count" -eq 1 ]; then
echo "Expected 1 files in $DIR, ok."
else
Expand All @@ -38,7 +39,7 @@ omlmd crawl localhost:5001/mmortari/mlartifact:v1 localhost:5001/mmortari/mlarti
omlmd crawl --plain-http \
localhost:5001/mmortari/mlartifact:v1 \
localhost:5001/mmortari/mlartifact:v1 \
localhost:5001/mmortari/mlartifact:v1 \
| jq "max_by(.config.customProperties.accuracy).reference"
localhost:5001/mmortari/mlartifact:v1 |
jq "max_by(.config.customProperties.accuracy).reference"

deactivate
10 changes: 9 additions & 1 deletion omlmd/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ def crawl(plain_http: bool, targets: tuple[str]):
required=True,
type=click.Path(path_type=Path, exists=True, resolve_path=True),
)
@click.option(
"--as-artifact",
is_flag=True,
help="Push as an artifact (default is as a blob)",
)
@cloup.option_group(
"Metadata options",
cloup.option(
Expand All @@ -88,6 +93,7 @@ def push(
plain_http: bool,
target: str,
path: Path,
as_artifact: bool,
metadata: Path | None,
empty_metadata: bool,
):
Expand All @@ -96,4 +102,6 @@ def push(
if empty_metadata:
logger.warning(f"Pushing to {target} with empty metadata.")
md = deserialize_mdfile(metadata) if metadata else {}
click.echo(Helper.from_default_registry(plain_http).push(target, path, **md))
click.echo(
Helper.from_default_registry(plain_http).push(target, path, as_artifact, **md)
)
6 changes: 5 additions & 1 deletion omlmd/constants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from oras.defaults import default_blob_media_type

FILENAME_METADATA_JSON = "model_metadata.omlmd.json"
MIME_APPLICATION_CONFIG = "application/x-config"
MIME_APPLICATION_MLMODEL = "application/x-mlmodel"
MIME_APPLICATION_MLMETADATA = "application/x-mlmetadata+json"
MIME_BLOB = default_blob_media_type
MIME_MANIFEST_CONFIG = "application/vnd.oci.image.config.v1+json"
79 changes: 70 additions & 9 deletions omlmd/helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import json
import logging
import os
import platform
import tarfile
import urllib.request
from collections.abc import Sequence
from dataclasses import dataclass, field
Expand All @@ -10,8 +13,10 @@

from .constants import (
FILENAME_METADATA_JSON,
MIME_APPLICATION_CONFIG,
MIME_APPLICATION_MLMETADATA,
MIME_APPLICATION_MLMODEL,
MIME_BLOB,
MIME_MANIFEST_CONFIG,
)
from .listener import Event, Listener, PushEvent
from .model_metadata import ModelMetadata
Expand All @@ -20,6 +25,18 @@
logger = logging.getLogger(__name__)


def get_arch() -> str:
mac = platform.machine()
if mac == "x86_64":
return "amd64"
if mac == "arm64":
return "arm64"
if mac == "aarch64":
return "arm64"
msg = f"Unsupported architecture: {platform.machine()}"
raise NotImplementedError(msg)


def download_file(uri: str):
file_name = os.path.basename(uri)
urllib.request.urlretrieve(uri, file_name)
Expand All @@ -41,6 +58,7 @@ def push(
self,
target: str,
path: Path | str,
as_artifact: bool = False,
**kwargs,
):
owns_meta = True
Expand All @@ -52,8 +70,7 @@ def push(
owns_meta = False
logger.warning("Reusing intermediate metadata files.")
logger.debug(f"{meta_path}")
with open(meta_path, "r") as f:
model_metadata = ModelMetadata.from_json(f.read())
model_metadata = ModelMetadata.from_dict(json.loads(meta_path.read_bytes()))
elif meta_path.exists():
err = dedent(f"""
OMLMD intermediate metadata files found at '{meta_path}'.
Expand All @@ -65,13 +82,51 @@ def push(
raise RuntimeError(err)
else:
model_metadata = ModelMetadata.from_dict(kwargs)
meta_path.write_text(model_metadata.to_json())
meta_path.write_text(json.dumps(model_metadata.to_dict()))

owns_model_tar = False
owns_md_tar = False
manifest_path = path.parent / "manifest.json"
model_tar = None
meta_tar = None
if not as_artifact:
manifest_path.write_text(
json.dumps(
{
"architecture": get_arch(),
"os": "linux",
}
)
)
config = f"{manifest_path}:{MIME_MANIFEST_CONFIG}"
model_tar = path.parent / f"{path.stem}.tar"
meta_tar = path.parent / f"{meta_path.stem}.tar"
if not model_tar.exists():
owns_model_tar = True
with tarfile.open(model_tar, "w") as tf:
tf.add(path, arcname=path.name)
if not meta_tar.exists():
owns_md_tar = True
with tarfile.open(meta_tar, "w:gz") as tf:
tf.add(meta_path, arcname=meta_path.name)
files = [
f"{model_tar}:{MIME_BLOB}",
f"{meta_tar}:{MIME_BLOB}+gzip",
]
else:
manifest_path.write_text(
json.dumps(
{
"artifactType": MIME_APPLICATION_MLMODEL,
}
)
)
config = f"{manifest_path}:{MIME_APPLICATION_MLMODEL}"
files = [
f"{path}:{MIME_APPLICATION_MLMODEL}",
f"{meta_path}:{MIME_APPLICATION_MLMETADATA}",
]

config = f"{meta_path}:{MIME_APPLICATION_CONFIG}"
files = [
f"{path}:{MIME_APPLICATION_MLMODEL}",
config,
]
try:
# print(target, files, model_metadata.to_annotations_dict())
result = self._registry.push(
Expand All @@ -88,6 +143,12 @@ def push(
finally:
if owns_meta:
meta_path.unlink()
if owns_model_tar:
assert isinstance(model_tar, Path)
model_tar.unlink()
if owns_md_tar:
assert isinstance(meta_tar, Path)
meta_tar.unlink()

def pull(
self, target: str, outdir: Path | str, media_types: Sequence[str] | None = None
Expand Down
4 changes: 1 addition & 3 deletions omlmd/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@


class Listener(ABC):
"""
TODO: not yet settled for multi-method or current single update method.
"""
# TODO: not yet settled for multi-method or current single update method.

@abstractmethod
def update(self, source: t.Any, event: Event) -> None:
Expand Down
13 changes: 0 additions & 13 deletions omlmd/model_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@ class ModelMetadata:
model_format_name: str | None = None
model_format_version: str | None = None

def to_json(self) -> str:
return json.dumps(self.to_dict(), indent=4)

def to_dict(self) -> dict[str, t.Any]:
return asdict(self)

Expand All @@ -38,16 +35,6 @@ def to_annotations_dict(self) -> dict[str, str]:
) # post-fix "+json" for OCI annotation which is a str representing a json
return result

@staticmethod
def from_json(json_str: str) -> "ModelMetadata":
data = json.loads(json_str)
return ModelMetadata(**data)

@staticmethod
def from_yaml(yaml_str: str) -> "ModelMetadata":
data = yaml.safe_load(yaml_str)
return ModelMetadata(**data)

@staticmethod
def from_dict(data: dict[str, t.Any]) -> "ModelMetadata":
known_keys = {f.name for f in fields(ModelMetadata)}
Expand Down
4 changes: 3 additions & 1 deletion tests/test_e2e_model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def update(self, source: Helper, event: Event) -> None:
assert mv
assert mv.description == "Lorem ipsum"
assert mv.author == "John Doe"
assert mv.custom_properties == {"accuracy": 0.987}
assert mv.custom_properties == {
"accuracy": accuracy_value,
}

ma = model_registry.get_model_artifact("mnist", v)
assert ma
Expand Down
50 changes: 40 additions & 10 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
import io
import json
import subprocess
import tarfile
import tempfile
import typing as t
from hashlib import sha256
from pathlib import Path

import pytest

from omlmd.constants import MIME_APPLICATION_MLMODEL
from omlmd.constants import MIME_BLOB
from omlmd.helpers import Helper
from omlmd.listener import Event, Listener
from omlmd.model_metadata import ModelMetadata, deserialize_mdfile
from omlmd.provider import OMLMDRegistry


def untar(tar: Path, out: Path):
out.write_bytes(
t.cast(io.BufferedReader, tarfile.open(tar, "r").extractfile(tar.stem)).read()
)


def test_call_push_using_md_from_file(mocker):
helper = Helper()
mocker.patch.object(helper, "push", return_value=None)
Expand Down Expand Up @@ -100,12 +108,33 @@ def test_push_pull_chunked(tmp_path, target):

omlmd.push(target, temp, **md)
omlmd.pull(target, tmp_path)
assert len(list(tmp_path.iterdir())) == 3
assert tmp_path.joinpath(temp.name).stat().st_size == base_size
files = list(tmp_path.iterdir())
print(files)
assert len(files) == 2
print(tmp_path)
out = tmp_path.joinpath(temp.name)
untar(out.with_suffix(".tar"), out)
assert temp.stat().st_size == base_size
finally:
temp.unlink()


@pytest.mark.e2e
def test_e2e_push_pull_as_artifact(tmp_path, target):
omlmd = Helper()
omlmd.push(
target,
Path(__file__).parent / ".." / "README.md",
as_artifact=True,
name="mnist",
description="Lorem ipsum",
author="John Doe",
accuracy=0.987,
)
omlmd.pull(target, tmp_path)
assert len(list(tmp_path.iterdir())) == 2


@pytest.mark.e2e
def test_e2e_push_pull(tmp_path, target):
omlmd = Helper()
Expand All @@ -118,7 +147,7 @@ def test_e2e_push_pull(tmp_path, target):
accuracy=0.987,
)
omlmd.pull(target, tmp_path)
assert len(list(tmp_path.iterdir())) == 3
assert len(list(tmp_path.iterdir())) == 2


@pytest.mark.e2e
Expand All @@ -132,7 +161,7 @@ def test_e2e_push_pull_with_filters(tmp_path, target):
author="John Doe",
accuracy=0.987,
)
omlmd.pull(target, tmp_path, media_types=[MIME_APPLICATION_MLMODEL])
omlmd.pull(target, tmp_path, media_types=[MIME_BLOB])
assert len(list(tmp_path.iterdir())) == 1


Expand All @@ -155,10 +184,11 @@ def test_e2e_push_pull_column(tmp_path, target):

omlmd.push(target, temp, **md)
omlmd.pull(target, tmp_path)
with open(tmp_path.joinpath(temp.name), "r") as f:
pulled = f.read()
assert pulled == content
pulled_sha = sha256(pulled.encode("utf-8")).hexdigest()
assert pulled_sha == content_sha
out = tmp_path.joinpath(temp.name)
untar(out.with_suffix(".tar"), out)
pulled = out.read_text()
assert pulled == content
pulled_sha = sha256(pulled.encode("utf-8")).hexdigest()
assert pulled_sha == content_sha
finally:
temp.unlink()
6 changes: 3 additions & 3 deletions tests/test_omlmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@

def test_dry_run_model_metadata_json_yaml_conversions():
metadata = ModelMetadata(name="Example Model", author="John Doe")
json_str = metadata.to_json()
json_str = json.dumps(metadata.to_dict(), indent=4)
yaml_str = yaml.dump(metadata.to_dict(), default_flow_style=False)

print("JSON representation:\n", json_str)
print("YAML representation:\n", yaml_str)

metadata_from_json = ModelMetadata.from_json(json_str)
metadata_from_yaml = ModelMetadata.from_yaml(yaml_str)
metadata_from_json = ModelMetadata(**json.loads(json_str))
metadata_from_yaml = ModelMetadata(**yaml.safe_load(yaml_str))

print("Metadata from JSON:\n", metadata_from_json)
print("Metadata from YAML:\n", metadata_from_yaml)
Expand Down

0 comments on commit b00c2bf

Please sign in to comment.