diff --git a/mlem/core/metadata.py b/mlem/core/metadata.py index 07c35d48..107ffac1 100644 --- a/mlem/core/metadata.py +++ b/mlem/core/metadata.py @@ -5,7 +5,8 @@ import logging import os import posixpath -from typing import Any, Dict, Optional, Type, TypeVar, Union, overload +from collections import defaultdict +from typing import Any, Dict, List, Optional, Type, TypeVar, Union, overload from fsspec import AbstractFileSystem from typing_extensions import Literal @@ -16,8 +17,14 @@ MlemProjectNotFound, WrongMetaType, ) -from mlem.core.meta_io import Location, get_meta_path -from mlem.core.objects import MlemData, MlemModel, MlemObject, find_object +from mlem.core.meta_io import MLEM_EXT, Location, get_meta_path +from mlem.core.objects import ( + MlemData, + MlemLink, + MlemModel, + MlemObject, + find_object, +) from mlem.utils.path import make_posix logger = logging.getLogger(__name__) @@ -214,3 +221,28 @@ def find_meta_location(location: Location) -> Location: path = posixpath.relpath(path, location.project) location.update_path(path) return location + + +def list_objects( + path: str = ".", fs: Optional[AbstractFileSystem] = None, recursive=True +) -> Dict[Type[MlemObject], List[MlemObject]]: + loc = Location.resolve(path, fs=fs) + result = defaultdict(list) + postfix = f"/**{MLEM_EXT}" if recursive else f"/*{MLEM_EXT}" + for filepath in loc.fs.glob(loc.fullpath + postfix, detail=False): + meta = load_meta( + filepath, fs=loc.fs, load_value=False, follow_links=False + ) + type_ = meta.__class__ + if isinstance(meta, MlemLink): + type_ = meta.link_cls + else: + parent = meta.__parent__ + if ( + parent is not None + and parent != MlemObject + and issubclass(parent, MlemObject) + ): + type_ = parent + result[type_].append(meta) + return result diff --git a/tests/core/test_metadata.py b/tests/core/test_metadata.py index c8572220..42c49cee 100644 --- a/tests/core/test_metadata.py +++ b/tests/core/test_metadata.py @@ -12,9 +12,11 @@ from sklearn.tree import DecisionTreeClassifier from mlem.api import init +from mlem.contrib.heroku.meta import HerokuEnv from mlem.core.meta_io import MLEM_EXT -from mlem.core.metadata import load, load_meta, save -from mlem.core.objects import MlemModel +from mlem.core.metadata import list_objects, load, load_meta, save +from mlem.core.objects import MlemData, MlemEnv, MlemLink, MlemModel +from mlem.utils.path import make_posix from tests.conftest import ( MLEM_TEST_REPO, MLEM_TEST_REPO_NAME, @@ -179,3 +181,75 @@ def test_loading_from_s3(model, s3_storage_fs, s3_tmp_path): assert isinstance(loaded, DecisionTreeClassifier) train, _ = load_iris(return_X_y=True) loaded.predict(train) + + +def test_ls_local(filled_mlem_project): + objects = list_objects(filled_mlem_project) + + assert len(objects) == 1 + assert MlemModel in objects + models = objects[MlemModel] + assert len(models) == 2 + model, lnk = models + if isinstance(model, MlemLink): + model, lnk = lnk, model + + assert isinstance(model, MlemModel) + assert isinstance(lnk, MlemLink) + assert ( + posixpath.join(make_posix(filled_mlem_project), lnk.path) + == model.loc.fullpath + ) + + +@pytest.mark.parametrize("recursive,count", [[True, 3], [False, 1]]) +def test_ls_local_recursive(tmpdir, recursive, count): + path = str(tmpdir) + meta = HerokuEnv() + meta.dump(posixpath.join(path, "env")) + meta.dump(posixpath.join(path, "subdir", "env")) + meta.dump(posixpath.join(path, "subdir", "subsubdir", "env")) + objects = list_objects(path, recursive=recursive) + assert len(objects) == 1 + assert MlemEnv in objects + assert len(objects[MlemEnv]) == count + + +def test_ls_no_project(tmpdir): + assert not list_objects(str(tmpdir)) + + +@long +@need_test_repo_auth +def test_ls_remote(current_test_branch): + objects = list_objects( + os.path.join(MLEM_TEST_REPO, f"tree/{current_test_branch}/simple") + ) + assert len(objects) == 2 + assert MlemModel in objects + models = objects[MlemModel] + assert len(models) == 2 + model, lnk = models + if isinstance(model, MlemLink): + model, lnk = lnk, model + + assert isinstance(model, MlemModel) + assert isinstance(lnk, MlemLink) + + assert MlemData in objects + assert len(objects[MlemData]) == 4 + + +@long +def test_ls_remote_s3(s3_tmp_path): + path = s3_tmp_path("ls_remote_s3") + init(path) + meta = HerokuEnv() + meta.dump(posixpath.join(path, "env")) + meta.dump(posixpath.join(path, "subdir", "env")) + meta.dump(posixpath.join(path, "subdir", "subsubdir", "env")) + objects = list_objects(path) + assert MlemEnv in objects + envs = objects[MlemEnv] + assert len(envs) == 3 + assert all(o == meta for o in envs)