From 1f7bc2952bebaf0fd9dc7269d620193d7c1f44a7 Mon Sep 17 00:00:00 2001 From: Mikhail Sveshnikov Date: Fri, 9 Sep 2022 01:39:57 +0300 Subject: [PATCH 1/4] Add -c help WIP (#363) * Add -c help to declare WIP * extrapolate for other commands * some field docs and little improvements * cli utils * add simple_parsing * fix tests * fix tests * lazy help * ooopsie * class and fields docstrings * reparsing cli params for nested complex objects rewrite get_field_docstring for 9000x speed fix build_model bug * fix for py37 * support lists in build_model * support lists in cli * nested options WIP * very nested options WIP * all but flat nested WIP * lil refactoring * flat nested stuff DONE * Update mlem/contrib/heroku/build.py Co-authored-by: Alexander Guschin <1aguschin@gmail.com> * Update mlem/cli/declare.py Co-authored-by: Alexander Guschin <1aguschin@gmail.com> * get rid of --conf, add mlem abc to declare * fix tests * fix lazyness * fix serialization * fix tests * fix tests * Update bitbucketfs.py * Apply suggestions from code review * fix comments and disable failfast for gh actions * backport docs from mlem.ai * sort import choices * make run_cmd optional instead of bool * docs for torch import * allow --load for groups add server config into docker build * fix windows bugs * suddenly fix dockerhub requests * suddenly fix dockerhub requests Co-authored-by: Alexander Guschin <1aguschin@gmail.com> --- .github/workflows/check-test-release.yml | 2 +- .pylintrc | 2 +- mlem/api/commands.py | 6 +- mlem/cli/apply.py | 158 ++++-- mlem/cli/build.py | 103 +++- mlem/cli/clone.py | 13 +- mlem/cli/config.py | 7 +- mlem/cli/declare.py | 170 +++++- mlem/cli/deployment.py | 39 +- mlem/cli/import_object.py | 2 +- mlem/cli/info.py | 27 +- mlem/cli/init.py | 2 +- mlem/cli/link.py | 17 +- mlem/cli/main.py | 415 ++++++++------- mlem/cli/serve.py | 87 ++- mlem/cli/types.py | 141 ++--- mlem/cli/utils.py | 621 ++++++++++++++++++++++ mlem/contrib/bitbucketfs.py | 4 +- mlem/contrib/callable.py | 2 + mlem/contrib/catboost.py | 3 + mlem/contrib/docker/base.py | 63 ++- mlem/contrib/docker/context.py | 44 +- mlem/contrib/docker/dockerfile.j2 | 2 +- mlem/contrib/docker/utils.py | 32 +- mlem/contrib/dvc.py | 12 +- mlem/contrib/fastapi.py | 4 + mlem/contrib/github.py | 10 +- mlem/contrib/gitlabfs.py | 14 +- mlem/contrib/heroku/build.py | 6 +- mlem/contrib/heroku/meta.py | 18 + mlem/contrib/heroku/server.py | 4 +- mlem/contrib/lightgbm.py | 11 +- mlem/contrib/numpy.py | 26 +- mlem/contrib/pandas.py | 12 +- mlem/contrib/pip/base.py | 21 +- mlem/contrib/rabbitmq.py | 9 + mlem/contrib/sklearn.py | 10 +- mlem/contrib/tensorflow.py | 16 +- mlem/contrib/torch.py | 22 +- mlem/contrib/xgboost.py | 12 +- mlem/core/artifacts.py | 8 + mlem/core/base.py | 176 ++++-- mlem/core/data_type.py | 35 +- mlem/core/errors.py | 2 +- mlem/core/meta_io.py | 22 +- mlem/core/metadata.py | 4 +- mlem/core/model.py | 16 +- mlem/core/objects.py | 33 +- mlem/core/requirements.py | 32 +- mlem/runtime/client.py | 4 + mlem/runtime/interface.py | 4 + mlem/utils/entrypoints.py | 38 +- mlem/utils/templates.py | 1 + setup.py | 4 +- tests/cli/conftest.py | 11 +- tests/cli/test_apply.py | 10 +- tests/cli/test_build.py | 77 ++- tests/cli/test_declare.py | 466 +++++++++++++++- tests/cli/test_deployment.py | 13 +- tests/cli/test_main.py | 8 +- tests/cli/test_serve.py | 9 +- tests/cli/test_types.py | 85 +++ tests/contrib/test_bitbucket.py | 6 +- tests/contrib/test_docker/test_context.py | 12 +- tests/contrib/test_docker/test_utils.py | 10 +- tests/contrib/test_gitlab.py | 6 +- tests/core/test_base.py | 169 +++++- tests/core/test_meta_io.py | 4 +- tests/core/test_objects.py | 20 +- tests/utils/test_entrypoints.py | 39 ++ 70 files changed, 2844 insertions(+), 649 deletions(-) create mode 100644 mlem/cli/utils.py create mode 100644 tests/cli/test_types.py create mode 100644 tests/utils/test_entrypoints.py diff --git a/.github/workflows/check-test-release.yml b/.github/workflows/check-test-release.yml index af873f97..0baf5058 100644 --- a/.github/workflows/check-test-release.yml +++ b/.github/workflows/check-test-release.yml @@ -58,7 +58,7 @@ jobs: # no HDF5 support installed for tables - os: windows-latest python: "3.9" - fail-fast: true + fail-fast: false steps: - uses: actions/checkout@v3 with: diff --git a/.pylintrc b/.pylintrc index 4ea818e6..75fac659 100644 --- a/.pylintrc +++ b/.pylintrc @@ -389,7 +389,7 @@ ignore-comments=yes ignore-docstrings=yes # Ignore imports when computing similarities. -ignore-imports=no +ignore-imports=yes # Ignore function signatures when computing similarities. ignore-signatures=no diff --git a/mlem/api/commands.py b/mlem/api/commands.py index c37cd862..99e74ebe 100644 --- a/mlem/api/commands.py +++ b/mlem/api/commands.py @@ -25,7 +25,7 @@ WrongMethodError, ) from mlem.core.import_objects import ImportAnalyzer, ImportHook -from mlem.core.meta_io import MLEM_DIR, Location, UriResolver, get_fs +from mlem.core.meta_io import MLEM_DIR, Location, get_fs from mlem.core.metadata import load_meta, save from mlem.core.objects import ( MlemBuilder, @@ -367,7 +367,7 @@ def ls( # pylint: disable=too-many-locals include_links: bool = True, ignore_errors: bool = False, ) -> Dict[Type[MlemObject], List[MlemObject]]: - loc = UriResolver.resolve( + loc = Location.resolve( "", project=project, rev=rev, fs=fs, find_project=True ) _validate_ls_project(loc, project) @@ -392,7 +392,7 @@ def import_object( """Try to load an object as MLEM model (or data) and return it, optionally saving to the specified target location """ - loc = UriResolver.resolve(path, project, rev, fs) + loc = Location.resolve(path, project, rev, fs) echo(EMOJI_LOAD + f"Importing object from {loc.uri_repr}") if type_ is not None: type_, modifier = parse_import_type_modifier(type_) diff --git a/mlem/cli/apply.py b/mlem/cli/apply.py index 5fd0c1b7..37fc6c27 100644 --- a/mlem/cli/apply.py +++ b/mlem/cli/apply.py @@ -1,13 +1,14 @@ from json import dumps from typing import List, Optional -from typer import Argument, Option +from typer import Argument, Option, Typer from mlem.api import import_object from mlem.cli.main import ( - config_arg, + app, mlem_command, - option_conf, + mlem_group, + mlem_group_callback, option_data_project, option_data_rev, option_external, @@ -20,6 +21,12 @@ option_rev, option_target_project, ) +from mlem.cli.utils import ( + abc_fields_parameters, + config_arg, + for_each_impl, + lazy_class_docstring, +) from mlem.core.data_type import DataAnalyzer from mlem.core.errors import UnsupportedDataBatchLoading from mlem.core.import_objects import ImportHook @@ -32,7 +39,7 @@ @mlem_command("apply", section="runtime") def apply( - model: str = Argument(..., help="Path to model object"), + model: str = Argument(..., metavar="model", help="Path to model object"), data_path: str = Argument(..., metavar="data", help="Path to data object"), project: Optional[str] = option_project, rev: Optional[str] = option_rev, @@ -65,7 +72,8 @@ def apply( external: bool = option_external, json: bool = option_json, ): - """Apply a model to data. Resulting data will be saved as MLEM object to `output` if it is provided, otherwise will be printed + """Apply a model to data. The result will be saved as a MLEM object to `output` if + provided. Otherwise, it will be printed to `stdout`. Examples: Apply local mlem model to local mlem data @@ -120,38 +128,55 @@ def apply( ) -@mlem_command("apply-remote", section="runtime") -def apply_remote( - subtype: str = Argument( - "", - help=f"Type of client. Choices: {list_implementations(Client)}", - show_default=False, - ), - data: str = Argument(..., help="Path to data object"), - project: Optional[str] = option_project, - rev: Optional[str] = option_rev, - output: Optional[str] = Option( - None, "-o", "--output", help="Where to store the outputs." - ), - target_project: Optional[str] = option_target_project, - method: str = option_method, - index: bool = option_index, - json: bool = option_json, - load: Optional[str] = option_load("client"), - conf: List[str] = option_conf("client"), - file_conf: List[str] = option_file_conf("client"), -): - """Apply a model (deployed somewhere remotely) to data. Resulting data will be saved as MLEM object to `output` if it is provided, otherwise will be printed +apply_remote = Typer( + name="apply-remote", + help="""Apply a deployed-model (possibly remotely) to data. The results will be saved as +a MLEM object to `output` if provided. Otherwise, it will be printed to +`stdout`. Examples: Apply hosted mlem model to local mlem data $ mlem apply-remote http mydata -c host="0.0.0.0" -c port=8080 --output myprediction - """ - client = config_arg(Client, load, subtype, conf, file_conf) + """, + cls=mlem_group("runtime"), + subcommand_metavar="client", +) +app.add_typer(apply_remote) + + +def _apply_remote( + data, + project, + rev, + index, + method, + output, + target_project, + json, + type_name, + load, + file_conf, + kwargs, +): + client = config_arg( + Client, + load, + type_name, + conf=None, + file_conf=file_conf, + **(kwargs or {}), + ) with set_echo(None if json else ...): result = run_apply_remote( - client, data, project, rev, index, method, output, target_project + client, + data, + project, + rev, + index, + method, + output, + target_project, ) if output is None and json: print( @@ -161,6 +186,79 @@ def apply_remote( ) +option_output = Option( + None, "-o", "--output", help="Where to store the outputs." +) + + +@mlem_group_callback(apply_remote, required=["data", "load"]) +def apply_remote_load( + data: str = Option(None, "-d", "--data", help="Path to data object"), + project: Optional[str] = option_project, + rev: Optional[str] = option_rev, + output: Optional[str] = option_output, + target_project: Optional[str] = option_target_project, + method: str = option_method, + index: bool = option_index, + json: bool = option_json, + load: Optional[str] = option_load("client"), +): + return _apply_remote( + data, + project, + rev, + index, + method, + output, + target_project, + json, + None, + load, + None, + None, + ) + + +@for_each_impl(Client) +def create_apply_remote(type_name): + @mlem_command( + type_name, + section="clients", + parent=apply_remote, + dynamic_metavar="__kwargs__", + dynamic_options_generator=abc_fields_parameters(type_name, Client), + hidden=type_name.startswith("_"), + lazy_help=lazy_class_docstring(Client.abs_name, type_name), + no_pass_from_parent=["file_conf"], + ) + def apply_remote_func( + data: str = Option(..., "-d", "--data", help="Path to data object"), + project: Optional[str] = option_project, + rev: Optional[str] = option_rev, + output: Optional[str] = option_output, + target_project: Optional[str] = option_target_project, + method: str = option_method, + index: bool = option_index, + json: bool = option_json, + file_conf: List[str] = option_file_conf("client"), + **__kwargs__, + ): + return _apply_remote( + data, + project, + rev, + index, + method, + output, + target_project, + json, + type_name, + None, + file_conf, + __kwargs__, + ) + + def run_apply_remote( client: Client, data_path: str, diff --git a/mlem/cli/build.py b/mlem/cli/build.py index 4b4b1623..41561ea0 100644 --- a/mlem/cli/build.py +++ b/mlem/cli/build.py @@ -1,49 +1,98 @@ from typing import List, Optional -from typer import Argument +from typer import Option, Typer from mlem.cli.main import ( - config_arg, + app, mlem_command, - option_conf, + mlem_group, + mlem_group_callback, option_file_conf, option_load, option_project, option_rev, ) +from mlem.cli.utils import ( + abc_fields_parameters, + config_arg, + for_each_impl, + lazy_class_docstring, +) from mlem.core.metadata import load_meta from mlem.core.objects import MlemBuilder, MlemModel -from mlem.utils.entrypoints import list_implementations + +build = Typer( + name="build", + help=""" + Build models to create re-usable, ship-able entities such as a Docker image or +Python package. + + Examples: + Build docker image from model + $ mlem build mymodel docker -c server.type=fastapi -c image.name=myimage + + Create build docker_dir declaration and build it + $ mlem declare builder docker_dir -c server=fastapi -c target=build build_dock + $ mlem build mymodel --load build_dock + """, + cls=mlem_group("runtime", aliases=["export"]), + subcommand_metavar="builder", +) +app.add_typer(build) -@mlem_command("build", section="runtime", aliases=["export"]) -def build( - model: str = Argument(..., help="Path to model"), - subtype: str = Argument( - "", - help=f"Type of build. Choices: {list_implementations(MlemBuilder)}", - show_default=False, - ), +@mlem_group_callback(build, required=["model", "load"]) +def build_load( + model: str = Option(None, "-m", "--model", help="Path to model"), project: Optional[str] = option_project, rev: Optional[str] = option_rev, - load: Optional[str] = option_load("builder"), - conf: List[str] = option_conf("builder"), - file_conf: List[str] = option_file_conf("builder"), + load: str = option_load("builder"), ): - """ - Build/export model - - Examples: - Build docker image from model - $ mlem build mymodel docker -c server.type=fastapi -c image.name=myimage - - Create build docker_dir declaration and build it - $ mlem declare builder docker_dir -c server=fastapi -c target=build build_dock - $ mlem build mymodel --load build_dock - """ from mlem.api.commands import build build( - config_arg(MlemBuilder, load, subtype, conf, file_conf), + config_arg( + MlemBuilder, + load, + None, + conf=None, + file_conf=None, + ), load_meta(model, project, rev, force_type=MlemModel), ) + + +@for_each_impl(MlemBuilder) +def create_build_command(type_name): + @mlem_command( + type_name, + section="builders", + parent=build, + dynamic_metavar="__kwargs__", + dynamic_options_generator=abc_fields_parameters( + type_name, MlemBuilder + ), + hidden=type_name.startswith("_"), + lazy_help=lazy_class_docstring(MlemBuilder.abs_name, type_name), + no_pass_from_parent=["file_conf"], + ) + def build_type( + model: str = Option(..., "-m", "--model", help="Path to model"), + project: Optional[str] = option_project, + rev: Optional[str] = option_rev, + file_conf: List[str] = option_file_conf("builder"), + **__kwargs__ + ): + from mlem.api.commands import build + + build( + config_arg( + MlemBuilder, + None, + type_name, + conf=None, + file_conf=file_conf, + **__kwargs__ + ), + load_meta(model, project, rev, force_type=MlemModel), + ) diff --git a/mlem/cli/clone.py b/mlem/cli/clone.py index 96f8e3cc..3ae9b1a2 100644 --- a/mlem/cli/clone.py +++ b/mlem/cli/clone.py @@ -22,14 +22,15 @@ def clone( external: Optional[bool] = option_external, index: Optional[bool] = option_index, ): - """Download MLEM object from `uri` and save it to `target` + """Copy a MLEM Object from `uri` and + saves a copy of it to `target` path. - Examples: - Copy remote model to local directory - $ mlem clone models/logreg --project https://github.com/iterative/example-mlem --rev main mymodel + Examples: + Copy remote model to local directory + $ mlem clone models/logreg --project https://github.com/iterative/example-mlem --rev main mymodel - Copy remote model to remote MLEM project - $ mlem clone models/logreg --project https://github.com/iterative/example-mlem --rev main mymodel --tp s3://mybucket/mymodel + Copy remote model to remote MLEM project + $ mlem clone models/logreg --project https://github.com/iterative/example-mlem --rev main mymodel --tp s3://mybucket/mymodel """ from mlem.api.commands import clone diff --git a/mlem/cli/config.py b/mlem/cli/config.py index 50ac9002..a4f04ebe 100644 --- a/mlem/cli/config.py +++ b/mlem/cli/config.py @@ -7,7 +7,7 @@ from mlem.cli.main import app, mlem_command, mlem_group, option_project from mlem.config import CONFIG_FILE_NAME, get_config_cls from mlem.constants import MLEM_DIR -from mlem.core.base import get_recursively, set_recursively, smart_split +from mlem.core.base import SmartSplitDict, get_recursively, smart_split from mlem.core.errors import MlemError from mlem.core.meta_io import get_fs, get_uri from mlem.ui import EMOJI_OK, echo @@ -45,8 +45,9 @@ def config_set( with fs.open(posixpath.join(project, MLEM_DIR, CONFIG_FILE_NAME)) as f: new_conf = safe_load(f) or {} - new_conf[section] = new_conf.get(section, {}) - set_recursively(new_conf[section], smart_split(name, "."), value) + conf = SmartSplitDict(new_conf.get(section, {})) + conf[name] = value + new_conf[section] = conf.build() if validate: config_cls = get_config_cls(section) config_cls(**new_conf[section]) diff --git a/mlem/cli/declare.py b/mlem/cli/declare.py index d840fb22..ef47b819 100644 --- a/mlem/cli/declare.py +++ b/mlem/cli/declare.py @@ -1,40 +1,158 @@ -from typing import List, Optional +from typing import Type -from typer import Argument, Option +from typer import Argument, Typer +from yaml import safe_dump -from ..core.base import build_mlem_object +from ..core.base import MlemABC, build_mlem_object, load_impl_ext +from ..core.meta_io import Location from ..core.objects import MlemObject +from ..utils.entrypoints import list_abstractions, list_implementations from .main import ( + app, mlem_command, + mlem_group, option_external, option_index, option_project, +) +from .utils import ( + abc_fields_parameters, + lazy_class_docstring, wrap_build_error, ) - -@mlem_command("declare", section="object") -def declare( - object_type: str = Argument(..., help="Type of metafile to create"), - subtype: str = Argument("", help="Subtype of MLEM object"), - conf: Optional[List[str]] = Option( - None, - "-c", - "--conf", - help="Values for object fields in format `field.nested.name=value`", - ), - path: str = Argument(..., help="Where to save object"), - project: str = option_project, - external: bool = option_external, - index: bool = option_index, -): - """Creates new mlem object metafile from conf args and config files +declare = Typer( + name="declare", + help="""Declares a new MLEM Object metafile from config args and config files. Examples: Create heroku deployment - $ mlem declare env heroku production -c api_key=<...> - """ - cls = MlemObject.__type_map__[object_type] - with wrap_build_error(subtype, cls): - meta = build_mlem_object(cls, subtype, conf, []) - meta.dump(path, project=project, index=index, external=external) + $ mlem declare env heroku production --api_key <...> + """, + cls=mlem_group("object"), + subcommand_metavar="subtype", +) +app.add_typer(declare) + + +def create_declare_mlem_object(type_name, cls: Type[MlemObject]): + if cls.__is_root__: + typer = Typer( + name=type_name, help=cls.__doc__, cls=mlem_group("Mlem Objects") + ) + declare.add_typer(typer) + + for subtype in list_implementations(MlemObject, cls): + create_declare_mlem_object_subcommand( + typer, subtype, type_name, cls + ) + + +def create_declare_mlem_object_subcommand( + parent: Typer, subtype: str, type_name: str, parent_cls +): + @mlem_command( + subtype, + section="Mlem Objects", + parent=parent, + dynamic_metavar="__kwargs__", + dynamic_options_generator=abc_fields_parameters(subtype, parent_cls), + hidden=subtype.startswith("_"), + lazy_help=lazy_class_docstring(type_name, subtype), + ) + def subtype_command( + path: str = Argument(..., help="Where to save object"), + project: str = option_project, + external: bool = option_external, + index: bool = option_index, + **__kwargs__, + ): + subtype_cls = load_impl_ext(type_name, subtype) + cls = subtype_cls.__type_map__[subtype] + with wrap_build_error(subtype, cls): + meta = build_mlem_object( + cls, subtype, str_conf=None, file_conf=[], **__kwargs__ + ) + meta.dump(path, project=project, index=index, external=external) + + +for meta_type in list_implementations(MlemObject): + create_declare_mlem_object(meta_type, MlemObject.__type_map__[meta_type]) + + +def create_declare_mlem_abc(abs_name: str): + try: + root_cls = MlemABC.abs_types[abs_name] + except KeyError: + root_cls = None + + typer = Typer( + name=abs_name, + help=root_cls.__doc__ + if root_cls + else f"Create `{abs_name}` configuration", + cls=mlem_group("Subtypes"), + ) + declare.add_typer(typer) + + for subtype in list_implementations(abs_name): + if root_cls is None: + try: + impl = load_impl_ext(abs_name, subtype) + root_cls = impl.__parent__ # type: ignore[assignment] + except ImportError: + pass + create_declare_mlem_abc_subcommand(typer, subtype, abs_name, root_cls) + + +def create_declare_mlem_abc_subcommand( + parent: Typer, subtype: str, abs_name: str, root_cls +): + @mlem_command( + subtype, + section="Subtypes", + parent=parent, + dynamic_metavar="__kwargs__", + dynamic_options_generator=abc_fields_parameters(subtype, root_cls) + if root_cls + else None, + hidden=subtype.startswith("_"), + lazy_help=lazy_class_docstring(abs_name, subtype), + ) + def subtype_command( + path: str = Argument(..., help="Where to save object"), + project: str = option_project, + **__kwargs__, + ): + with wrap_build_error(subtype, root_cls): + obj = build_mlem_object( + root_cls, subtype, str_conf=None, file_conf=[], **__kwargs__ + ) + location = Location.resolve( + path=path, project=project, rev=None, fs=None + ) + with location.fs.open(location.fullpath, "w") as f: + safe_dump(obj.dict(), f) + + +_internal = { + "artifact", + "data_reader", + "data_type", + "data_writer", + "deploy_state", + "import", + "interface", + "meta", + "model_io", + "model_type", + "requirement", + "resolver", + "storage", +} +for abs_name in list_abstractions(include_hidden=False): + if abs_name in {"builder", "env", "deployment"}: + continue + if abs_name in _internal: + continue + create_declare_mlem_abc(abs_name) diff --git a/mlem/cli/deployment.py b/mlem/cli/deployment.py index 9ce50d1a..3547cab2 100644 --- a/mlem/cli/deployment.py +++ b/mlem/cli/deployment.py @@ -8,6 +8,7 @@ app, mlem_command, mlem_group, + option_conf, option_data_project, option_data_rev, option_external, @@ -27,7 +28,7 @@ deployment = Typer( name="deployment", - help="Manage deployments", + help="A set of commands to set up and manage deployments.", cls=mlem_group("runtime", aliases=["deploy"]), ) app.add_typer(deployment) @@ -46,24 +47,20 @@ def deploy_run( project: Optional[str] = option_project, external: bool = option_external, index: bool = option_index, - conf: Optional[List[str]] = Option( - None, - "-c", - "--conf", - help="Configuration for new deployment meta if it does not exist", - ), + conf: Optional[List[str]] = option_conf(), ): - """Deploy a model to target environment. Can use existing deployment declaration or create a new one on-the-fly - - Examples: - Create new deployment - $ mlem declare env heroku staging -c api_key=... - $ mlem deploy run service_name -m model -t staging -c name=my_service - - Deploy existing meta - $ mlem declare env heroku staging -c api_key=... - $ mlem declare deployment heroku service_name -c app_name=my_service -c model=model -c env=staging - $ mlem deploy run service_name + """Deploy a model to a target environment. Can use an existing deployment + declaration or create a new one on-the-fly. + + Examples: + Create new deployment + $ mlem declare env heroku staging -c api_key=... + $ mlem deploy run service_name -m model -t staging -c name=my_service + + Deploy existing meta + $ mlem declare env heroku staging -c api_key=... + $ mlem declare deployment heroku service_name -c app_name=my_service -c model=model -c env=staging + $ mlem deploy run service_name """ from mlem.api.commands import deploy @@ -83,7 +80,7 @@ def deploy_remove( path: str = Argument(..., help="Path to deployment meta"), project: Optional[str] = option_project, ): - """Stop and destroy deployed instance + """Stop and destroy deployed instance. Examples: $ mlem deployment remove service_name @@ -97,7 +94,7 @@ def deploy_status( path: str = Argument(..., help="Path to deployment meta"), project: Optional[str] = option_project, ): - """Print status of deployed service + """Print status of deployed service. Examples: $ mlem deployment status service_name @@ -126,7 +123,7 @@ def deploy_apply( index: bool = option_index, json: bool = option_json, ): - """Apply method of deployed service + """Apply a deployed model to data. Examples: $ mlem deployment apply service_name diff --git a/mlem/cli/import_object.py b/mlem/cli/import_object.py index 782bfa84..7eae2def 100644 --- a/mlem/cli/import_object.py +++ b/mlem/cli/import_object.py @@ -29,7 +29,7 @@ def import_object( index: bool = option_index, external: bool = option_external, ): - """Create MLEM model or data metadata from file/dir + """Create a `.mlem` metafile for a model or data in any file or directory. Examples: Create MLEM data from local csv diff --git a/mlem/cli/info.py b/mlem/cli/info.py index cef6ac6f..9aea8315 100644 --- a/mlem/cli/info.py +++ b/mlem/cli/info.py @@ -4,13 +4,8 @@ from typer import Argument, Option -from mlem.cli.main import ( - Choices, - mlem_command, - option_json, - option_project, - option_rev, -) +from mlem.cli.main import mlem_command, option_json, option_project, option_rev +from mlem.cli.utils import Choices from mlem.core.metadata import load_meta from mlem.core.objects import MLEM_EXT, MlemLink, MlemObject from mlem.ui import echo, set_echo @@ -57,14 +52,15 @@ def ls( ), rev: Optional[str] = option_rev, links: bool = Option( - True, "+l/-l", "--links/--no-links", help="Include links" + True, "+l/-l", "--links/--no-links", help="Whether to include links" ), json: bool = option_json, ignore_errors: bool = Option( False, "-i", "--ignore-errors", help="Ignore corrupted objects" ), ): - """List MLEM objects of in project + """List MLEM objects inside a MLEM project (location should be [initialized](/doc/command-reference/init)). + Examples: $ mlem list https://github.com/iterative/example-mlem @@ -114,14 +110,15 @@ def pretty_print( ), json: bool = option_json, ): - """Print specified MLEM object + """Display all details about a specific MLEM Object from an existing MLEM + project. - Examples: - Print local object - $ mlem pprint mymodel + Examples: + Print local object + $ mlem pprint mymodel - Print remote object - $ mlem pprint https://github.com/iterative/example-mlem/models/logreg + Print remote object + $ mlem pprint https://github.com/iterative/example-mlem/models/logreg """ with set_echo(None if json else ...): meta = load_meta( diff --git a/mlem/cli/init.py b/mlem/cli/init.py index 1f881e38..8160f21b 100644 --- a/mlem/cli/init.py +++ b/mlem/cli/init.py @@ -7,7 +7,7 @@ def init( path: str = Argument(".", help="Where to init project", show_default=False) ): - """Initialize MLEM project + """Initialize a MLEM project. Examples: $ mlem init diff --git a/mlem/cli/link.py b/mlem/cli/link.py index 2cdf7c35..691fa70c 100644 --- a/mlem/cli/link.py +++ b/mlem/cli/link.py @@ -12,7 +12,9 @@ @mlem_command("link", section="object") def link( - source: str = Argument(..., help="URI to object you are crating link to"), + source: str = Argument( + ..., help="URI of the object you are creating a link to" + ), target: str = Argument(..., help="Path to save link object"), source_project: Optional[str] = Option( None, @@ -36,14 +38,15 @@ def link( help="Which path to linked object to specify: absolute or relative.", ), ): - """Create link for MLEM object + """Create a link (read alias) for an existing MLEM Object, including from + remote MLEM projects. - Examples: - Add alias to local object - $ mlem link my_model latest + Examples: + Add alias to local object + $ mlem link my_model latest - Add remote object to your project without copy - $ mlem link models/logreg --source-project https://github.com/iteartive/example-mlem remote_model + Add remote object to your project without copy + $ mlem link models/logreg --source-project https://github.com/iteartive/example-mlem remote_model """ from mlem.api.commands import link diff --git a/mlem/cli/main.py b/mlem/cli/main.py index d4ab0c3b..ebba1d01 100644 --- a/mlem/cli/main.py +++ b/mlem/cli/main.py @@ -1,27 +1,39 @@ -import contextlib +import inspect import logging -import typing as t from collections import defaultdict -from enum import Enum, EnumMeta from functools import partial, wraps from gettext import gettext -from typing import List, Optional, Tuple, Type +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Sequence, + Union, +) +import click import typer -from click import Abort, ClickException, Command, HelpFormatter, pass_context -from click.exceptions import Exit -from pydantic import BaseModel, MissingError, ValidationError, parse_obj_as -from pydantic.error_wrappers import ErrorWrapper +from click import Abort, ClickException, Command, HelpFormatter, Parameter +from click.exceptions import Exit, MissingParameter +from pydantic import ValidationError from typer import Context, Option, Typer from typer.core import TyperCommand, TyperGroup -from yaml import safe_load from mlem import LOCAL_CONFIG, version +from mlem.cli.utils import ( + FILE_CONF_PARAM_NAME, + LOAD_PARAM_NAME, + NOT_SET, + CallContext, + _extract_examples, + _format_validation_error, + get_extra_keys, +) from mlem.constants import MLEM_DIR, PREDICT_METHOD_NAME -from mlem.core.base import MlemABC, build_mlem_object from mlem.core.errors import MlemError -from mlem.core.metadata import load_meta -from mlem.core.objects import MlemObject from mlem.telemetry import telemetry from mlem.ui import EMOJI_FAIL, EMOJI_MLEM, bold, cli_echo, color, echo @@ -46,7 +58,7 @@ def __init__( self.aliases = aliases self.rich_help_panel = section.capitalize() - def collect_usage_pieces(self, ctx: Context) -> t.List[str]: + def collect_usage_pieces(self, ctx: Context) -> List[str]: return [p.lower() for p in super().collect_usage_pieces(ctx)] def get_help(self, ctx: Context) -> str: @@ -77,9 +89,20 @@ def __init__( section: str = "other", aliases: List[str] = None, help: Optional[str] = None, + dynamic_options_generator: Callable[ + [CallContext], Iterable[Parameter] + ] = None, + dynamic_metavar: str = None, + lazy_help: Optional[Callable[[], str]] = None, + pass_from_parent: Optional[List[str]] = None, **kwargs, ): + self.dynamic_metavar = dynamic_metavar + self.dynamic_options_generator = dynamic_options_generator examples, help = _extract_examples(help) + self._help = help + self.lazy_help = lazy_help + self.pass_from_parent = pass_from_parent super().__init__( name=name, section=section, @@ -89,20 +112,90 @@ def __init__( **kwargs, ) + def make_context( + self, + info_name: Optional[str], + args: List[str], + parent: Optional[Context] = None, + **extra: Any, + ) -> Context: + args_copy = args[:] + ctx = super().make_context(info_name, args, parent, **extra) + if not self.dynamic_options_generator: + return ctx + extra_args = ctx.args + params = ctx.params.copy() + while extra_args: + ctx.params = params + ctx.args = args_copy[:] + with ctx.scope(cleanup=False): + self.parse_args(ctx, args_copy[:]) + params.update(ctx.params) + + if ctx.args == extra_args: + break + extra_args = ctx.args + + return ctx + + def invoke(self, ctx: Context) -> Any: + ctx.params = {k: v for k, v in ctx.params.items() if v != NOT_SET} + return super().invoke(ctx) + + def get_params(self, ctx) -> List["Parameter"]: + regular_options = super().get_params(ctx) + res: List[Parameter] = ( + list( + self.dynamic_options_generator( + CallContext( + ctx.params, + get_extra_keys(ctx.args), + [o.name for o in regular_options], + ) + ) + ) + if self.dynamic_options_generator is not None + else [] + ) + regular_options + + if self.dynamic_metavar is not None: + kw_param = [p for p in res if p.name == self.dynamic_metavar] + if len(kw_param) > 0: + res.remove(kw_param[0]) + if self.pass_from_parent is not None: + res = [ + o + for o in res + if o.name not in self.pass_from_parent + or o.name not in ctx.parent.params + or ctx.parent.params[o.name] is None + ] + return res + + @property + def help(self): + if self.lazy_help: + return self.lazy_help() + return self._help + + @help.setter + def help(self, value): + self._help = value + class MlemGroup(MlemMixin, TyperGroup): order = ["common", "object", "runtime", "other"] def __init__( self, - name: t.Optional[str] = None, - commands: t.Optional[ - t.Union[t.Dict[str, Command], t.Sequence[Command]] + name: Optional[str] = None, + commands: Optional[ + Union[Dict[str, Command], Sequence[Command]] ] = None, section: str = "other", aliases: List[str] = None, help: str = None, - **attrs: t.Any, + **attrs: Any, ) -> None: examples, help = _extract_examples(help) super().__init__( @@ -152,7 +245,7 @@ def format_commands(self, ctx: Context, formatter: HelpFormatter) -> None: ): formatter.write_dl(sections[section]) - def get_command(self, ctx: Context, cmd_name: str) -> t.Optional[Command]: + def get_command(self, ctx: Context, cmd_name: str) -> Optional[Command]: cmd = super().get_command(ctx, cmd_name) if cmd is not None: return cmd @@ -175,25 +268,29 @@ def __init__(self, *args, **kwargs): return MlemGroupSection -class ChoicesMeta(EnumMeta): - def __call__(cls, *names, module=None, qualname=None, type=None, start=1): - if len(names) == 1: - return super().__call__(names[0]) - return super().__call__( - "Choice", - names, - module=module, - qualname=qualname, - type=type, - start=start, +def mlem_group_callback(group: Typer, required: Optional[List[str]] = None): + def decorator(f): + @wraps(f) + def inner(*args, **kwargs): + ctx = click.get_current_context() + if ctx.invoked_subcommand is not None: + return None + if required is not None: + for req in required: + if req not in kwargs or kwargs[req] is None: + param = [ + p + for p in ctx.command.get_params(ctx) + if p.name == req + ][0] + raise MissingParameter(ctx=ctx, param=param) + return f(*args, **kwargs) + + return group.callback(invoke_without_command=True)( + wrap_mlem_cli_call(inner, None) ) - -class Choices(str, Enum, metaclass=ChoicesMeta): - def _generate_next_value_( # pylint: disable=no-self-argument - name, start, count, last_values - ): - return name + return decorator app = Typer( @@ -241,16 +338,12 @@ def mlem_callback( ctx.obj = {"traceback": traceback or LOCAL_CONFIG.DEBUG} -def _extract_examples( - help_str: Optional[str], -) -> Tuple[Optional[str], Optional[str]]: - if help_str is None: - return None, None - try: - examples = help_str.index("Examples:") - except ValueError: - return None, help_str - return help_str[examples + len("Examples:") + 1 :], help_str[:examples] +def get_cmd_name(ctx: Context): + pieces = [] + while ctx.parent is not None: + pieces.append(ctx.info_name) + ctx = ctx.parent + return " ".join(reversed(pieces)) def mlem_command( @@ -259,68 +352,106 @@ def mlem_command( aliases=None, options_metavar="[options]", parent=app, + mlem_cls=None, + dynamic_metavar=None, + dynamic_options_generator=None, + lazy_help=None, + pass_from_parent: Optional[List[str]] = None, + no_pass_from_parent: Optional[List[str]] = None, **kwargs, ): def decorator(f): - if len(args) > 0: - cmd_name = args[0] + context_settings = kwargs.get("context_settings", {}) + if dynamic_options_generator: + context_settings.update( + {"allow_extra_args": True, "ignore_unknown_options": True} + ) + if no_pass_from_parent is not None: + _pass_from_parent = [ + a + for a in inspect.getfullargspec(f).args + if a not in no_pass_from_parent + ] else: - cmd_name = kwargs.get("name", f.__name__) - - @parent.command( + _pass_from_parent = pass_from_parent + call = wrap_mlem_cli_call(f, _pass_from_parent) + return parent.command( *args, options_metavar=options_metavar, + context_settings=context_settings, **kwargs, - cls=partial(MlemCommand, section=section, aliases=aliases), - ) - @wraps(f) - @pass_context - def inner(ctx, *iargs, **ikwargs): - res = {} - error = None - try: - with cli_echo(): - res = f(*iargs, **ikwargs) or {} - res = {f"cmd_{cmd_name}_{k}": v for k, v in res.items()} - except (ClickException, Exit, Abort) as e: - error = f"{e.__class__.__module__}.{e.__class__.__name__}" + cls=partial( + mlem_cls or MlemCommand, + section=section, + aliases=aliases, + dynamic_options_generator=dynamic_options_generator, + dynamic_metavar=dynamic_metavar, + lazy_help=lazy_help, + pass_from_parent=pass_from_parent, + ), + )(call) + + return decorator + + +def wrap_mlem_cli_call(f, pass_from_parent: Optional[List[str]]): + @wraps(f) + def inner(*iargs, **ikwargs): + res = {} + error = None + ctx = click.get_current_context() + cmd_name = get_cmd_name(ctx) + try: + if pass_from_parent is not None: + ikwargs.update( + { + o: ctx.parent.params[o] + for o in pass_from_parent + if o in ctx.parent.params + and (o not in ikwargs or ikwargs[o] is None) + } + ) + with cli_echo(): + res = f(*iargs, **ikwargs) or {} + res = {f"cmd_{cmd_name}_{k}": v for k, v in res.items()} + except (ClickException, Exit, Abort) as e: + error = f"{e.__class__.__module__}.{e.__class__.__name__}" + raise + except MlemError as e: + error = f"{e.__class__.__module__}.{e.__class__.__name__}" + if ctx.obj["traceback"]: raise - except MlemError as e: - error = f"{e.__class__.__module__}.{e.__class__.__name__}" - if ctx.obj["traceback"]: - raise - with cli_echo(): - echo(EMOJI_FAIL + color(str(e), col=typer.colors.RED)) - raise typer.Exit(1) - except ValidationError as e: - error = f"{e.__class__.__module__}.{e.__class__.__name__}" - if ctx.obj["traceback"]: - raise - msgs = "\n".join(_format_validation_error(e)) - with cli_echo(): - echo(EMOJI_FAIL + color("Error:\n", "red") + msgs) - raise typer.Exit(1) - except Exception as e: # pylint: disable=broad-except - error = f"{e.__class__.__module__}.{e.__class__.__name__}" - if ctx.obj["traceback"]: - raise - with cli_echo(): - echo( - EMOJI_FAIL - + color( - "Unexpected error: " + str(e), col=typer.colors.RED - ) - ) - echo( - "Please report it here: " + with cli_echo(): + echo(EMOJI_FAIL + color(str(e), col=typer.colors.RED)) + raise typer.Exit(1) + except ValidationError as e: + error = f"{e.__class__.__module__}.{e.__class__.__name__}" + if ctx.obj["traceback"]: + raise + msgs = "\n".join(_format_validation_error(e)) + with cli_echo(): + echo(EMOJI_FAIL + color("Error:\n", "red") + msgs) + raise typer.Exit(1) + except Exception as e: # pylint: disable=broad-except + error = f"{e.__class__.__module__}.{e.__class__.__name__}" + if ctx.obj["traceback"]: + raise + with cli_echo(): + echo( + EMOJI_FAIL + + color( + "Unexpected error: " + str(e), col=typer.colors.RED ) - raise typer.Exit(1) - finally: + ) + echo( + "Please report it here: " + ) + raise typer.Exit(1) + finally: + if error is not None or ctx.invoked_subcommand is None: telemetry.send_cli_call(cmd_name, error=error, **res) - return inner - - return decorator + return inner option_project = Option( @@ -369,7 +500,10 @@ def inner(ctx, *iargs, **ikwargs): def option_load(type_: str = None): type_ = type_ + " " if type_ is not None else "" return Option( - None, "-l", "--load", help=f"File to load {type_}config from" + None, + "-l", + f"--{LOAD_PARAM_NAME}", + help=f"File to load {type_}config from", ) @@ -388,91 +522,6 @@ def option_file_conf(type_: str = None): return Option( None, "-f", - "--file_conf", + f"--{FILE_CONF_PARAM_NAME}", help=f"File with options {type_}in format `field.name=path_to_config`", ) - - -def _iter_errors( - errors: t.Sequence[t.Any], model: Type, loc: Optional[Tuple] = None -): - for error in errors: - if isinstance(error, ErrorWrapper): - - if loc: - error_loc = loc + error.loc_tuple() - else: - error_loc = error.loc_tuple() - - if isinstance(error.exc, ValidationError): - yield from _iter_errors( - error.exc.raw_errors, error.exc.model, error_loc - ) - else: - yield error_loc, model, error.exc - - -def _format_validation_error(error: ValidationError) -> List[str]: - res = [] - for loc, model, exc in _iter_errors(error.raw_errors, error.model): - path = ".".join(loc_part for loc_part in loc if loc_part != "__root__") - field_name = loc[-1] - if field_name not in model.__fields__: - res.append( - f"Unknown field '{field_name}'. Fields available: {', '.join(model.__fields__)}" - ) - continue - field_type = model.__fields__[field_name].type_ - if ( - isinstance(exc, MissingError) - and isinstance(field_type, type) - and issubclass(field_type, BaseModel) - ): - msgs = [ - str(EMOJI_FAIL + f"field `{path}.{f.name}`: {exc}") - for f in field_type.__fields__.values() - if f.required - ] - if msgs: - res.extend(msgs) - else: - res.append(str(EMOJI_FAIL + f"field `{path}`: {exc}")) - else: - res.append(str(EMOJI_FAIL + f"field `{path}`: {exc}")) - return res - - -@contextlib.contextmanager -def wrap_build_error(subtype, model: Type[MlemABC]): - try: - yield - except ValidationError as e: - msgs = "\n".join(_format_validation_error(e)) - raise typer.BadParameter( - f"Error on constructing {subtype} {model.abs_name}:\n{msgs}" - ) from e - - -def config_arg( - model: Type[MlemABC], - load: Optional[str], - subtype: str, - conf: Optional[List[str]], - file_conf: Optional[List[str]], -): - obj: MlemABC - if load is not None: - if issubclass(model, MlemObject): - obj = load_meta(load, force_type=model) - else: - with open(load, "r", encoding="utf8") as of: - obj = parse_obj_as(model, safe_load(of)) - else: - if not subtype: - raise typer.BadParameter( - f"Cannot configure {model.abs_name}: either subtype or --load should be provided" - ) - with wrap_build_error(subtype, model): - obj = build_mlem_object(model, subtype, conf, file_conf) - - return obj diff --git a/mlem/cli/serve.py b/mlem/cli/serve.py index 0ca81c51..5d6ff54e 100644 --- a/mlem/cli/serve.py +++ b/mlem/cli/serve.py @@ -1,42 +1,95 @@ from typing import List, Optional -from typer import Argument +from typer import Option, Typer from mlem.cli.main import ( - config_arg, + app, mlem_command, - option_conf, + mlem_group, + mlem_group_callback, option_file_conf, option_load, option_project, option_rev, ) +from mlem.cli.utils import ( + abc_fields_parameters, + config_arg, + for_each_impl, + lazy_class_docstring, +) from mlem.core.metadata import load_meta from mlem.core.objects import MlemModel from mlem.runtime.server import Server -from mlem.utils.entrypoints import list_implementations + +serve = Typer( + name="serve", + help="""Deploy the model locally using a server implementation and expose its methods as +endpoints. + + Examples: + $ mlem serve fastapi https://github.com/iterative/example-mlem/models/logreg + """, + cls=mlem_group("runtime"), + subcommand_metavar="server", +) +app.add_typer(serve) -@mlem_command("serve", section="runtime") -def serve( - model: str = Argument(..., help="Model to create service from"), - subtype: str = Argument( - "", help=f"Server type. Choices: {list_implementations(Server)}" +@mlem_group_callback(serve, required=["model", "load"]) +def serve_load( + model: str = Option( + None, "-m", "--model", help="Model to create service from" ), project: Optional[str] = option_project, rev: Optional[str] = option_rev, load: Optional[str] = option_load("server"), - conf: List[str] = option_conf("server"), - file_conf: List[str] = option_file_conf("server"), ): - """Serve selected model - - Examples: - $ mlem serve https://github.com/iterative/example-mlem/models/logreg fastapi - """ from mlem.api.commands import serve serve( load_meta(model, project, rev, force_type=MlemModel), - config_arg(Server, load, subtype, conf, file_conf), + config_arg( + Server, + load, + None, + conf=None, + file_conf=None, + ), + ) + + +@for_each_impl(Server) +def create_serve_command(type_name): + @mlem_command( + type_name, + section="servers", + parent=serve, + dynamic_metavar="__kwargs__", + dynamic_options_generator=abc_fields_parameters(type_name, Server), + hidden=type_name.startswith("_"), + lazy_help=lazy_class_docstring(Server.abs_name, type_name), + no_pass_from_parent=["file_conf"], ) + def serve_command( + model: str = Option( + ..., "-m", "--model", help="Model to create service from" + ), + project: Optional[str] = option_project, + rev: Optional[str] = option_rev, + file_conf: List[str] = option_file_conf("server"), + **__kwargs__ + ): + from mlem.api.commands import serve + + serve( + load_meta(model, project, rev, force_type=MlemModel), + config_arg( + Server, + None, + type_name, + conf=None, + file_conf=file_conf, + **__kwargs__ + ), + ) diff --git a/mlem/cli/types.py b/mlem/cli/types.py index c3d56381..af1f094e 100644 --- a/mlem/cli/types.py +++ b/mlem/cli/types.py @@ -1,59 +1,64 @@ -from typing import Optional, Type +from typing import Iterator, Optional, Type from pydantic import BaseModel from typer import Argument from mlem.cli.main import mlem_command +from mlem.cli.utils import CliTypeField, iterate_type_fields, parse_type_field from mlem.core.base import MlemABC, load_impl_ext +from mlem.core.errors import MlemError from mlem.core.objects import MlemObject from mlem.ui import EMOJI_BASE, bold, color, echo -from mlem.utils.entrypoints import list_implementations +from mlem.utils.entrypoints import list_abstractions, list_implementations -def explain_type(cls: Type[BaseModel], prefix="", force_not_req=False): - for name, field in sorted( - cls.__fields__.items(), key=lambda x: not x[1].required - ): - if issubclass(cls, MlemObject) and name in MlemObject.__fields__: - continue - if issubclass(cls, MlemABC) and name in cls.__config__.exclude: - continue - fullname = name if not prefix else f"{prefix}.{name}" - module = field.type_.__module__ - type_name = getattr(field.type_, "__name__", str(field.type_)) - if module != "builtins" and "." not in type_name: - type_name = f"{module}.{type_name}" - type_name = color(type_name, "yellow") - - if field.required and not force_not_req: - req = color("[required] ", "grey") - else: - req = color("[not required] ", "white") - if not field.required: - default = field.default - if isinstance(default, str): - default = f'"{default}"' - default = f" = {default}" - else: - default = "" - if ( - isinstance(field.type_, type) - and issubclass(field.type_, MlemABC) - and field.type_.__is_root__ - ): - echo( - req - + color(fullname, "green") - + ": One of " - + color(f"mlem types {field.type_.abs_name}", "yellow") +def _add_examples(generator: Iterator[CliTypeField], parent_help=None): + for field in generator: + field.help = parent_help or field.help + yield field + if field.is_list or field.is_mapping: + key = ".key" if field.is_mapping else ".0" + yield from _add_examples( + parse_type_field( + path=field.path + key, + type_=field.type_, + help_=field.help, + is_list=False, + is_mapping=False, + required=False, + allow_none=False, + default=None, + ), + parent_help=f"Element of {field.path}", ) - elif isinstance(field.type_, type) and issubclass( - field.type_, BaseModel - ): - echo(req + color(fullname, "green") + ": " + type_name) - explain_type(field.type_, fullname, not field.required) - else: - echo(req + color(fullname, "green") + ": " + type_name + default) + + +def type_fields_with_collection_examples(cls): + yield from _add_examples(iterate_type_fields(cls)) + + +def explain_type(cls: Type[BaseModel]): + echo( + color("Type ", "") + + color(cls.__module__ + ".", "yellow") + + color(cls.__name__, "green") + ) + if issubclass(cls, MlemABC): + echo(color("MlemABC parent type: ", "") + color(cls.abs_name, "green")) + echo(color("MlemABC type: ", "") + color(cls.__get_alias__(), "green")) + if issubclass(cls, MlemObject): + echo( + color("MlemObject type name: ", "") + + color(cls.object_type, "green") + ) + echo((cls.__doc__ or "Class docstring missing").strip()) + fields = list(type_fields_with_collection_examples(cls)) + if not fields: + echo("No fields") + else: + echo("Fields:") + for field in fields: + echo(field.to_text()) @mlem_command("types", hidden=True) @@ -64,34 +69,50 @@ def list_types( ), sub_type: Optional[str] = Argument(None, help="Type of `meta` subtype"), ): - """List MLEM types implementations available in current env. - If subtype is not provided, list ABCs + """List different implementations available for a particular MLEM type. If a + subtype is not provided, list all available MLEM types. - Examples: - List ABCs - $ mlem types + Examples: + List ABCs + $ mlem types - List available server implementations - $ mlem types server + List available server implementations + $ mlem types server """ if abc is None: for at in MlemABC.abs_types.values(): echo(EMOJI_BASE + bold(at.abs_name) + ":") echo( - f"\tBase class: {at.__module__}.{at.__name__}\n\t{at.__doc__.strip()}" + f"\tBase class: {at.__module__}.{at.__name__}\n\t{(at.__doc__ or 'Class docstring missing').strip()}" ) elif abc == MlemObject.abs_name: if sub_type is None: - echo(list(MlemObject.non_abstract_subtypes().keys())) + echo("\n".join(MlemObject.non_abstract_subtypes().keys())) else: - echo( - list_implementations( - MlemObject, MlemObject.non_abstract_subtypes()[sub_type] + mlem_object_type = MlemObject.non_abstract_subtypes()[sub_type] + if mlem_object_type.__is_root__: + echo( + "\n".join( + list_implementations( + MlemObject, mlem_object_type, include_hidden=False + ) + ) ) - ) + else: + explain_type(mlem_object_type) else: if sub_type is None: - echo(list_implementations(abc)) + abcs = list_abstractions(include_hidden=False) + if abc not in abcs: + raise MlemError( + f"Unknown abc \"{abc}\". Known abcs: {' '.join(abcs)}" + ) + echo("\n".join(list_implementations(abc, include_hidden=False))) else: - cls = load_impl_ext(abc, sub_type, True) + try: + cls = load_impl_ext(abc, sub_type, True) + except ValueError as e: + raise MlemError( + f"Unknown implementation \"{sub_type}\" of abc \"{abc}\". Known implementations: {' '.join(list_implementations(abc, include_hidden=False))}" + ) from e explain_type(cls) diff --git a/mlem/cli/utils.py b/mlem/cli/utils.py new file mode 100644 index 00000000..f06d28e6 --- /dev/null +++ b/mlem/cli/utils.py @@ -0,0 +1,621 @@ +import ast +import contextlib +import inspect +from dataclasses import dataclass +from enum import Enum, EnumMeta +from functools import lru_cache +from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Type + +import typer +from click import Context, MissingParameter +from pydantic import ( + BaseModel, + MissingError, + ValidationError, + create_model, + parse_obj_as, +) +from pydantic.error_wrappers import ErrorWrapper +from pydantic.fields import ( + MAPPING_LIKE_SHAPES, + SHAPE_LIST, + SHAPE_SEQUENCE, + SHAPE_SET, + SHAPE_TUPLE, + SHAPE_TUPLE_ELLIPSIS, + ModelField, +) +from pydantic.typing import display_as_type, get_args, is_union +from typer.core import TyperOption +from typing_extensions import get_origin +from yaml import safe_load + +from mlem import LOCAL_CONFIG +from mlem.core.base import ( + MlemABC, + build_mlem_object, + load_impl_ext, + smart_split, +) +from mlem.core.errors import ExtensionRequirementError, MlemObjectNotFound +from mlem.core.meta_io import Location +from mlem.core.metadata import load_meta +from mlem.core.objects import MlemObject +from mlem.ui import EMOJI_FAIL, color +from mlem.utils.entrypoints import list_implementations +from mlem.utils.module import lstrip_lines + +LIST_LIKE_SHAPES = ( + SHAPE_LIST, + SHAPE_TUPLE, + SHAPE_SET, + SHAPE_TUPLE_ELLIPSIS, + SHAPE_SEQUENCE, +) + + +class ChoicesMeta(EnumMeta): + def __call__(cls, *names, module=None, qualname=None, type=None, start=1): + if len(names) == 1: + return super().__call__(names[0]) + return super().__call__( + "Choice", + names, + module=module, + qualname=qualname, + type=type, + start=start, + ) + + +class Choices(str, Enum, metaclass=ChoicesMeta): + def _generate_next_value_( # pylint: disable=no-self-argument + name, start, count, last_values + ): + return name + + +class CliTypeField(BaseModel): + """A descriptor of model field to build cli option""" + + path: str + """a.dotted.path from schema root""" + required: bool + allow_none: bool + type_: Type + help: str + default: Any + is_list: bool + is_mapping: bool + mapping_key_type: Optional[Type] + + @property + def type_repr(self): + type_name = self.type_.__name__ + if self.is_list: + return f"List[{type_name}]" + if self.is_mapping: + return f"Dict[{self.mapping_key_type.__name__}, {type_name}]" + return type_name + + def to_text(self): + req = ( + color("[required]", "") + if self.required + else color("[not required]", "white") + ) + if not self.required: + default = self.default + if isinstance(default, str): + default = f'"{default}"' + default = f" = {default}" + else: + default = "" + return ( + req + + " " + + color(self.path, "green") + + ": " + + self.type_repr + + default + + "\n\t" + + self.help + ) + + +@lru_cache() +def get_attribute_docstrings(cls) -> Dict[str, str]: + """Parses cls source to find all classfields followed by docstring expr""" + res = {} + tree = ast.parse(lstrip_lines(inspect.getsource(cls))) + class_def = tree.body[0] + assert isinstance(class_def, ast.ClassDef) + field: Optional[str] = None + for statement in class_def.body: + if isinstance(statement, ast.AnnAssign) and isinstance( + statement.target, ast.Name + ): + field = statement.target.id + continue + if ( + isinstance(statement, ast.Assign) + and len(statement.targets) == 1 + and isinstance(statement.targets[0], ast.Name) + ): + field = statement.targets[0].id + continue + if field is not None and isinstance(statement, ast.Expr): + if isinstance(statement.value, ast.Constant) and isinstance( + statement.value.value, str + ): + res[field] = statement.value.value + if isinstance(statement.value, ast.Str): + res[field] = statement.value.s + field = None + return res + + +@lru_cache() +def get_field_help(cls: Type, field_name: str): + """Parses all class mro to find classfield docstring""" + for base_cls in cls.mro(): + if base_cls is object: + continue + try: + docsting = get_attribute_docstrings(base_cls).get(field_name) + if docsting: + return docsting + except OSError: + pass + return "Field docstring missing" + + +def _get_type_name_alias(type_): + if not isinstance(type_, type): + type_ = get_origin(type_) + return type_.__name__ if type_ is not None else "any" + + +def anything(type_): + """Creates special type that is named as original type or collection type + It returns original object on creation and is needed for nice typename in cli option help""" + return type( + _get_type_name_alias(type_), (), {"__new__": lambda cls, value: value} + ) + + +def optional(type_): + """Creates special type that is named as original type or collection type + It allows use string `None` to indicate None value""" + return type( + _get_type_name_alias(type_), + (), + { + "__new__": lambda cls, value: None + if value == "None" + else type_(value) + }, + ) + + +def parse_type_field( + path: str, + type_: Type, + help_: str, + is_list: bool, + is_mapping: bool, + required: bool, + allow_none: bool, + default: Any, +) -> Iterator[CliTypeField]: + """Recursively creates CliTypeFields from field description""" + if is_list or is_mapping: + # collection + yield CliTypeField( + required=required, + allow_none=allow_none, + path=path, + type_=type_, + default=default, + help=help_, + is_list=is_list, + is_mapping=is_mapping, + mapping_key_type=str, + ) + return + + if ( + isinstance(type_, type) + and issubclass(type_, MlemABC) + and type_.__is_root__ + ): + # mlem abstraction: substitute default and extend help + if isinstance(default, type_): + default = default.__get_alias__() + yield CliTypeField( + required=required, + allow_none=allow_none, + path=path, + type_=type_, + help=f"{help_}. One of {list_implementations(type_, include_hidden=False)}. Run 'mlem types {type_.abs_name} ' for list of nested fields for each subtype", + default=default, + is_list=is_list, + is_mapping=is_mapping, + mapping_key_type=str, + ) + return + if isinstance(type_, type) and issubclass(type_, BaseModel): + # BaseModel (including MlemABC non-root classes): reqursively get nested + yield from iterate_type_fields(type_, path, not required) + return + # probably primitive field + yield CliTypeField( + required=required, + allow_none=allow_none, + path=path, + type_=type_, + default=default, + help=help_, + is_list=is_list, + is_mapping=is_mapping, + mapping_key_type=str, + ) + + +def iterate_type_fields( + cls: Type[BaseModel], path: str = "", force_not_req: bool = False +) -> Iterator[CliTypeField]: + """Recursively get CliTypeFields from BaseModel""" + field: ModelField + for name, field in sorted( + cls.__fields__.items(), key=lambda x: not x[1].required + ): + name = field.alias or name + if issubclass(cls, MlemObject) and name in MlemObject.__fields__: + # Skip base MlemObject fields + continue + if ( + issubclass(cls, MlemABC) + and name in cls.__config__.exclude + or field.field_info.exclude + ): + # Skip excluded fields + continue + if name == "__root__": + fullname = path + else: + fullname = name if not path else f"{path}.{name}" + + field_type = field.type_ + # field.type_ is element type for collections/mappings + + if not isinstance(field_type, type): + # Handle generics. Probably will break in complex cases + origin = get_origin(field_type) + if is_union(origin): + # get first type for union + generic_args = get_args(field_type) + field_type = generic_args[0] + if origin is list or origin is dict: + # replace with dynamic __root__: Dict/List model + field_type = create_model( + display_as_type(field_type), __root__=(field_type, ...) + ) + if field_type is Any: + field_type = anything(field_type) + + if not isinstance(field_type, type): + # skip too complicated stuff + continue + + yield from parse_type_field( + path=fullname, + type_=field_type, + help_=get_field_help(cls, name), + is_list=field.shape in LIST_LIKE_SHAPES, + is_mapping=field.shape in MAPPING_LIKE_SHAPES, + required=not force_not_req and bool(field.required), + allow_none=field.allow_none, + default=field.default, + ) + + +@dataclass +class CallContext: + params: Dict[str, Any] + extra_keys: List[str] + regular_options: List[str] + + +def _options_from_model( + cls: Type[BaseModel], + ctx: CallContext, + path="", + force_not_set: bool = False, +) -> Iterator[TyperOption]: + """Generate additional cli options from model field""" + for field in iterate_type_fields(cls, path=path): + path = field.path + if path in ctx.regular_options: + # add dot if path shadows existing parameter + # it will be ignored on model building + path = f".{path}" + + if field.is_list: + yield from _options_from_list(path, field, ctx) + continue + if field.is_mapping: + yield from _options_from_mapping(path, field, ctx) + continue + if issubclass(field.type_, MlemABC) and field.type_.__is_root__: + yield from _options_from_mlem_abc( + ctx, field, path, force_not_set=force_not_set + ) + continue + + yield _option_from_field(field, path, force_not_set=force_not_set) + + +def _options_from_mlem_abc( + ctx: CallContext, + field: CliTypeField, + path: str, + force_not_set: bool = False, +): + """Generate str option for mlem abc type. + If param is already set, also generate respective implementation fields""" + assert issubclass(field.type_, MlemABC) and field.type_.__is_root__ + if path in ctx.params and ctx.params[path] != NOT_SET: + yield from _options_from_model( + load_impl_ext(field.type_.abs_name, ctx.params[path]), + ctx, + path, + ) + yield _option_from_field( + field, path, override_type=str, force_not_set=force_not_set + ) + + +def _options_from_mapping(path: str, field: CliTypeField, ctx: CallContext): + """Generate options for mapping and example element. + If some keys are already set, also generate options for them""" + mapping_keys = [ + key[len(path) + 1 :].split(".", maxsplit=1)[0] + for key in ctx.extra_keys + if key.startswith(path + ".") + ] + for key in mapping_keys: + yield from _options_from_collection_element( + f"{path}.{key}", field, ctx + ) + + override_type = Dict[str, field.type_] # type: ignore[name-defined] + yield _option_from_field( + field, path, override_type=override_type, force_not_set=True + ) + yield from _options_from_collection_element( + f"{path}.key", field, ctx, force_not_set=True + ) + + +def _options_from_list(path: str, field: CliTypeField, ctx: CallContext): + """Generate option for list and example element. + If some indexes are already set, also generate options for them""" + index = 0 + next_path = f"{path}.{index}" + while any(p.startswith(next_path) for p in ctx.params) and any( + v != NOT_SET for p, v in ctx.params.items() if p.startswith(next_path) + ): + yield from _options_from_collection_element(next_path, field, ctx) + index += 1 + next_path = f"{path}.{index}" + + override_type = List[field.type_] # type: ignore[name-defined] + yield _option_from_field( + field, path, override_type=override_type, force_not_set=True + ) + yield from _options_from_collection_element( + f"{path}.{index}", field, ctx, force_not_set=True + ) + + +def _options_from_collection_element( + path: str, + field: CliTypeField, + ctx: CallContext, + force_not_set: bool = False, +) -> Iterator[TyperOption]: + """Generate options for collection/mapping values""" + if issubclass(field.type_, MlemABC) and field.type_.__is_root__: + yield from _options_from_mlem_abc( + ctx, field, path, force_not_set=force_not_set + ) + return + if issubclass(field.type_, BaseModel): + yield from _options_from_model( + field.type_, ctx, path, force_not_set=force_not_set + ) + return + yield _option_from_field(field, path, force_not_set=force_not_set) + + +NOT_SET = "__NOT_SET__" +FILE_CONF_PARAM_NAME = "file_conf" +LOAD_PARAM_NAME = "load" + + +class SetViaFileTyperOption(TyperOption): + def process_value(self, ctx: Context, value: Any) -> Any: + try: + return super().process_value(ctx, value) + except MissingParameter: + if ( + LOAD_PARAM_NAME in ctx.params + or FILE_CONF_PARAM_NAME in ctx.params + and any( + smart_split(v, "=", 1)[0] == self.name + for v in ctx.params[FILE_CONF_PARAM_NAME] + ) + ): + return NOT_SET + raise + + +def _option_from_field( + field: CliTypeField, + path: str, + override_type: Type = None, + force_not_set: bool = False, +) -> TyperOption: + """Create cli option from field descriptor""" + type_ = override_type or field.type_ + if force_not_set: + type_ = anything(type_) + elif field.allow_none: + type_ = optional(type_) + option = SetViaFileTyperOption( + param_decls=[f"--{path}", path.replace(".", "_")], + type=type_ if not force_not_set else anything(type_), + required=field.required and not force_not_set, + default=field.default + if not field.is_list and not field.is_mapping and not force_not_set + else NOT_SET, + help=field.help, + show_default=not field.required, + ) + option.name = path + return option + + +def abc_fields_parameters(type_name: str, mlem_abc: Type[MlemABC]): + """Create a dynamic options generator that adds implementation fields""" + + def generator(ctx: CallContext): + try: + cls = load_impl_ext(mlem_abc.abs_name, type_name=type_name) + except ImportError: + return + yield from _options_from_model(cls, ctx) + + return generator + + +def get_extra_keys(args): + return [a[2:] for a in args if a.startswith("--")] + + +def lazy_class_docstring(abs_name: str, type_name: str): + def load_docstring(): + try: + return load_impl_ext(abs_name, type_name).__doc__ + except ExtensionRequirementError as e: + return f"Help unavailbale: {e}" + + return load_docstring + + +def for_each_impl(mlem_abc: Type[MlemABC]): + def inner(f): + for type_name in list_implementations(mlem_abc): + f(type_name) + return f + + return inner + + +def _iter_errors( + errors: Sequence[Any], model: Type, loc: Optional[Tuple] = None +): + for error in errors: + if isinstance(error, ErrorWrapper): + + if loc: + error_loc = loc + error.loc_tuple() + else: + error_loc = error.loc_tuple() + + if isinstance(error.exc, ValidationError): + yield from _iter_errors( + error.exc.raw_errors, error.exc.model, error_loc + ) + else: + yield error_loc, model, error.exc + + +def _format_validation_error(error: ValidationError) -> List[str]: + res = [] + for loc, model, exc in _iter_errors(error.raw_errors, error.model): + path = ".".join(loc_part for loc_part in loc if loc_part != "__root__") + field_name = loc[-1] + if field_name not in model.__fields__: + res.append( + f"Unknown field '{field_name}'. Fields available: {', '.join(model.__fields__)}" + ) + continue + field_type = model.__fields__[field_name].type_ + if ( + isinstance(exc, MissingError) + and isinstance(field_type, type) + and issubclass(field_type, BaseModel) + ): + msgs = [ + str(EMOJI_FAIL + f"field `{path}.{f.name}`: {exc}") + for f in field_type.__fields__.values() + if f.required + ] + if msgs: + res.extend(msgs) + else: + res.append(str(EMOJI_FAIL + f"field `{path}`: {exc}")) + else: + res.append(str(EMOJI_FAIL + f"field `{path}`: {exc}")) + return res + + +@contextlib.contextmanager +def wrap_build_error(subtype, model: Type[MlemABC]): + try: + yield + except ValidationError as e: + if LOCAL_CONFIG.DEBUG: + raise + msgs = "\n".join(_format_validation_error(e)) + raise typer.BadParameter( + f"Error on constructing {subtype} {model.abs_name}:\n{msgs}" + ) from e + + +def config_arg( + model: Type[MlemABC], + load: Optional[str], + subtype: Optional[str], + conf: Optional[List[str]], + file_conf: Optional[List[str]], + **kwargs, +): + if load is not None: + if issubclass(model, MlemObject): + try: + return load_meta(load, force_type=model) + except MlemObjectNotFound: + pass + with Location.resolve(load).open("r", encoding="utf8") as of: + return parse_obj_as(model, safe_load(of)) + if not subtype: + raise typer.BadParameter( + f"Cannot configure {model.abs_name}: either subtype or --load should be provided" + ) + with wrap_build_error(subtype, model): + return build_mlem_object(model, subtype, conf, file_conf, kwargs) + + +def _extract_examples( + help_str: Optional[str], +) -> Tuple[Optional[str], Optional[str]]: + if help_str is None: + return None, None + try: + examples = help_str.index("Examples:") + except ValueError: + return None, help_str + return help_str[examples + len("Examples:") + 1 :], help_str[:examples] diff --git a/mlem/contrib/bitbucketfs.py b/mlem/contrib/bitbucketfs.py index 3e5ad75f..b25da96a 100644 --- a/mlem/contrib/bitbucketfs.py +++ b/mlem/contrib/bitbucketfs.py @@ -222,11 +222,13 @@ def _mathch_path_with_ref(repo, path): class BitBucketResolver(CloudGitResolver): + """Resolve bitbucket URIs""" + type: ClassVar = "bitbucket" FS = BitBucketFileSystem PROTOCOL = "bitbucket" - # TODO: support on-prem gitlab (other hosts) + # TODO: https://github.com/iterative/mlem/issues/388 PREFIXES = [BITBUCKET_ORG, PROTOCOL + "://"] versioning_support = True diff --git a/mlem/contrib/callable.py b/mlem/contrib/callable.py index f5b3fcbe..fb08bb54 100644 --- a/mlem/contrib/callable.py +++ b/mlem/contrib/callable.py @@ -191,6 +191,8 @@ def persistent_load(self, pid: str) -> Any: class CallableModelType(ModelType, ModelHook): + """ModelType implementation for arbitrary callables""" + type: ClassVar = "callable" priority: ClassVar = LOW_PRIORITY_VALUE diff --git a/mlem/contrib/catboost.py b/mlem/contrib/catboost.py index 655a2bee..fcdab7fb 100644 --- a/mlem/contrib/catboost.py +++ b/mlem/contrib/catboost.py @@ -25,8 +25,11 @@ class CatBoostModelIO(ModelIO): type: ClassVar[str] = "catboost_io" classifier_file_name: ClassVar = "clf.cb" + """filename for catboost classifier""" regressor_file_name: ClassVar = "rgr.cb" + """filename for catboost classifier""" model_type: CBType = CBType.regressor + """type of catboost model""" def dump(self, storage: Storage, path, model) -> Artifacts: with tempfile.TemporaryDirectory() as tmpdir: diff --git a/mlem/contrib/docker/base.py b/mlem/contrib/docker/base.py index 134978ba..e4b9d997 100644 --- a/mlem/contrib/docker/base.py +++ b/mlem/contrib/docker/base.py @@ -127,14 +127,12 @@ def delete_image( class RemoteRegistry(DockerRegistry): - """DockerRegistry implementation for official Docker Registry (as in https://docs.docker.com/registry/) - - :param host: adress of the registry""" + """DockerRegistry implementation for official Docker Registry (as in https://docs.docker.com/registry/)""" type: ClassVar = "remote" - host: Optional[ - str - ] = None # TODO: https://github.com/iterative/mlem/issues/38 credentials + # TODO: https://github.com/iterative/mlem/issues/38 credentials + host: Optional[str] = None + """address of the registry""" def login(self, client): """ @@ -227,11 +225,10 @@ def delete_image( class DockerDaemon(MlemABC): - """Class that represents docker daemon - - :param host: adress of the docker daemon (empty string for local)""" + """Class that represents docker daemon""" host: str # TODO: https://github.com/iterative/mlem/issues/38 credentials + """adress of the docker daemon (empty string for local)""" @contextlib.contextmanager def client(self) -> Iterator[docker.DockerClient]: @@ -242,19 +239,18 @@ def client(self) -> Iterator[docker.DockerClient]: class DockerImage(BaseModel): """:class:`.Image.Params` implementation for docker images - full uri for image looks like registry.host/repository/name:tag - - :param name: name of the image - :param tag: tag of the image - :param repository: repository of the image - :param registry: :class:`.DockerRegistry` instance with this image - :param image_id: docker internal id of this image""" + full uri for image looks like registry.host/repository/name:tag""" name: str + """name of the image""" tag: str = "latest" + """tag of the image""" repository: Optional[str] = None + """repository of the image""" registry: DockerRegistry = DockerRegistry() + """DockerRegistry instance with this image""" image_id: Optional[str] = None + """internal docker id of this image""" @property def fullname(self): @@ -278,10 +274,14 @@ def delete(self, client: docker.DockerClient, force=False, **kwargs): class DockerContainerState(DeployState): + """State of docker container deployment""" + type: ClassVar = "docker_container" image: Optional[DockerImage] + """built image""" container_id: Optional[str] + """started container id""" def get_client(self): raise NotImplementedError @@ -289,25 +289,28 @@ def get_client(self): class _DockerBuildMixin(BaseModel): server: Server + """server to use""" args: DockerBuildArgs = DockerBuildArgs() + """additional docker arguments""" class DockerContainer(MlemDeployment, _DockerBuildMixin): - """:class:`.MlemDeployment` implementation for docker containers - - :param name: name of the container - :param port_mapping: port mapping in this container - :param params: other parameters for docker run cmd - :param container_id: internal docker id for this container""" + """MlemDeployment implementation for docker containers""" type: ClassVar = "docker_container" container_name: str + """Name to use for container""" image_name: Optional[str] = None + """Name to use for image""" port_mapping: Dict[int, int] = {} + """Expose ports""" params: Dict[str, str] = {} + """Additional params""" rm: bool = True + """Remove container on stop""" state: Optional[DockerContainerState] = None + """state""" @property def ensure_image_name(self): @@ -315,15 +318,14 @@ def ensure_image_name(self): class DockerEnv(MlemEnv[DockerContainer]): - """:class:`.MlemEnv` implementation for docker environment - - :param registry: default registry to push images to - :param daemon: :class:`.DockerDaemon` instance""" + """MlemEnv implementation for docker environment""" type: ClassVar = "docker" deploy_type: ClassVar = DockerContainer registry: DockerRegistry = DockerRegistry() + """default registry to push images to""" daemon: DockerDaemon = DockerDaemon(host="") + """Docker daemon parameters""" def delete_image(self, image: DockerImage, force: bool = False, **kwargs): with self.daemon.client() as client: @@ -448,8 +450,11 @@ def get_status( class DockerDirBuilder(MlemBuilder, _DockerBuildMixin): + """Create a directory with docker context to build docker image""" + type: ClassVar[str] = "docker_dir" target: str + """path to save result""" def build(self, obj: MlemModel): docker_dir = DockerModelDirectory( @@ -464,11 +469,17 @@ def build(self, obj: MlemModel): class DockerImageBuilder(MlemBuilder, _DockerBuildMixin): + """Build docker image from model""" + type: ClassVar[str] = "docker" image: DockerImage + """Image parameters""" env: DockerEnv = DockerEnv() + """Where to build and push image. Defaults to local docker daemon""" force_overwrite: bool = False + """Ignore existing image with same name""" push: bool = True + """Push image to registry after it is built""" def build(self, obj: MlemModel) -> DockerImage: with tempfile.TemporaryDirectory(prefix="mlem_build_") as tempdir: diff --git a/mlem/contrib/docker/context.py b/mlem/contrib/docker/context.py index 5546aa51..faad7f13 100644 --- a/mlem/contrib/docker/context.py +++ b/mlem/contrib/docker/context.py @@ -13,6 +13,7 @@ from fsspec import AbstractFileSystem from fsspec.implementations.local import LocalFileSystem from pydantic import BaseModel +from yaml import safe_dump import mlem from mlem.config import MlemConfigBase, project_config @@ -26,6 +27,7 @@ REQUIREMENTS = "requirements.txt" MLEM_REQUIREMENTS = "mlem_requirements.txt" +SERVER = "server.yaml" TEMPLATE_FILE = "dockerfile.j2" MLEM_LOCAL_WHL = f"mlem-{mlem.__version__}-py3-none-any.whl" @@ -195,31 +197,34 @@ def get_mlem_requirements(): class DockerBuildArgs(BaseModel): - """ - Container for DockerBuild arguments + """Container for DockerBuild arguments""" - :param base_image: base image for the built image in form of a string or function from python version, - default: python:{python_version} - :param python_version: Python version to use, default: version of running interpreter - :param templates_dir: directory or list of directories for Dockerfile templates, default: ./docker_templates - - `pre_install.j2` - Dockerfile commands to run before pip - - `post_install.j2` - Dockerfile commands to run after pip - - `post_copy.j2` - Dockerfile commands to run after pip and MLEM distribution copy - :param run_cmd: command to run in container, default: sh run.sh - :param package_install_cmd: command to install packages. Default is apt-get, change it for other package manager - :param prebuild_hook: callable to call before build, accepts python version. Used for pre-building server images - :param mlem_whl: a path to mlem .whl file. If it is empty, mlem will be installed from pip TODO - :param platform: platform to build docker for, see https://docs.docker.com/desktop/multi-arch/ - """ + class Config: + fields = {"prebuild_hook": {"exclude": True}} base_image: Optional[Union[str, Callable[[str], str]]] = None + """base image for the built image in form of a string or function from python version, + default: python:{python_version}""" python_version: str = get_python_version() + """Python version to use + default: version of running interpreter""" templates_dir: List[str] = [] - run_cmd: Union[bool, str] = "sh run.sh" + """directory or list of directories for Dockerfile templates + - `pre_install.j2` - Dockerfile commands to run before pip + - `post_install.j2` - Dockerfile commands to run after pip + - `post_copy.j2` - Dockerfile commands to run after pip and MLEM distribution copy""" + run_cmd: Optional[str] = "sh run.sh" + """command to run in container""" package_install_cmd: str = "apt-get update && apt-get -y upgrade && apt-get install --no-install-recommends -y" + """command to install packages. Default is apt-get, change it for other package manager""" package_clean_cmd: str = "&& apt-get clean && rm -rf /var/lib/apt/lists/*" + """command to clean after package installation""" prebuild_hook: Optional[Callable[[str], Any]] = None + """callable to call before build, accepts python version. Used for pre-building server images""" + mlem_whl: Optional[str] = None + """a path to mlem .whl file. If it is empty, mlem will be installed from pip""" platform: Optional[str] = None + """platform to build docker for, see docs.docker.com/desktop/multi-arch/""" def get_base_image(self): if self.base_image is None: @@ -338,7 +343,10 @@ def write_dockerfile(self, requirements: Requirements): df.write(dockerfile) def write_configs(self): - pass + with self.fs.open( + posixpath.join(self.path, SERVER), "w", encoding="utf8" + ) as f: + safe_dump(self.server.dict(), f) def write_local_sources(self, requirements: Requirements): echo(EMOJI_PACK + "Adding sources...") @@ -363,7 +371,7 @@ def write_local_sources(self, requirements: Requirements): def write_run_file(self): with self.fs.open(posixpath.join(self.path, "run.sh"), "w") as sh: - sh.write(f"mlem serve {self.model_name} {self.server.type}") + sh.write(f"mlem serve -l {SERVER} -m {self.model_name}") def write_mlem_source(self): source = get_mlem_source() diff --git a/mlem/contrib/docker/dockerfile.j2 b/mlem/contrib/docker/dockerfile.j2 index 3abe1d12..b5720706 100644 --- a/mlem/contrib/docker/dockerfile.j2 +++ b/mlem/contrib/docker/dockerfile.j2 @@ -10,4 +10,4 @@ COPY . ./ {% for name, value in env.items() %}ENV {{ name }}={{ value }} {% endfor %} {% include "post_copy.j2" ignore missing %} -{% if run_cmd is not false %}CMD {{ run_cmd }}{% endif %} +{% if run_cmd is not none %}CMD {{ run_cmd }}{% endif %} diff --git a/mlem/contrib/docker/utils.py b/mlem/contrib/docker/utils.py index 7f5ef13e..c0469772 100644 --- a/mlem/contrib/docker/utils.py +++ b/mlem/contrib/docker/utils.py @@ -6,7 +6,7 @@ from contextlib import contextmanager from functools import wraps from threading import Lock -from typing import Any, Generator, Iterator, Tuple, Union +from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union import docker import requests @@ -108,23 +108,35 @@ def create_docker_client( client.close() -def image_exists_at_dockerhub(tag): +def image_exists_at_dockerhub(tag, library=False): repo, tag = tag.split(":") + lib = "library/" if library else "" resp = requests.get( - f"https://registry.hub.docker.com/v1/repositories/{repo}/tags/{tag}" + f"https://registry.hub.docker.com/v2/repositories/{lib}{repo}/tags/{tag}" ) time.sleep(1) # rate limiting return resp.status_code == 200 -def repository_tags_at_dockerhub(repo): - resp = requests.get( - f"https://registry.hub.docker.com/v1/repositories/{repo}/tags" +def repository_tags_at_dockerhub( + repo, library=False, max_results: Optional[int] = 100 +): + lib = "library/" if library else "" + res: List[Dict] = [] + next_page = ( + f"https://registry.hub.docker.com/v2/repositories/{lib}{repo}/tags" ) - time.sleep(1) # rate limiting - if resp.status_code != 200: - return {} - return {tag["name"] for tag in resp.json()} + while next_page is not None and ( + max_results is None or len(res) <= max_results + ): + resp = requests.get(next_page, params={"page_size": 1000}) + if resp.status_code != 200: + return {} + res.extend(resp.json()["results"]) + next_page = resp.json()["next"] + time.sleep(0.1) # rate limiting + + return {tag["name"] for tag in res} def wrap_docker_error(f): diff --git a/mlem/contrib/dvc.py b/mlem/contrib/dvc.py index eab8ea92..6eb7c8d0 100644 --- a/mlem/contrib/dvc.py +++ b/mlem/contrib/dvc.py @@ -34,13 +34,14 @@ def find_dvc_repo_root(path: str): class DVCStorage(LocalStorage): - """For now this storage is user-managed dvc storage, which means user should - track corresponding files with dvc manually. - TODO: add support for pipeline-tracked files and for single files with .dvc - Also add possibility to automatically add and push every artifact""" + """User-managed dvc storage, which means user should + track corresponding files with dvc manually.""" + + # TODO: https://github.com//issues/47 type: ClassVar = "dvc" uri: str = "" + """base storage path""" def upload(self, local_path: str, target_path: str) -> "DVCArtifact": return DVCArtifact( @@ -64,8 +65,11 @@ def relative(self, fs: AbstractFileSystem, path: str) -> Storage: class DVCArtifact(LocalArtifact): + """Local artifact that can be also read from DVC cache""" + type: ClassVar = "dvc" uri: str + """local path to file""" def _download(self, target_path: str) -> LocalArtifact: if os.path.isdir(target_path): diff --git a/mlem/contrib/fastapi.py b/mlem/contrib/fastapi.py index 4d7dd587..e2523ea9 100644 --- a/mlem/contrib/fastapi.py +++ b/mlem/contrib/fastapi.py @@ -34,11 +34,15 @@ def _create_schema_route(app: FastAPI, interface: Interface): class FastAPIServer(Server, LibRequirementsMixin): + """Serves model with http""" + libraries: ClassVar[List[ModuleType]] = [uvicorn, fastapi] type: ClassVar[str] = "fastapi" host: str = "0.0.0.0" + """net interface to use""" port: int = 8080 + """port to use""" @classmethod def _create_handler( diff --git a/mlem/contrib/github.py b/mlem/contrib/github.py index b77caa00..dc8805b5 100644 --- a/mlem/contrib/github.py +++ b/mlem/contrib/github.py @@ -59,12 +59,12 @@ class GithubResolver(CloudGitResolver): type: ClassVar = "github" FS: ClassVar = GithubFileSystem - PROTOCOL = "github" - GITHUB_COM = "https://github.com" + PROTOCOL: ClassVar = "github" + GITHUB_COM: ClassVar = "https://github.com" - # TODO: support on-prem github (other hosts) - PREFIXES = [GITHUB_COM, PROTOCOL + "://"] - versioning_support = True + # TODO: https://github.com//issues/388 + PREFIXES: ClassVar = [GITHUB_COM, PROTOCOL + "://"] + versioning_support: ClassVar = True @classmethod def get_envs(cls): diff --git a/mlem/contrib/gitlabfs.py b/mlem/contrib/gitlabfs.py index 14899688..4dd2b1a9 100644 --- a/mlem/contrib/gitlabfs.py +++ b/mlem/contrib/gitlabfs.py @@ -157,14 +157,16 @@ def _mathch_path_with_ref(project_id, path): class GitlabResolver(CloudGitResolver): + """Resolve https://gitlab.com URIs""" + type: ClassVar = "gitlab" - FS = GitlabFileSystem - PROTOCOL = "gitlab" - GITLAB_COM = "https://gitlab.com" + FS: ClassVar = GitlabFileSystem + PROTOCOL: ClassVar = "gitlab" + GITLAB_COM: ClassVar = "https://gitlab.com" - # TODO: support on-prem gitlab (other hosts) - PREFIXES = [GITLAB_COM, PROTOCOL + "://"] - versioning_support = True + # TODO: https://github.com//issues/388 + PREFIXES: ClassVar = [GITLAB_COM, PROTOCOL + "://"] + versioning_support: ClassVar = True @classmethod def get_kwargs(cls, uri): diff --git a/mlem/contrib/heroku/build.py b/mlem/contrib/heroku/build.py index 32c1f494..2736cff2 100644 --- a/mlem/contrib/heroku/build.py +++ b/mlem/contrib/heroku/build.py @@ -14,9 +14,13 @@ class HerokuRemoteRegistry(RemoteRegistry): + """Heroku docker registry""" + type: ClassVar = "heroku" api_key: Optional[str] = None - host = DEFAULT_HEROKU_REGISTRY + """HEROKU_API_KEY""" + host: str = DEFAULT_HEROKU_REGISTRY + """Registry host""" def uri(self, image: str): return super().uri(image).split(":")[0] diff --git a/mlem/contrib/heroku/meta.py b/mlem/contrib/heroku/meta.py index a2aee45f..59c24263 100644 --- a/mlem/contrib/heroku/meta.py +++ b/mlem/contrib/heroku/meta.py @@ -28,15 +28,23 @@ class HerokuAppMeta(BaseModel): name: str + """App name""" web_url: str + """App web url""" meta_info: dict + """additional metadata""" class HerokuState(DeployState): + """State of heroku deployment""" + type: ClassVar = "heroku" app: Optional[HerokuAppMeta] + """created heroku app""" image: Optional[DockerImage] + """built docker image""" release_state: Optional[Union[dict, list]] + """state of the release""" @property def ensured_app(self) -> HerokuAppMeta: @@ -51,18 +59,28 @@ def get_client(self) -> Client: class HerokuDeployment(MlemDeployment): + """Heroku App""" + type: ClassVar = "heroku" state: Optional[HerokuState] + """state""" app_name: str + """Heroku application name""" region: str = "us" + """heroku region""" stack: str = "container" + """stack to use""" team: Optional[str] = None + """heroku team""" class HerokuEnv(MlemEnv[HerokuDeployment]): + """Heroku Account""" + type: ClassVar = "heroku" deploy_type: ClassVar = HerokuDeployment api_key: Optional[str] = None + """HEROKU_API_KEY - advised to set via env variable or `heroku login`""" def deploy(self, meta: HerokuDeployment): from .utils import create_app, release_docker_app diff --git a/mlem/contrib/heroku/server.py b/mlem/contrib/heroku/server.py index c91cda25..f10e7164 100644 --- a/mlem/contrib/heroku/server.py +++ b/mlem/contrib/heroku/server.py @@ -9,7 +9,9 @@ class HerokuServer(FastAPIServer): - type: ClassVar = "heroku" + """Special FastAPI server to pickup port from env PORT""" + + type: ClassVar = "_heroku" def serve(self, interface: Interface): self.port = int(os.environ["PORT"]) diff --git a/mlem/contrib/lightgbm.py b/mlem/contrib/lightgbm.py index b45fad44..a215e617 100644 --- a/mlem/contrib/lightgbm.py +++ b/mlem/contrib/lightgbm.py @@ -42,6 +42,7 @@ class LightGBMDataType( type: ClassVar[str] = "lightgbm" valid_types: ClassVar = (lgb.Dataset,) inner: DataType + """Inner DataType""" def serialize(self, instance: Any) -> dict: self.check_type(instance, lgb.Dataset, SerializationError) @@ -77,6 +78,8 @@ def get_model(self, prefix: str = "") -> Type[BaseModel]: class LightGBMDataWriter(DataWriter): + """Wrapper writer for lightgbm.Dataset objects""" + type: ClassVar[str] = "lightgbm" def write( @@ -103,10 +106,14 @@ def write( class LightGBMDataReader(DataReader): + """Wrapper reader for lightgbm.Dataset objects""" + type: ClassVar[str] = "lightgbm" data_type: LightGBMDataType inner: DataReader + """inner reader""" label: List + """list of labels""" def read(self, artifacts: Artifacts) -> DataType: inner_data_type = self.inner.read(artifacts) @@ -128,7 +135,8 @@ class LightGBMModelIO(ModelIO): """ type: ClassVar[str] = "lightgbm_io" - model_file_name = "model.lgb" + model_file_name: str = "model.lgb" + """filename to use""" def dump(self, storage: Storage, path, model) -> Artifacts: with tempfile.TemporaryDirectory(prefix="mlem_lightgbm_dump") as f: @@ -161,6 +169,7 @@ class LightGBMModel(ModelType, ModelHook, IsInstanceHookMixin): type: ClassVar[str] = "lightgbm" valid_types: ClassVar = (lgb.Booster,) io: ModelIO = LightGBMModelIO() + """LightGBMModelIO""" @classmethod def process( diff --git a/mlem/contrib/numpy.py b/mlem/contrib/numpy.py index 6dbdd697..fa966983 100644 --- a/mlem/contrib/numpy.py +++ b/mlem/contrib/numpy.py @@ -40,19 +40,12 @@ def np_type_from_string(string_repr) -> np.dtype: class NumpyNumberType( LibRequirementsMixin, DataType, DataSerializer, DataHook ): - """ - :class:`.DataType` implementation for `numpy.number` objects which - converts them to built-in Python numbers and vice versa. - - :param dtype: `numpy.number` data type as string - """ + """numpy.number DataType""" libraries: ClassVar[List[ModuleType]] = [np] type: ClassVar[str] = "number" dtype: str - - # def get_spec(self) -> ArgList: - # return [Field(None, python_type_from_np_string_repr(self.dtype), False)] + """`numpy.number` type name as string""" def deserialize(self, obj: dict) -> Any: return self.actual_type(obj) # pylint: disable=not-callable @@ -83,19 +76,15 @@ def get_model(self, prefix: str = "") -> Type: class NumpyNdarrayType( LibRequirementsMixin, DataType, DataHook, DataSerializer ): - """ - :class:`.DataType` implementation for `np.ndarray` objects - which converts them to built-in Python lists and vice versa. - - :param shape: shape of `numpy.ndarray` objects in data - :param dtype: data type of `numpy.ndarray` objects in data - """ + """DataType implementation for `np.ndarray`""" type: ClassVar[str] = "ndarray" libraries: ClassVar[List[ModuleType]] = [np] shape: Optional[Tuple[Optional[int], ...]] + """shape of `numpy.ndarray`""" dtype: str + """data type of elements""" @staticmethod def _abstract_shape(shape): @@ -179,6 +168,8 @@ def get_writer(self, project: str = None, filename: str = None, **kwargs): class NumpyNumberWriter(DataWriter): + """Write np.number objects""" + type: ClassVar[str] = "numpy_number" def write( @@ -190,8 +181,11 @@ def write( class NumpyNumberReader(DataReader): + """Read np.number objects""" + type: ClassVar[str] = "numpy_number" data_type: NumpyNumberType + """resulting data type""" def read(self, artifacts: Artifacts) -> DataType: if DataWriter.art_name not in artifacts: diff --git a/mlem/contrib/pandas.py b/mlem/contrib/pandas.py index 601c9c32..80918822 100644 --- a/mlem/contrib/pandas.py +++ b/mlem/contrib/pandas.py @@ -114,16 +114,15 @@ class Config: class _PandasDataType( LibRequirementsMixin, DataType, DataHook, DataSerializer, ABC ): - """Intermidiate class for pandas DataType implementations - - :param columns: list of column names (including index) - :param dtypes: list of string representations of pandas dtypes of columns - :param index_cols: list of column names that are used as index""" + """Intermidiate class for pandas DataType implementations""" libraries: ClassVar = [pd] columns: List[str] + """Column names""" dtypes: List[str] + """Column types""" index_cols: List[str] + """Column names that should be in index""" @classmethod def process(cls, obj: Any, **kwargs) -> "_PandasDataType": @@ -560,6 +559,7 @@ def get_pandas_batch_formats(batch_size: int): class _PandasIO(BaseModel): format: str + """name of pandas-supported format""" @validator("format") def is_valid_format( # pylint: disable=no-self-argument @@ -669,6 +669,8 @@ def write( class PandasImport(ExtImportHook, LoadAndAnalyzeImportHook): + """Import files as pd.DataFrame""" + EXTS: ClassVar = tuple(f".{k}" for k in PANDAS_FORMATS) type: ClassVar = "pandas" force_type: ClassVar = MlemData diff --git a/mlem/contrib/pip/base.py b/mlem/contrib/pip/base.py index 7fcf2110..27bf2dd7 100644 --- a/mlem/contrib/pip/base.py +++ b/mlem/contrib/pip/base.py @@ -4,7 +4,7 @@ import posixpath import subprocess import tempfile -from typing import ClassVar, Dict, List, Optional +from typing import Any, ClassVar, Dict, List, Optional from fsspec import AbstractFileSystem from fsspec.implementations.local import LocalFileSystem @@ -26,19 +26,27 @@ class SetupTemplate(TemplateModel): TEMPLATE_DIR: ClassVar = os.path.dirname(__file__) package_name: str + """Name of python package""" python_version: Optional[str] = None + """Required python version""" short_description: str = "" + """short_description""" url: str = "" + """url""" email: str = "" + """author's email""" author: str = "" + """author's name""" version: str = "0.0.0" - additional_setup_kwargs: Dict = {} + """package version""" + additional_setup_kwargs: Dict[str, Any] = {} + """additional parameters for setup()""" @validator("python_version") def validate_python_version( # pylint: disable=no-self-argument cls, value # noqa: B902 ): - return f"=={value}" if value[0] in "0123456789" else value + return f"=={value}" if value and value[0] in "0123456789" else value class SourceTemplate(TemplateModel): @@ -46,6 +54,7 @@ class SourceTemplate(TemplateModel): TEMPLATE_DIR: ClassVar = os.path.dirname(__file__) methods: List[str] + """list of methods""" class PipMixin(SetupTemplate): @@ -86,8 +95,11 @@ def make_distr(self, obj: MlemModel, root: str, fs: AbstractFileSystem): class PipBuilder(MlemBuilder, PipMixin): + """Create a directory python package""" + type: ClassVar = "pip" target: str + """path to save result""" def build(self, obj: MlemModel): fs, root = get_fs(self.target) @@ -95,8 +107,11 @@ def build(self, obj: MlemModel): class WhlBuilder(MlemBuilder, PipMixin): + """Create a wheel with python package""" + type: ClassVar = "whl" target: str + """path to save result""" def build_whl(self, path, target, target_fs): target_fs.makedirs(target, exist_ok=True) diff --git a/mlem/contrib/rabbitmq.py b/mlem/contrib/rabbitmq.py index 14bea59f..90a6a6ff 100644 --- a/mlem/contrib/rabbitmq.py +++ b/mlem/contrib/rabbitmq.py @@ -24,9 +24,13 @@ class RabbitMQMixin(BaseModel): host: str + """Host of RMQ instance""" port: int + """Port of RMQ instance""" exchange: str = "" + """RMQ exchange to use""" queue_prefix: str = "" + """Queue prefix""" channel_cache: Optional[BlockingChannel] = None class Config: @@ -44,6 +48,8 @@ def channel(self): class RabbitMQServer(Server, RabbitMQMixin): + """RMQ server that consumes requests and produces model predictions from/to RMQ instance""" + type: ClassVar = "rmq" def _create_handler( @@ -96,8 +102,11 @@ def serve(self, interface: Interface): class RabbitMQClient(Client, RabbitMQMixin): + """Access models served with rmq server""" + type: ClassVar = "rmq" timeout: float = 0 + """Time to wait for response. 0 means indefinite""" def _interface_factory(self) -> InterfaceDescriptor: res, _, payload = self.channel.basic_get( diff --git a/mlem/contrib/sklearn.py b/mlem/contrib/sklearn.py index e4f29c90..b81d8000 100644 --- a/mlem/contrib/sklearn.py +++ b/mlem/contrib/sklearn.py @@ -22,14 +22,14 @@ class SklearnModel(ModelType, ModelHook, IsInstanceHookMixin): - """ - :class:`mlem.core.model.ModelType implementation for `scikit-learn` models - """ + """ModelType implementation for `scikit-learn` models""" type: ClassVar[str] = "sklearn" - io: ModelIO = SimplePickleIO() valid_types: ClassVar = (RegressorMixin, ClassifierMixin) + io: ModelIO = SimplePickleIO() + """IO""" + @classmethod def process( cls, obj: Any, sample_data: Optional[Any] = None, **kwargs @@ -85,6 +85,8 @@ def get_requirements(self) -> Requirements: class SklearnPipelineType(SklearnModel): + """ModelType implementation for `scikit-learn` pipelines""" + valid_types: ClassVar = (Pipeline,) type: ClassVar = "sklearn_pipeline" diff --git a/mlem/contrib/tensorflow.py b/mlem/contrib/tensorflow.py index d8135040..29448024 100644 --- a/mlem/contrib/tensorflow.py +++ b/mlem/contrib/tensorflow.py @@ -39,17 +39,15 @@ class TFTensorDataType( DataType, DataSerializer, DataHook, IsInstanceHookMixin ): """ - :class:`.DataType` implementation for `tensorflow.Tensor` objects - which converts them to built-in Python lists and vice versa. - - :param shape: shape of `tensorflow.Tensor` objects in data - :param dtype: data type of `tensorflow.Tensor` objects in data + DataType implementation for `tensorflow.Tensor` """ type: ClassVar[str] = "tf_tensor" valid_types: ClassVar = (tf.Tensor,) shape: Tuple[Optional[int], ...] + """shape of `tensorflow.Tensor` objects in data""" dtype: str + """data type of `tensorflow.Tensor` objects in data""" @property def tf_type(self): @@ -117,6 +115,8 @@ def process(cls, obj: tf.Tensor, **kwargs) -> DataType: class TFTensorWriter(DataWriter): + """Write tensorflow tensors to np format""" + type: ClassVar[str] = "tf_tensor" def write( @@ -128,6 +128,8 @@ def write( class TFTensorReader(DataReader): + """Read tensorflow tensors from np format""" + type: ClassVar[str] = "tf_tensor" def read(self, artifacts: Artifacts) -> DataType: @@ -157,11 +159,12 @@ def is_custom_net(model): class TFKerasModelIO(BufferModelIO): """ - :class:`.ModelIO` implementation for Tensorflow Keras models (:class:`tensorflow.keras.Model` objects) + IO for Tensorflow Keras models (:class:`tensorflow.keras.Model` objects) """ type: ClassVar[str] = "tf_keras" save_format: Optional[str] = None + """`tf` for custom net classes and `h5` otherwise""" def save_model(self, model: tf.keras.Model, path: str): if self.save_format is None: @@ -198,6 +201,7 @@ class TFKerasModel(ModelType, ModelHook, IsInstanceHookMixin): type: ClassVar[str] = "tf_keras" valid_types: ClassVar = (tf.keras.Model,) io: ModelIO = TFKerasModelIO() + """IO""" @classmethod def process( diff --git a/mlem/contrib/torch.py b/mlem/contrib/torch.py index bc9e489f..4f09dc63 100644 --- a/mlem/contrib/torch.py +++ b/mlem/contrib/torch.py @@ -30,18 +30,14 @@ def python_type_from_torch_string_repr(dtype: str): class TorchTensorDataType( DataType, DataSerializer, DataHook, IsInstanceHookMixin ): - """ - :class:`.DataType` implementation for `torch.Tensor` objects - which converts them to built-in Python lists and vice versa. - - :param shape: shape of `torch.Tensor` objects in data - :param dtype: data type of `torch.Tensor` objects in data - """ + """DataType implementation for `torch.Tensor`""" type: ClassVar[str] = "torch" valid_types: ClassVar = (torch.Tensor,) shape: Tuple[Optional[int], ...] + """shape of `torch.Tensor` object""" dtype: str + """type name of `torch.Tensor` elements""" def _check_shape(self, tensor, exc_type): if tuple(tensor.shape)[1:] != self.shape[1:]: @@ -102,6 +98,8 @@ def process(cls, obj: torch.Tensor, **kwargs) -> DataType: class TorchTensorWriter(DataWriter): + """Write torch tensors""" + type: ClassVar[str] = "torch" def write( @@ -113,6 +111,8 @@ def write( class TorchTensorReader(DataReader): + """Read torch tensors""" + type: ClassVar[str] = "torch" def read(self, artifacts: Artifacts) -> DataType: @@ -131,12 +131,11 @@ def read_batch( class TorchModelIO(ModelIO): - """ - :class:`.ModelIO` implementation for PyTorch models - """ + """IO for PyTorch models""" type: ClassVar[str] = "torch_io" is_jit: bool = False + """Is model jit compiled""" def dump(self, storage: Storage, path, model) -> Artifacts: self.is_jit = isinstance(model, torch.jit.ScriptModule) @@ -162,6 +161,7 @@ class TorchModel(ModelType, ModelHook, IsInstanceHookMixin): type: ClassVar[str] = "torch" valid_types: ClassVar = (torch.nn.Module,) io: ModelIO = TorchModelIO() + """TorchModelIO""" @classmethod def process( @@ -194,6 +194,8 @@ def get_requirements(self) -> Requirements: class TorchModelImport(LoadAndAnalyzeImportHook): + """Import torch models saved with `torch.save`""" + type: ClassVar = "torch" force_type: ClassVar = MlemModel diff --git a/mlem/contrib/xgboost.py b/mlem/contrib/xgboost.py index c7db3fe8..2b9e7206 100644 --- a/mlem/contrib/xgboost.py +++ b/mlem/contrib/xgboost.py @@ -41,19 +41,18 @@ class DMatrixDataType( IsInstanceHookMixin, ): """ - :class:`~.DataType` implementation for xgboost.DMatrix type - - :param is_from_list: whether DMatrix can be constructed from list - :param feature_type_names: string representation of feature types - :param feature_names: list of feature names + DataType implementation for xgboost.DMatrix type """ type: ClassVar[str] = "xgboost_dmatrix" valid_types: ClassVar = (xgboost.DMatrix,) is_from_list: bool + """whether DMatrix can be constructed from list""" feature_type_names: Optional[List[str]] + """string representation of feature types""" feature_names: Optional[List[str]] = None + """list of feature names""" @property def feature_types(self): @@ -118,7 +117,8 @@ class XGBoostModelIO(ModelIO): """ type: ClassVar[str] = "xgboost_io" - model_file_name = "model.xgb" + model_file_name: str = "model.xgb" + """filename to use""" def dump( self, storage: Storage, path, model: xgboost.Booster diff --git a/mlem/core/artifacts.py b/mlem/core/artifacts.py index 558da83a..1edc2faf 100644 --- a/mlem/core/artifacts.py +++ b/mlem/core/artifacts.py @@ -38,8 +38,11 @@ class Config: abs_name: ClassVar = "artifact" uri: str + """location""" size: int + """size in bytes""" hash: str + """md5 hash""" @overload def materialize( @@ -101,6 +104,7 @@ class FSSpecArtifact(Artifact): type: ClassVar = "fsspec" uri: str + """Path to file""" def _download(self, target_path: str) -> "LocalArtifact": fs, path = get_fs(self.uri) @@ -135,7 +139,9 @@ class PlaceholderArtifact(Artifact): """On dumping this artifact will be replaced with actual artifact that is relative to project root (if there is a project)""" + type: ClassVar = "_placeholder" location: Location + """location of artifact""" def relative(self, fs: AbstractFileSystem, path: str) -> "Artifact": raise NotImplementedError @@ -201,7 +207,9 @@ class Config: fs: Optional[AbstractFileSystem] = None base_path: str = "" uri: str + """Path to storage dir""" storage_options: Optional[Dict[str, str]] = {} + """Additional options for FS""" def upload(self, local_path: str, target_path: str) -> FSSpecArtifact: fs = self.get_fs() diff --git a/mlem/core/base.py b/mlem/core/base.py index 760a0d25..d052c4c2 100644 --- a/mlem/core/base.py +++ b/mlem/core/base.py @@ -1,6 +1,18 @@ import shlex +from collections import defaultdict from inspect import isabstract -from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, overload +from typing import ( + Any, + ClassVar, + Dict, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, + overload, +) from pydantic import BaseModel, parse_obj_as from typing_extensions import Literal @@ -42,6 +54,11 @@ def load_impl_ext( load_entrypoints, ) + if abs_name in MlemABC.abs_types: + abs_class = MlemABC.abs_types[abs_name] + if type_name in abs_class.__type_map__: + return abs_class.__type_map__[type_name] + if type_name is not None and "." in type_name: try: obj = import_string(type_name) @@ -130,30 +147,7 @@ def load_type(cls, type_name: str): raise UnknownImplementation(type_name, cls.abs_name) from e -def set_or_replace(obj: dict, key: str, value: Any, subkey: str = "type"): - if key in obj: - old_value = obj[key] - if ( - isinstance(old_value, str) - and isinstance(value, dict) - and subkey not in value - ): - value[subkey] = old_value - obj[key] = value - return - if isinstance(old_value, dict) and isinstance(value, str): - old_value[subkey] = value - return - obj[key] = value - - -def set_recursively(obj: dict, keys: List[str], value: Any): - if len(keys) == 1: - set_or_replace(obj, keys[0], value) - return - key, keys = keys[0], keys[1:] - set_or_replace(obj, key, {}) - set_recursively(obj[key], keys, value) +_not_set = object() def get_recursively(obj: dict, keys: List[str]): @@ -163,13 +157,13 @@ def get_recursively(obj: dict, keys: List[str]): return get_recursively(obj[key], keys) -def smart_split(string: str, char: str, maxsplit: int = None): +def smart_split(value: str, char: str, maxsplit: int = None): SPECIAL = "\0" if char != " ": - string = string.replace(" ", SPECIAL).replace(char, " ") + value = value.replace(" ", SPECIAL).replace(char, " ") res = [ s.replace(" ", char).replace(SPECIAL, " ") - for s in shlex.split(string, posix=True) + for s in shlex.split(value, posix=True) ] if maxsplit is None: return res @@ -227,12 +221,118 @@ def parse_links(model: Type["BaseModel"], str_conf: List[str]): return not_links, links +IntStr = Union[int, str] +Keys = Tuple[IntStr, ...] +KeyValue = Tuple[IntStr, Any] +Aggregates = Dict[Keys, List[KeyValue]] + + +class SmartSplitDict(dict): + def __init__(self, value=None, sep=".", type_field="type"): + self.type_field = type_field + self.sep = sep + super().__init__(value or ()) + + def update(self, __m: Dict[Any, Any], **kwargs) -> None: # type: ignore[override] + for k, v in __m.items(): + self[k] = v + for k, v in kwargs.items(): + self[k] = v + + def __setitem__(self, key, value): + if isinstance(key, str): + key = tuple(smart_split(key, self.sep)) + + for keys, val in self._disassemble(value, key): + super().__setitem__(keys, val) + + def _disassemble(self, value: Any, key_prefix): + if isinstance(value, list): + for i, v in enumerate(value): + yield from self._disassemble(v, key_prefix + (i,)) + return + if isinstance(value, dict): + for k, v in value.items(): + yield from self._disassemble(v, key_prefix + (k,)) + return + yield key_prefix, value + + def build(self) -> Dict[str, Any]: + prefix_values: Aggregates = self._aggregate_by_prefix() + while prefix_values: + if len(prefix_values) == 1 and () in prefix_values: + return self._merge_aggregates(prefix_values[()]) + max_len = max(len(k) for k in prefix_values) + to_aggregate: Dict[Keys, Any] = {} + postponed: Aggregates = defaultdict(list) + for prefix, values in prefix_values.items(): + if len(prefix) == max_len: + to_aggregate[prefix] = self._merge_aggregates(values) + continue + postponed[prefix] = values + aggregated: Aggregates = self._aggregate_by_prefix(to_aggregate) + for prefix in set(postponed).union(aggregated): + postponed[prefix].extend(aggregated.get(prefix, [])) + if postponed == prefix_values: + raise RuntimeError("infinite loop on smartdict builing") + prefix_values = postponed + # this can only be reached if loop was not entered + return {} + + def _merge_aggregates(self, values: List[KeyValue]) -> Any: + if all(isinstance(k, int) for k, _ in values): + return self._merge_as_list(values) + return self._merge_as_dict(values) + + def _merge_as_list(self, values: List[KeyValue]): + assert all(isinstance(k, int) for k, _ in values) + index_values = defaultdict(list) + for index, value in values: + index_values[index].append(value) + res = [_not_set] * (int(max(k for k, _ in values)) + 1) + for i, v in index_values.items(): + res[i] = self._merge_values(v) # type: ignore[index] + return res + + def _merge_as_dict(self, values: List[KeyValue]) -> Dict[Any, Any]: + key_values = defaultdict(list) + for key, value in values: + key_values[key].append(value) + return {k: self._merge_values(v) for k, v in key_values.items()} + + def _merge_values(self, values: List[Any]) -> Any: + if len(values) == 1: + return values[0] + merged = {} + for value in values: + if isinstance(value, dict): + merged.update(value) + elif isinstance(value, str): + merged[self.type_field] = value + else: + raise ValueError(f"Cannot merge {value.__class__} into dict") + return merged + + def _aggregate_by_prefix( + self, values: Dict[Keys, Any] = None + ) -> Aggregates: + values = values if values is not None else self + prefix_values: Aggregates = defaultdict(list) + + for keys, value in values.items(): + prefix, key = keys[:-1], keys[-1] + if isinstance(key, str) and key.isnumeric(): + key = int(key) + prefix_values[prefix].append((key, value)) + return prefix_values + + def parse_string_conf(conf: List[str]) -> Dict[str, Any]: - res: Dict[str, Any] = {} + res = SmartSplitDict() for c in conf: keys, value = smart_split(c, "=") - set_recursively(res, smart_split(keys, "."), value) - return res + res[keys] = value + return res.build() def build_model( @@ -242,21 +342,19 @@ def build_model( conf: Dict[str, Any] = None, **kwargs, ): - model_dict: Dict[str, Any] = {} - kwargs.update(conf or {}) - model_dict.update() - for key, c in kwargs.items(): - set_recursively(model_dict, smart_split(key, "."), c) + model_dict = SmartSplitDict() + model_dict.update(kwargs) + model_dict.update(conf or {}) for file in file_conf or []: keys, path = smart_split(make_posix(file), "=") with open(path, "r", encoding="utf8") as f: value = safe_load(f) - set_recursively(model_dict, smart_split(keys, "."), value) + model_dict[keys] = value for c in str_conf or []: keys, value = smart_split(c, "=", 1) if value == "None": value = None - set_recursively(model_dict, smart_split(keys, "."), value) - return parse_obj_as(model, model_dict) + model_dict[keys] = value + return parse_obj_as(model, model_dict.build()) diff --git a/mlem/core/data_type.py b/mlem/core/data_type.py index a96efc09..5d920996 100644 --- a/mlem/core/data_type.py +++ b/mlem/core/data_type.py @@ -133,6 +133,7 @@ class Config: type_root = True data_type: DataType + """resulting data type""" abs_name: ClassVar[str] = "data_reader" @abstractmethod @@ -172,6 +173,7 @@ class PrimitiveType(DataType, DataHook, DataSerializer): type: ClassVar[str] = "primitive" ptype: str + """Name of builtin type""" @classmethod def is_object_valid(cls, obj: Any) -> bool: @@ -205,6 +207,8 @@ def get_model(self, prefix: str = "") -> Type[BaseModel]: class PrimitiveWriter(DataWriter): + """Writer for primitive types""" + type: ClassVar[str] = "primitive" def write( @@ -216,6 +220,8 @@ def write( class PrimitiveReader(DataReader): + """Reader for primitive types""" + type: ClassVar[str] = "primitive" data_type: PrimitiveType @@ -247,7 +253,9 @@ class ArrayType(DataType, DataSerializer): type: ClassVar[str] = "array" dtype: DataType + """DataType of elements""" size: Optional[int] + """size of the list""" def get_requirements(self) -> Requirements: return self.dtype.get_requirements() @@ -272,6 +280,8 @@ def get_model(self, prefix: str = "") -> Type[BaseModel]: class ArrayWriter(DataWriter): + """Writer for lists with single element type""" + type: ClassVar[str] = "array" def write( @@ -298,9 +308,12 @@ def write( class ArrayReader(DataReader): + """Reader for lists with single element type""" + type: ClassVar[str] = "array" data_type: ArrayType readers: List[DataReader] + """inner readers""" def read(self, artifacts: Artifacts) -> DataType: artifacts = flatdict.FlatterDict(artifacts, delimiter="/") @@ -321,9 +334,12 @@ class _TupleLikeType(DataType, DataSerializer): DataType for tuple-like collections """ - items: List[DataType] + type: ClassVar = "_tuple_like" actual_type: ClassVar[type] + items: List[DataType] + """DataTypes of elements""" + def deserialize(self, obj): _check_type_and_size( obj, self.actual_type, len(self.items), DeserializationError @@ -377,6 +393,8 @@ def _check_type_and_size(obj, dtype, size, exc_type): class _TupleLikeWriter(DataWriter): + """Writer for tuple-like data""" + type: ClassVar[str] = "tuple_like" def write( @@ -404,9 +422,12 @@ def write( class _TupleLikeReader(DataReader): + """Reader for tuple-like data""" + type: ClassVar[str] = "tuple_like" data_type: _TupleLikeType readers: List[DataReader] + """inner readers""" def read(self, artifacts: Artifacts) -> DataType: artifacts = flatdict.FlatterDict(artifacts, delimiter="/") @@ -515,6 +536,7 @@ class DictType(DataType, DataSerializer): type: ClassVar[str] = "dict" item_types: Dict[Union[StrictStr, StrictInt], DataType] + """Mapping key -> nested data type""" @classmethod def process(cls, obj, **kwargs): @@ -570,6 +592,8 @@ def get_model(self, prefix="") -> Type[BaseModel]: class DictWriter(DataWriter): + """Writer for dicts""" + type: ClassVar[str] = "dict" def write( @@ -597,9 +621,12 @@ def write( class DictReader(DataReader): + """Reader for dicts""" + type: ClassVar[str] = "dict" data_type: DictType item_readers: Dict[Union[StrictStr, StrictInt], DataReader] + """nested readers""" def read(self, artifacts: Artifacts) -> DataType: artifacts = flatdict.FlatterDict(artifacts, delimiter="/") @@ -623,7 +650,9 @@ class DynamicDictType(DataType, DataSerializer): type: ClassVar[str] = "d_dict" key_type: PrimitiveType + """DataType for key (primitive)""" value_type: DataType + """DataType for value""" @validator("key_type") def is_valid_key_type( # pylint: disable=no-self-argument @@ -720,6 +749,8 @@ def get_model(self, prefix="") -> Type[BaseModel]: class DynamicDictWriter(DataWriter): + """Write dicts without fixed set of keys""" + type: ClassVar[str] = "d_dict" def write( @@ -739,6 +770,8 @@ def write( class DynamicDictReader(DataReader): + """Read dicts without fixed set of keys""" + type: ClassVar[str] = "d_dict" data_type: DynamicDictType diff --git a/mlem/core/errors.py b/mlem/core/errors.py index 7b2a5aaf..c01a6f31 100644 --- a/mlem/core/errors.py +++ b/mlem/core/errors.py @@ -147,7 +147,7 @@ def __init__(self, section: str): super().__init__(f'Unknown config section "{section}"') -class ExtensionRequirementError(MlemError): +class ExtensionRequirementError(MlemError, ImportError): def __init__(self, ext: str, reqs: List[str], extra: Optional[str]): self.ext = ext self.reqs = reqs diff --git a/mlem/core/meta_io.py b/mlem/core/meta_io.py index 06da3905..35405fd7 100644 --- a/mlem/core/meta_io.py +++ b/mlem/core/meta_io.py @@ -75,6 +75,23 @@ def uri_repr(self): return posixpath.relpath(self.fullpath, "") return self.uri + @classmethod + def resolve( + cls, + path: str, + project: str = None, + rev: str = None, + fs: AbstractFileSystem = None, + find_project: bool = False, + ): + return UriResolver.resolve( + path=path, + project=project, + rev=rev, + fs=fs, + find_project=find_project, + ) + class UriResolver(MlemABC): """Base class for resolving location. Turns (path, project, rev, fs) tuple @@ -299,6 +316,7 @@ def pre_process( class FSSpecResolver(UriResolver): """Resolve different fsspec URIs""" + type: ClassVar = "fsspec" low_priority: ClassVar = True @classmethod @@ -338,7 +356,7 @@ def get_uri( def get_fs(uri: str) -> Tuple[AbstractFileSystem, str]: - location = UriResolver.resolve(path=uri, project=None, rev=None, fs=None) + location = Location.resolve(path=uri, project=None, rev=None, fs=None) return location.fs, location.fullpath @@ -353,7 +371,7 @@ def get_path_by_fs_path(fs: AbstractFileSystem, path: str): def get_uri(fs: AbstractFileSystem, path: str, repr: bool = False): - loc = UriResolver.resolve(path, None, None, fs=fs) + loc = Location.resolve(path, None, None, fs=fs) if repr: return loc.uri_repr return loc.uri diff --git a/mlem/core/metadata.py b/mlem/core/metadata.py index eae91588..d6db52d6 100644 --- a/mlem/core/metadata.py +++ b/mlem/core/metadata.py @@ -15,7 +15,7 @@ MlemProjectNotFound, WrongMetaType, ) -from mlem.core.meta_io import Location, UriResolver, get_meta_path +from mlem.core.meta_io import Location, get_meta_path from mlem.core.objects import MlemData, MlemModel, MlemObject, find_object from mlem.utils.path import make_posix @@ -164,7 +164,7 @@ def load_meta( Returns: MlemObject: Saved MlemObject """ - location = UriResolver.resolve( + location = Location.resolve( path=make_posix(path), project=make_posix(project), rev=rev, diff --git a/mlem/core/model.py b/mlem/core/model.py index 5952e690..70f745d0 100644 --- a/mlem/core/model.py +++ b/mlem/core/model.py @@ -102,10 +102,15 @@ class Argument(BaseModel): """Function argument descriptor""" name: str + """argument name""" type_: DataType + """argument data type""" required: bool = True + """is required""" default: Any = None + """default value""" kw_only: bool = False + """is keyword only""" @classmethod def from_argspec( @@ -177,10 +182,15 @@ class Signature(BaseModel, WithRequirements): """Function signature descriptor""" name: str + """function name""" args: List[Argument] + """list of arguments""" returns: DataType + """returning data type""" varargs: Optional[str] = None + """name of var arg""" varkw: Optional[str] = None + """name of varkw arg""" @classmethod def from_method( @@ -230,9 +240,7 @@ def get_requirements(self): class ModelType(ABC, MlemABC, WithRequirements): - """ - Base class for model metadata. - """ + """Base class for model metadata.""" class Config: type_root = True @@ -243,7 +251,9 @@ class Config: model: Any = None io: ModelIO + """model IO""" methods: Dict[str, Signature] + """model method signatures""" def load(self, artifacts: Artifacts): self.model = self.io.load(artifacts) diff --git a/mlem/core/objects.py b/mlem/core/objects.py index 5609b60c..77e8f845 100644 --- a/mlem/core/objects.py +++ b/mlem/core/objects.py @@ -45,13 +45,7 @@ MlemProjectNotFound, WrongMetaType, ) -from mlem.core.meta_io import ( - MLEM_DIR, - MLEM_EXT, - Location, - UriResolver, - get_path_by_fs_path, -) +from mlem.core.meta_io import MLEM_DIR, MLEM_EXT, Location, get_path_by_fs_path from mlem.core.model import ModelAnalyzer, ModelType from mlem.core.requirements import Requirements from mlem.polydantic.lazy import lazy_field @@ -77,7 +71,9 @@ class Config: __abstract__: ClassVar[bool] = True object_type: ClassVar[str] location: Optional[Location] = None + """MlemObject location [transient]""" params: Dict[str, str] = {} + """Arbitrary map of additional parameters""" @property def loc(self) -> Location: @@ -126,7 +122,7 @@ def _get_location( """Create location from arguments""" if metafile_path: path = cls.get_metafile_path(path) - loc = UriResolver.resolve( + loc = Location.resolve( path, project, rev=None, fs=fs, find_project=True ) if loc.project is not None: @@ -359,9 +355,13 @@ class MlemLink(MlemObject): location""" path: str + """path to object""" project: Optional[str] = None + """project URI""" rev: Optional[str] = None + """revision to use""" link_type: str + """type of underlying object""" object_type: ClassVar = "link" @@ -406,7 +406,7 @@ def parse_link(self) -> Location: if self.project is None and self.rev is None: # is it possible to have rev without project? - location = UriResolver.resolve( + location = Location.resolve( path=self.path, project=None, rev=None, fs=None ) if ( @@ -424,7 +424,7 @@ def parse_link(self) -> Location: return find_meta_location(location) # link is absolute return find_meta_location( - UriResolver.resolve( + Location.resolve( path=self.path, project=self.project, rev=self.rev, fs=None ) ) @@ -448,7 +448,9 @@ class _WithArtifacts(ABC, MlemObject): __abstract__: ClassVar[bool] = True artifacts: Optional[Artifacts] = None + """dict with artifacts""" requirements: Requirements = Requirements.new() + """list of requirements""" @classmethod def get_metafile_path(cls, fullpath: str): @@ -592,6 +594,7 @@ class MlemModel(_WithArtifacts): object_type: ClassVar = "model" model_type_cache: Any model_type: ModelType + """framework-specific metadata""" model_type, model_type_raw, model_type_cache = lazy_field( ModelType, "model_type", "model_type_cache" ) @@ -644,8 +647,9 @@ class Config: exclude = {"data_type"} object_type: ClassVar = "data" - reader_cache: Optional[Dict] + reader_cache: Any reader: Optional[DataReader] + """How to read this data""" reader, reader_raw, reader_cache = lazy_field( DataReader, "reader", @@ -711,6 +715,7 @@ class Config: type_root = True type_field = "type" + type: ClassVar[str] object_type: ClassVar = "builder" abs_name: ClassVar[str] = "builder" @@ -728,6 +733,7 @@ class Config: abs_name: ClassVar[str] = "deploy_state" model_hash: Optional[str] = None + """hash of deployed model meta""" @abstractmethod def get_client(self): @@ -794,10 +800,15 @@ class Config: type: ClassVar[str] env_link: MlemLink + """Enironment to use""" env: Optional[MlemEnv] + """Enironment to use""" model_link: MlemLink + """Model to use""" model: Optional[MlemModel] + """Model to use""" state: Optional[DeployState] + """state""" def get_env(self): if self.env is None: diff --git a/mlem/core/requirements.py b/mlem/core/requirements.py index 26e7d6ee..a0348e26 100644 --- a/mlem/core/requirements.py +++ b/mlem/core/requirements.py @@ -56,23 +56,24 @@ class Config: class PythonRequirement(Requirement, ABC): + type: ClassVar = "_python" module: str + """python module name""" class InstallableRequirement(PythonRequirement): """ - This class represents pip-installable python library - - :param module: name of python module - :param version: version of python package - :param package_name: Optional. pip package name for this module, if it is different from module name + pip-installable python library """ type: ClassVar[str] = "installable" module: str + """name of python module""" version: Optional[str] = None + """version of python package""" package_name: Optional[str] = None + """pip package name for this module, if it is different from module name""" @property def package(self): @@ -135,17 +136,16 @@ def from_str(cls, name): class CustomRequirement(PythonRequirement): """ - This class represents local python code that you need as a requirement for your code - - :param name: filename of this code - :param source64zip: zipped and base64-encoded source - :param is_package: whether this code should be in %name%/__init__.py + local python code that you need as a requirement for your code """ type: ClassVar[str] = "custom" name: str + """filename of this code""" source64zip: str + """zipped and base64-encoded source""" is_package: bool + """whether this code should be in %name%/__init__.py""" @staticmethod def from_module(mod: ModuleType) -> "CustomRequirement": @@ -264,11 +264,13 @@ def to_sources_dict(self) -> Dict[str, bytes]: class FileRequirement(CustomRequirement): - """Represents an additional file""" + """Additional file""" type: ClassVar[str] = "file" is_package: bool = False + """ignored""" module: str = "" + """ignored""" def to_sources_dict(self): """ @@ -287,10 +289,11 @@ def from_path(cls, path: str): class UnixPackageRequirement(Requirement): - """Represents a unix package that needs to be installed""" + """Unix package that needs to be installed""" type: ClassVar[str] = "unix" package_name: str + """name of the package""" T = TypeVar("T", bound=Requirement) @@ -299,11 +302,10 @@ class UnixPackageRequirement(Requirement): class Requirements(BaseModel): """ A collection of requirements - - :param requirements: list of :class:`Requirement` instances """ __root__: List[Requirement] = [] + """list of :class:`Requirement` instances""" @property def installable(self) -> List[InstallableRequirement]: @@ -522,7 +524,7 @@ def resolve_requirements(other: "AnyRequirements") -> Requirements: class WithRequirements: - """A mixing for objects that should provide their requirements""" + """A mixin for objects that should provide their requirements""" def get_requirements(self) -> Requirements: from mlem.utils.module import get_object_requirements diff --git a/mlem/runtime/client.py b/mlem/runtime/client.py index 8332cd7b..5e0bc9f0 100644 --- a/mlem/runtime/client.py +++ b/mlem/runtime/client.py @@ -91,9 +91,13 @@ def __call__(self, *args, **kwargs): class HTTPClient(Client): + """Access models served with http-based servers""" + type: ClassVar[str] = "http" host: str = "0.0.0.0" + """Server host""" port: Optional[int] = 8080 + """Server port""" @property def base_url(self): diff --git a/mlem/runtime/interface.py b/mlem/runtime/interface.py index e526b517..b1cc22de 100644 --- a/mlem/runtime/interface.py +++ b/mlem/runtime/interface.py @@ -20,7 +20,9 @@ class ExecutionError(MlemError): class InterfaceDescriptor(BaseModel): version: str = mlem.version.__version__ + """mlem version""" methods: Dict[str, Signature] = {} + """interface methods""" class Interface(ABC, MlemABC): @@ -137,6 +139,7 @@ class SimpleInterface(Interface): type: ClassVar[str] = "simple" methods: InterfaceDescriptor = InterfaceDescriptor() + """interface version and methods""" def __init__(self, **data: Any): methods = {} @@ -175,6 +178,7 @@ class Config: type: ClassVar[str] = "model" model_type: ModelType + """model metadata""" def load(self, uri: str): meta = load_meta(uri) diff --git a/mlem/utils/entrypoints.py b/mlem/utils/entrypoints.py index 4809cc64..1c5b2ca1 100644 --- a/mlem/utils/entrypoints.py +++ b/mlem/utils/entrypoints.py @@ -52,26 +52,48 @@ def load_entrypoints(domain: str = MLEM_ENTRY_POINT) -> Dict[str, Entrypoint]: def list_implementations( base_class: Union[str, Type[MlemABC]], - meta_subtype: Type["MlemObject"] = None, + meta_subtype: Union[str, Type["MlemObject"]] = None, + include_hidden: bool = True, ) -> List[str]: + from mlem.core.objects import MlemObject + if isinstance(base_class, type) and issubclass(base_class, MlemABC): abs_name = base_class.abs_name - if base_class == "meta" and meta_subtype is not None: - base_class = meta_subtype.object_type + + if (base_class in ("meta", MlemObject)) and meta_subtype is not None: + if isinstance(meta_subtype, str): + base_class = meta_subtype + else: + base_class = meta_subtype.object_type abs_name = "meta" + resolved_base_class: Optional[Type[MlemABC]] = None if isinstance(base_class, str): abs_name = base_class try: - base_class = MlemABC.abs_types[abs_name] + resolved_base_class = MlemABC.abs_types[abs_name] except KeyError: - base_class = load_impl_ext(abs_name, None) + try: + resolved_base_class = load_impl_ext(abs_name, None) + except ValueError: + pass + else: + resolved_base_class = base_class eps = { e.name for e in load_entrypoints().values() if e.abs_name == abs_name and e.name is not None } - eps.update(base_class.non_abstract_subtypes()) - return list(eps) + if resolved_base_class is not None: + eps.update(resolved_base_class.non_abstract_subtypes()) + return sorted(e for e in eps if include_hidden or not e.startswith("_")) + + +def list_abstractions( + include_hidden: bool = True, +) -> List[str]: + eps = {e.abs_name for e in load_entrypoints().values()} + eps.update(MlemABC.abs_types) + return [e for e in eps if include_hidden or not e.startswith("_")] IT = TypeVar("IT") @@ -123,7 +145,7 @@ def find_abc_implementations(root_module_name: str = MLEM_ENTRY_POINT): return { MLEM_ENTRY_POINT: [ f"{obj.abs_name}.{obj.__get_alias__()} = {name}" - if not obj.__is_root__ + if not obj.__is_root__ or hasattr(obj, obj.__type_field__()) else f"{obj.abs_name} = {name}" for obj, name in impls.items() if hasattr(obj, "abs_name") diff --git a/mlem/utils/templates.py b/mlem/utils/templates.py index 2ffe4caa..d86bd557 100644 --- a/mlem/utils/templates.py +++ b/mlem/utils/templates.py @@ -13,6 +13,7 @@ class TemplateModel(BaseModel): TEMPLATE_DIR: ClassVar[str] templates_dir: List[str] = [] + """list of directories to look for jinja templates""" def prepare_dict(self): return self.dict() diff --git a/setup.py b/setup.py index 79f347ec..6e940fe1 100644 --- a/setup.py +++ b/setup.py @@ -151,7 +151,7 @@ "env.docker = mlem.contrib.docker.base:DockerEnv", "docker_registry.docker_io = mlem.contrib.docker.base:DockerIORegistry", "builder.docker = mlem.contrib.docker.base:DockerImageBuilder", - "docker_registry = mlem.contrib.docker.base:DockerRegistry", + "docker_registry.local = mlem.contrib.docker.base:DockerRegistry", "docker_registry.remote = mlem.contrib.docker.base:RemoteRegistry", "artifact.dvc = mlem.contrib.dvc:DVCArtifact", "storage.dvc = mlem.contrib.dvc:DVCStorage", @@ -162,7 +162,7 @@ "deployment.heroku = mlem.contrib.heroku.meta:HerokuDeployment", "env.heroku = mlem.contrib.heroku.meta:HerokuEnv", "deploy_state.heroku = mlem.contrib.heroku.meta:HerokuState", - "server.heroku = mlem.contrib.heroku.server:HerokuServer", + "server._heroku = mlem.contrib.heroku.server:HerokuServer", "data_reader.lightgbm = mlem.contrib.lightgbm:LightGBMDataReader", "data_type.lightgbm = mlem.contrib.lightgbm:LightGBMDataType", "data_writer.lightgbm = mlem.contrib.lightgbm:LightGBMDataWriter", diff --git a/tests/cli/conftest.py b/tests/cli/conftest.py index 31e89f90..c0e39160 100644 --- a/tests/cli/conftest.py +++ b/tests/cli/conftest.py @@ -4,13 +4,20 @@ from mlem.cli import app +app.pretty_exceptions_short = False + class Runner: def __init__(self): self._runner = CliRunner() - def invoke(self, *args, **kwargs) -> Result: - return self._runner.invoke(app, *args, **kwargs) + def invoke(self, *args, raise_on_error: bool = False, **kwargs) -> Result: + result = self._runner.invoke(app, *args, **kwargs) + if raise_on_error and result.exit_code != 0: + if result.exit_code == 1: + raise result.exception + raise RuntimeError(result.output) + return result @pytest.fixture diff --git a/tests/cli/test_apply.py b/tests/cli/test_apply.py index 76af85f2..9a239e2a 100644 --- a/tests/cli/test_apply.py +++ b/tests/cli/test_apply.py @@ -218,14 +218,16 @@ def test_apply_remote(mlem_client, runner, data_path): [ "apply-remote", "http", + "-d", data_path, - "-c", - "host=''", - "-c", - "port=None", + "--host", + "", + "--port", + "None", "-o", path, ], + raise_on_error=True, ) assert result.exit_code == 0, (result.output, result.exception) predictions = load(path) diff --git a/tests/cli/test_build.py b/tests/cli/test_build.py index 0e66ef9a..01b12722 100644 --- a/tests/cli/test_build.py +++ b/tests/cli/test_build.py @@ -1,27 +1,98 @@ +import json import os.path from typing import ClassVar +from pydantic import parse_obj_as +from yaml import safe_dump + +from mlem.cli.build import create_build_command +from mlem.contrib.fastapi import FastAPIServer from mlem.core.objects import MlemBuilder, MlemModel +from mlem.runtime.server import Server from mlem.utils.path import make_posix from tests.cli.conftest import Runner class BuilderMock(MlemBuilder): + """mock""" + type: ClassVar = "mock" target: str + """target""" + server: Server + """server""" def build(self, obj: MlemModel): with open(self.target, "w", encoding="utf8") as f: - f.write(obj.loc.path) + f.write(obj.loc.path + "\n") + json.dump(self.server.dict(), f) + + +create_build_command(BuilderMock.type) def test_build(runner: Runner, model_meta_saved_single, tmp_path): path = os.path.join(tmp_path, "packed") result = runner.invoke( - f"build {make_posix(model_meta_saved_single.loc.uri)} -c target={make_posix(path)} mock" + f"build mock -m {make_posix(model_meta_saved_single.loc.uri)} --target {make_posix(path)} --server fastapi --server.port 1000" + ) + + assert result.exit_code == 0, (result.exception, result.output) + + with open(path, encoding="utf8") as f: + lines = f.read().splitlines() + assert len(lines) == 2 + path, serv = lines + assert path == model_meta_saved_single.loc.path + assert parse_obj_as(Server, json.loads(serv)) == FastAPIServer( + port=1000 + ) + + +def test_build_with_file_conf( + runner: Runner, model_meta_saved_single, tmp_path +): + path = os.path.join(tmp_path, "packed") + server_path = os.path.join(tmp_path, "server.yaml") + with open(server_path, "w", encoding="utf8") as f: + safe_dump(FastAPIServer(port=9999).dict(), f) + + result = runner.invoke( + f"build mock -m {make_posix(model_meta_saved_single.loc.uri)} --target {make_posix(path)} --file_conf server={make_posix(server_path)}" + ) + + assert result.exit_code == 0, (result.exception, result.output) + + with open(path, encoding="utf8") as f: + lines = f.read().splitlines() + assert len(lines) == 2 + path, serv = lines + assert path == model_meta_saved_single.loc.path + assert parse_obj_as(Server, json.loads(serv)) == FastAPIServer( + port=9999 + ) + + +def test_build_with_load(runner: Runner, model_meta_saved_single, tmp_path): + path = os.path.join(tmp_path, "packed") + load_path = os.path.join(tmp_path, "builder.yaml") + builder = BuilderMock( + server=FastAPIServer(port=9999), target=make_posix(path) + ) + with open(load_path, "w", encoding="utf8") as f: + safe_dump(builder.dict(), f) + + result = runner.invoke( + f"build -m {make_posix(model_meta_saved_single.loc.uri)} --load {make_posix(load_path)}" ) assert result.exit_code == 0, (result.exception, result.output) with open(path, encoding="utf8") as f: - assert f.read().strip() == model_meta_saved_single.loc.path + lines = f.read().splitlines() + assert len(lines) == 2 + path, serv = lines + assert path == model_meta_saved_single.loc.path + assert parse_obj_as(Server, json.loads(serv)) == FastAPIServer( + port=9999 + ) diff --git a/tests/cli/test_declare.py b/tests/cli/test_declare.py index 6686bf94..1886eaa2 100644 --- a/tests/cli/test_declare.py +++ b/tests/cli/test_declare.py @@ -1,14 +1,478 @@ +from functools import lru_cache +from typing import Any, Dict, List, Optional + +import pytest +from pydantic import BaseModel + +from mlem.cli.declare import create_declare_mlem_object_subcommand, declare +from mlem.contrib.docker import DockerDirBuilder +from mlem.contrib.docker.context import DockerBuildArgs +from mlem.contrib.fastapi import FastAPIServer from mlem.contrib.heroku.meta import HerokuEnv +from mlem.contrib.pip.base import PipBuilder +from mlem.core.base import build_mlem_object from mlem.core.metadata import load_meta +from mlem.core.objects import MlemBuilder, MlemModel +from mlem.runtime.server import Server from mlem.utils.path import make_posix from tests.cli.conftest import Runner +builder_typer = [ + g.typer_instance + for g in declare.registered_groups + if g.typer_instance.info.name == "builder" +][0] +builder_typer.pretty_exceptions_short = False + +all_test_params = [] + + +class SimpleValue(BaseModel): + value: str + + +class ComplexValue(BaseModel): + field: str + field_list: List[str] = [] + field_dict: Dict[str, str] = {} + + +class ListValue(BaseModel): + f: List[str] = [] + + +class _MockBuilder(MlemBuilder): + """mock""" + + def build(self, obj: MlemModel): + pass + + def __init_subclass__(cls): + cls.__doc__ = "mock" + super().__init_subclass__() + def test_declare(runner: Runner, tmp_path): result = runner.invoke( - f"declare env heroku {make_posix(str(tmp_path))} -c api_key=aaa" + f"declare env heroku {make_posix(str(tmp_path))} --api_key aaa" ) assert result.exit_code == 0, result.exception env = load_meta(str(tmp_path)) assert isinstance(env, HerokuEnv) assert env.api_key == "aaa" + + +@pytest.mark.parametrize( + "args, res", + [ + ("", []), + ( + "--args.templates_dir.0 kek --args.templates_dir.1 kek2", + ["kek", "kek2"], + ), + ], +) +def test_declare_list(runner: Runner, tmp_path, args, res): + result = runner.invoke( + f"declare builder docker_dir {make_posix(str(tmp_path))} --server fastapi --target lol " + + args, + raise_on_error=True, + ) + assert result.exit_code == 0, (result.exception, result.output) + builder = load_meta(str(tmp_path)) + assert isinstance(builder, DockerDirBuilder) + assert isinstance(builder.server, FastAPIServer) + assert builder.target == "lol" + assert isinstance(builder.args, DockerBuildArgs) + assert builder.args.templates_dir == res + + +@pytest.mark.parametrize( + "args, res", + [ + ("", {}), + ( + "--additional_setup_kwargs.key value --additional_setup_kwargs.key2 value2", + {"key": "value", "key2": "value2"}, + ), + ], +) +def test_declare_dict(runner: Runner, tmp_path, args, res): + result = runner.invoke( + f"declare builder pip {make_posix(str(tmp_path))} --package_name lol --target lol " + + args + ) + assert result.exit_code == 0, (result.exception, result.output) + builder = load_meta(str(tmp_path)) + assert isinstance(builder, PipBuilder) + assert builder.package_name == "lol" + assert builder.target == "lol" + assert builder.additional_setup_kwargs == res + + +class MockListComplexValue(_MockBuilder): + """mock""" + + field: List[ComplexValue] = [] + + +all_test_params.append( + pytest.param( + MockListComplexValue(), "", id=f"{MockListComplexValue.type}_empty" + ) +) +all_test_params.append( + pytest.param( + MockListComplexValue( + field=[ + ComplexValue( + field="a", + field_list=["a", "a"], + field_dict={"a": "a", "b": "b"}, + ), + ComplexValue( + field="a", + field_list=["a", "a"], + field_dict={"a": "a", "b": "b"}, + ), + ] + ), + "--field.0.field a --field.0.field_list.0 a --field.0.field_list.1 a --field.0.field_dict.a a --field.0.field_dict.b b " + "--field.1.field a --field.1.field_list.0 a --field.1.field_list.1 a --field.1.field_dict.a a --field.1.field_dict.b b", + id=f"{MockListComplexValue.type}_full", + ) +) + + +class MockListListValue(_MockBuilder): + """mock""" + + f: List[ListValue] = [] + + +all_test_params.append( + pytest.param(MockListListValue(), "", id="list_list_value_empty") +) +all_test_params.append( + pytest.param( + MockListListValue( + f=[ListValue(f=["a", "b"]), ListValue(f=["a", "b"])] + ), + "--f.0.f.0 a --f.0.f.1 b --f.1.f.0 a --f.1.f.1 b", + id="list_list_value_full", + ) +) + + +class MockModelListBuilder(_MockBuilder): + """mock""" + + field: List[SimpleValue] = [] + + +all_test_params.append( + pytest.param(MockModelListBuilder(), "", id="model_list_empty") +) +all_test_params.append( + pytest.param( + MockModelListBuilder( + field=[SimpleValue(value="kek"), SimpleValue(value="kek2")] + ), + "--field.0.value kek --field.1.value kek2", + id="model_list_full", + ) +) + + +class MockModelDictBuilder(_MockBuilder): + """mock""" + + field: Dict[str, SimpleValue] = {} + + +all_test_params.append( + pytest.param(MockModelDictBuilder(), "", id="model_dict_empty") +) +all_test_params.append( + pytest.param( + MockModelDictBuilder( + field={ + "k1": SimpleValue(value="kek"), + "k2": SimpleValue(value="kek2"), + } + ), + "--field.k1.value kek --field.k2.value kek2", + id="model_dict_empty", + ) +) + + +class MockFlatList(_MockBuilder): + """mock""" + + f: List[List[str]] = [] + + +all_test_params.append( + pytest.param(MockFlatList(f=[]), "", id="flat_list_empty") +) +all_test_params.append( + pytest.param( + MockFlatList(f=[["a", "a"], ["a", "a"]]), + "--f.0.0 a --f.0.1 a --f.1.0 a --f.1.1 a", + id="flat_list_full", + ) +) + + +class MockFlatListDict(_MockBuilder): + """mock""" + + f: List[Dict[str, str]] = [] + + +all_test_params.append( + pytest.param(MockFlatListDict(), "", id="flat_list_dict_empty") +) +all_test_params.append( + pytest.param( + MockFlatListDict(f=[{"k1": "a"}, {"k2": "b"}]), + "--f.0.k1 a --f.1.k2 b", + id="flat_list_dict_full", + ) +) + + +class MockFlatDictList(_MockBuilder): + """mock""" + + f: Dict[str, List[str]] = {} + + +all_test_params.append( + pytest.param(MockFlatDictList(), "", id="flat_dict_list_empty") +) +all_test_params.append( + pytest.param( + MockFlatDictList(f={"k1": ["a"], "k2": ["b"]}), + "--f.k1.0 a --f.k2.0 b", + id="flat_dict_list_full", + ) +) + + +class MockFlatDict(_MockBuilder): + """mock""" + + f: Dict[str, Dict[str, str]] = {} + + +all_test_params.append(pytest.param(MockFlatDict(), "", id="flat_dict_empty")) +all_test_params.append( + pytest.param( + MockFlatDict(f={"k1": {"k1": "a"}, "k2": {"k2": "b"}}), + "--f.k1.k1 a --f.k2.k2 b", + id="flat_dict_full", + ) +) + + +class MaskedField(_MockBuilder): + """mock""" + + field: ListValue + index: str + + +all_test_params.append( + pytest.param( + MaskedField(index="a", field=ListValue(f=["a"])), + "--.index a --field.f.0 a", + id="masked", + ) +) + + +class BooleanField(_MockBuilder): + field: bool + + +all_test_params.extend( + ( + pytest.param( + BooleanField(field=True), + "--field 1", + id="bool_true_1", + ), + pytest.param( + BooleanField(field=False), + "--field 0", + id="bool_false_0", + ), + pytest.param( + BooleanField(field=True), + "--field True", + id="bool_true", + ), + pytest.param( + BooleanField(field=False), + "--field False", + id="bool_false", + ), + ) +) + + +class AllowNoneField(_MockBuilder): + field: Optional[int] = 0 + + +all_test_params.extend( + ( + pytest.param( + AllowNoneField(field=10), "--field 10", id="allow_none_value" + ), + pytest.param( + AllowNoneField(field=None), "--field None", id="allow_none_none" + ), + pytest.param(AllowNoneField(), "", id="allow_none_default"), + ) +) + + +@lru_cache() +def _declare_builder_command(type_: str): + create_declare_mlem_object_subcommand( + builder_typer, + type_, + MlemBuilder.object_type, + MlemBuilder, + ) + + +@pytest.mark.parametrize("expected, args", all_test_params) +def test_declare_models( + runner: Runner, tmp_path, args: str, expected: MlemBuilder +): + _declare_builder_command(expected.__get_alias__()) + result = runner.invoke( + f"declare builder {expected.__get_alias__()} {make_posix(str(tmp_path))} " + + args, + raise_on_error=True, + ) + assert result.exit_code == 0, (result.exception, result.output) + builder = load_meta(str(tmp_path)) + assert isinstance(builder, type(expected)) + assert builder == expected + + +class RootValue(BaseModel): + __root__: List[str] = [] + + +class MockComplexBuilder(_MockBuilder): + """mock""" + + string: str + str_list: List[str] = [] + str_dict: Dict[str, str] = {} + str_list_dict: List[Dict[str, str]] = [] + str_dict_list: Dict[str, List[str]] = {} + value: ComplexValue + + value_list: List[ComplexValue] = [] + value_dict: Dict[str, ComplexValue] = {} + root_value: RootValue + root_list: List[RootValue] = [] + root_dict: Dict[str, RootValue] = {} + server: Server + server_list: List[Server] = [] + server_dict: Dict[str, Server] = {} + + +create_declare_mlem_object_subcommand( + builder_typer, + MockComplexBuilder.type, + MlemBuilder.object_type, + MlemBuilder, +) + + +def test_declare_all_together(runner: Runner, tmp_path): + args = [ + "string", + "str_list.0", + "str_list.1", + "str_dict.k1", + "str_dict.k2", + "str_list_dict.0.k1", + "str_list_dict.0.k2", + "str_list_dict.1.k1", + "str_list_dict.1.k2", + "str_dict_list.k1.0", + "str_dict_list.k1.1", + "str_dict_list.k2.0", + "str_dict_list.k2.1", + "value.field", + "value.field_list.0", + "value.field_list.1", + "value.field_dict.k1", + "value.field_dict.k2", + "value_list.0.field", + "value_list.0.field_list.0", + "value_list.0.field_list.1", + "value_list.0.field_dict.k1", + "value_list.0.field_dict.k2", + "value_list.1.field", + "value_list.1.field_list.0", + "value_list.1.field_list.1", + "value_list.1.field_dict.k1", + "value_list.1.field_dict.k2", + "value_dict.k1.field", + "value_dict.k1.field_list.0", + "value_dict.k1.field_list.1", + "value_dict.k1.field_dict.k1", + "value_dict.k1.field_dict.k2", + "value_dict.k2.field", + "value_dict.k2.field_list.0", + "value_dict.k2.field_list.1", + "value_dict.k2.field_dict.k1", + "value_dict.k2.field_dict.k2", + "root_value.0", + "root_value.1", + "root_list.0.0", + "root_list.0.1", + "root_list.1.0", + "root_list.1.1", + "root_dict.k1.0", + "root_dict.k1.1", + "root_dict.k2.0", + "root_dict.k2.1", + ] + server_args: Dict[str, Any] = { + "server": "fastapi", + "server.port": 0, + "server_list.0": "fastapi", + "server_list.0.port": 0, + "server_list.1": "fastapi", + "server_list.1.port": 0, + "server_dict.k1": "fastapi", + "server_dict.k1.port": 0, + "server_dict.k2": "fastapi", + "server_dict.k2.port": 0, + } + args_str = " ".join(f"--{k} lol" for k in args) + args_str += " " + " ".join(f"--{k} {v}" for k, v in server_args.items()) + result = runner.invoke( + f"declare builder {MockComplexBuilder.type} {make_posix(str(tmp_path))} {args_str}", + raise_on_error=True, + ) + assert result.exit_code == 0, (result.exception, result.output) + builder = load_meta(str(tmp_path)) + assert isinstance(builder, MockComplexBuilder) + assert builder == build_mlem_object( + MlemBuilder, + MockComplexBuilder.type, + str_conf=[f"{k}=lol" for k in args], + conf=server_args, + ) diff --git a/tests/cli/test_deployment.py b/tests/cli/test_deployment.py index 808feddf..a719e1e3 100644 --- a/tests/cli/test_deployment.py +++ b/tests/cli/test_deployment.py @@ -27,21 +27,30 @@ def mock_deploy_get_client(mocker, request_get_mock, request_post_mock): class DeployStateMock(DeployState): + """mock""" + def get_client(self) -> Client: pass class MlemDeploymentMock(MlemDeployment): + """mock""" + class Config: use_enum_values = True type: ClassVar = "mock" status: DeployStatus = DeployStatus.NOT_DEPLOYED + """status""" param: str = "" + """param""" state: DeployState = DeployStateMock() + """state""" class MlemEnvMock(MlemEnv): + """mock""" + type: ClassVar = "mock" deploy_type: ClassVar = MlemDeploymentMock @@ -93,7 +102,9 @@ def test_deploy_create_new( def test_deploy_create_existing(runner: Runner, mock_deploy_path): - result = runner.invoke(f"deploy run {mock_deploy_path}".split()) + result = runner.invoke( + f"deploy run {mock_deploy_path}".split(), raise_on_error=True + ) assert result.exit_code == 0, result.output meta = load_meta(mock_deploy_path) assert isinstance(meta, MlemDeploymentMock) diff --git a/tests/cli/test_main.py b/tests/cli/test_main.py index 48c25aaa..b2ce0921 100644 --- a/tests/cli/test_main.py +++ b/tests/cli/test_main.py @@ -11,7 +11,7 @@ def iter_group(group: Group, prefix=()): yield prefix, group for name, c in group.commands.items(): if isinstance(c, Group): - yield from iter_group(c, prefix + (name,)) + yield from iter_group(c, prefix) else: yield prefix + (name,), c @@ -40,18 +40,22 @@ def test_commands_help(app_cli_cmd): for name, cli_cmd in app_cli_cmd: if cli_cmd.help is None: no_help.append(name) - assert len(no_help) == 0, f"{no_help} cli commnads do not have help!" + assert len(no_help) == 0, f"{no_help} cli commands do not have help!" def test_commands_args_help(app_cli_cmd): no_help = [] for name, cmd in app_cli_cmd: + dynamic_metavar = getattr(cmd, "dynamic_metavar", None) for arg in cmd.params: + if arg.name == dynamic_metavar: + continue if arg.help is None: no_help.append(f"{name}:{arg.name}") assert len(no_help) == 0, f"{no_help} cli commnad args do not have help!" +@pytest.mark.xfail # TODO do we need examples for everything? def test_commands_examples(app_cli_cmd): no_examples = [] for name, cmd in app_cli_cmd: diff --git a/tests/cli/test_serve.py b/tests/cli/test_serve.py index 695c5a56..f3e34e2d 100644 --- a/tests/cli/test_serve.py +++ b/tests/cli/test_serve.py @@ -1,5 +1,6 @@ from typing import ClassVar +from mlem.cli.serve import create_serve_command from mlem.runtime import Interface from mlem.runtime.server import Server from mlem.ui import echo @@ -7,14 +8,20 @@ class MockServer(Server): + """mock""" + type: ClassVar = "mock" param: str = "wrong" + """param""" def serve(self, interface: Interface): echo(self.param) +create_serve_command(MockServer.type) + + def test_serve(runner: Runner, model_single_path): - result = runner.invoke(f"serve {model_single_path} mock -c param=aaa") + result = runner.invoke(f"serve mock -m {model_single_path} --param aaa") assert result.exit_code == 0, result.exception assert result.output.splitlines()[-1] == "aaa" diff --git a/tests/cli/test_types.py b/tests/cli/test_types.py new file mode 100644 index 00000000..4a70b291 --- /dev/null +++ b/tests/cli/test_types.py @@ -0,0 +1,85 @@ +from typing import Optional + +import pytest +from pydantic import BaseModel + +from mlem.cli.types import iterate_type_fields +from mlem.cli.utils import get_field_help +from mlem.core.base import MlemABC +from mlem.utils.entrypoints import list_implementations +from tests.cli.conftest import Runner + + +def test_types(runner: Runner): + result = runner.invoke("types") + assert result.exit_code == 0, (result.exception, result.output) + assert all(typename in result.output for typename in MlemABC.abs_types) + + +@pytest.mark.parametrize("abs_name", MlemABC.abs_types.keys()) +def test_types_abs_name(runner: Runner, abs_name): + result = runner.invoke(f"types {abs_name}") + assert result.exit_code == 0, result.exception + assert set(result.output.splitlines()) == set( + list_implementations(abs_name, include_hidden=False) + ) + + +@pytest.mark.parametrize( + "abs_name,subtype", + [ + (abs_name, subtype) + for abs_name, root_type in MlemABC.abs_types.items() + for subtype in list_implementations(root_type, include_hidden=False) + ], +) +def test_types_abs_name_subtype(runner: Runner, abs_name, subtype): + result = runner.invoke(f"types {abs_name} {subtype}") + assert result.exit_code == 0, result.exception + if not subtype.startswith("tests."): + assert "docstring missing" not in result.output + + +def test_iter_type_fields_subclass(): + class Parent(BaseModel): + parent: str + """parent""" + + class Child(Parent): + child: str + """child""" + excluded: Optional[str] = None + + class Config: + fields = {"excluded": {"exclude": True}} + + fields = list(iterate_type_fields(Child)) + + assert len(fields) == 2 + assert {get_field_help(Child, f.path) for f in fields} == { + "parent", + "child", + } + + +def test_iter_type_fields_subclass_multiinheritance(): + class Parent(BaseModel): + parent: str + """parent""" + + class Parent2(BaseModel): + parent2 = "" + """parent2""" + + class Child(Parent, Parent2): + child: str + """child""" + + fields = list(iterate_type_fields(Child)) + + assert len(fields) == 3 + assert {get_field_help(Child, f.path) for f in fields} == { + "parent", + "child", + "parent2", + } diff --git a/tests/contrib/test_bitbucket.py b/tests/contrib/test_bitbucket.py index 1f9bd916..5f018803 100644 --- a/tests/contrib/test_bitbucket.py +++ b/tests/contrib/test_bitbucket.py @@ -5,7 +5,7 @@ from mlem.contrib.bitbucketfs import BitBucketFileSystem from mlem.core.errors import RevisionNotFound -from mlem.core.meta_io import UriResolver, get_fs +from mlem.core.meta_io import Location, get_fs from mlem.core.metadata import load_meta from mlem.core.objects import MlemModel from tests.conftest import long @@ -71,7 +71,7 @@ def test_uri_resolver(uri): ["main", "branch", "tag", "3897d2ab"], ) def test_uri_resolver_rev(rev): - location = UriResolver.resolve(MLEM_TEST_REPO_URI, None, rev=rev, fs=None) + location = Location.resolve(MLEM_TEST_REPO_URI, None, rev=rev, fs=None) assert isinstance(location.fs, BitBucketFileSystem) assert location.fs.root == rev assert "README.md" in location.fs.ls("") @@ -80,7 +80,7 @@ def test_uri_resolver_rev(rev): @long def test_uri_resolver_wrong_rev(): with pytest.raises(RevisionNotFound): - UriResolver.resolve( + Location.resolve( MLEM_TEST_REPO_URI, None, rev="__not_exists__", fs=None ) diff --git a/tests/contrib/test_docker/test_context.py b/tests/contrib/test_docker/test_context.py index b708428a..0bb788c8 100644 --- a/tests/contrib/test_docker/test_context.py +++ b/tests/contrib/test_docker/test_context.py @@ -97,6 +97,12 @@ def test_dockerfile_generator_super_custom(): assert _generate_dockerfile(**kwargs) == dockerfile +def test_dockerfile_generator_no_cmd(): + kwargs = {"run_cmd": None} + with use_mlem_source("pip"): + assert "CMD" not in _generate_dockerfile(**kwargs) + + def test_use_wheel_installation(tmpdir): distr = tmpdir.mkdir("distr").join("somewhatwheel.txt") distr.write("wheel goes brrr") @@ -123,11 +129,11 @@ def test_docker_registry_io(): registry = DockerIORegistry() client = docker.DockerClient() - client.images.pull("hello-world:latest") + client.images.pull("library/hello-world:latest") assert registry.get_host() == "https://index.docker.io/v1/" - registry.push(client, "hello-world:latest") - image = DockerImage(name="hello-world") + registry.push(client, "library/hello-world:latest") + image = DockerImage(name="library/hello-world") assert registry.image_exists(client, image) diff --git a/tests/contrib/test_docker/test_utils.py b/tests/contrib/test_docker/test_utils.py index 41c2e0b4..d8846634 100644 --- a/tests/contrib/test_docker/test_utils.py +++ b/tests/contrib/test_docker/test_utils.py @@ -8,10 +8,12 @@ @docker_test def test_image_exists(): - assert image_exists_at_dockerhub(f"python:{get_python_version()}-slim") + assert image_exists_at_dockerhub( + f"python:{get_python_version()}-slim", library=True + ) assert image_exists_at_dockerhub("minio/minio:latest") - assert image_exists_at_dockerhub("postgres:alpine") - assert image_exists_at_dockerhub("registry:latest") + assert image_exists_at_dockerhub("postgres:alpine", library=True) + assert image_exists_at_dockerhub("registry:latest", library=True) @docker_test @@ -25,7 +27,7 @@ def test_image_not_exists(): @docker_test def test_repository_tags(): - tags = repository_tags_at_dockerhub("python") + tags = repository_tags_at_dockerhub("python", library=True) assert f"{get_python_version()}-slim" in tags assert get_python_version() in tags diff --git a/tests/contrib/test_gitlab.py b/tests/contrib/test_gitlab.py index a889d59a..ac9a409a 100644 --- a/tests/contrib/test_gitlab.py +++ b/tests/contrib/test_gitlab.py @@ -2,7 +2,7 @@ from mlem.contrib.gitlabfs import GitlabFileSystem from mlem.core.errors import RevisionNotFound -from mlem.core.meta_io import UriResolver, get_fs +from mlem.core.meta_io import Location, get_fs from mlem.core.metadata import load_meta from mlem.core.objects import MlemModel from tests.conftest import long @@ -46,7 +46,7 @@ def test_uri_resolver(uri): ["main", "branch", "tag", "3897d2ab"], ) def test_uri_resolver_rev(rev): - location = UriResolver.resolve(MLEM_TEST_REPO_URI, None, rev=rev, fs=None) + location = Location.resolve(MLEM_TEST_REPO_URI, None, rev=rev, fs=None) assert isinstance(location.fs, GitlabFileSystem) assert location.fs.root == rev assert "README.md" in location.fs.ls("") @@ -55,7 +55,7 @@ def test_uri_resolver_rev(rev): @long def test_uri_resolver_wrong_rev(): with pytest.raises(RevisionNotFound): - UriResolver.resolve( + Location.resolve( MLEM_TEST_REPO_URI, None, rev="__not_exists__", fs=None ) diff --git a/tests/core/test_base.py b/tests/core/test_base.py index d5624825..dccc75da 100644 --- a/tests/core/test_base.py +++ b/tests/core/test_base.py @@ -1,8 +1,16 @@ -from typing import ClassVar, Optional +from typing import ClassVar, List, Optional + +from pydantic import BaseModel from mlem.contrib.docker import DockerImageBuilder from mlem.contrib.fastapi import FastAPIServer -from mlem.core.base import MlemABC, build_mlem_object, parse_links, smart_split +from mlem.core.base import ( + MlemABC, + SmartSplitDict, + build_mlem_object, + parse_links, + smart_split, +) from mlem.core.objects import MlemBuilder, MlemLink, MlemModel, MlemObject from mlem.runtime.server import Server from tests.conftest import resource_path @@ -51,11 +59,12 @@ def test_build_with_replace(): res = build_mlem_object( MockMlemABC, "mock", - ["server=fastapi", "server.port=8081"], + ["server=fastapi", "server.port=8081", "server.host=localhost"], ) assert isinstance(res, MockMlemABC) assert isinstance(res.server, FastAPIServer) assert res.server.port == 8081 + assert res.server.host == "localhost" res = build_mlem_object( MockMlemABC, @@ -64,3 +73,157 @@ def test_build_with_replace(): ) assert isinstance(res, MockMlemABC) assert isinstance(res.server, FastAPIServer) + + res = build_mlem_object( + MockMlemABC, + "mock", + conf={ + "server": "fastapi", + "server.port": 8081, + "server.host": "localhost", + }, + ) + assert isinstance(res, MockMlemABC) + assert isinstance(res.server, FastAPIServer) + assert res.server.port == 8081 + assert res.server.host == "localhost" + + +def test_build_with_list(): + class MockMlemABCList(MlemABC): + abs_name: ClassVar = "mock_list" + values: List[str] + + res = build_mlem_object( + MockMlemABCList, + "mock_list", + ["values.0=a", "values.1=b"], + ) + assert isinstance(res, MockMlemABCList) + assert isinstance(res.values, list) + assert res.values == ["a", "b"] + + +def test_build_with_list_complex(): + class Value(BaseModel): + field: str + + class MockMlemABCListComplex(MlemABC): + abs_name: ClassVar = "mock_list_complex" + values: List[Value] + + res = build_mlem_object( + MockMlemABCListComplex, + "mock_list_complex", + ["values.0.field=a", "values.1.field=b"], + ) + assert isinstance(res, MockMlemABCListComplex) + assert isinstance(res.values, list) + assert res.values == [Value(field="a"), Value(field="b")] + + +def test_build_with_list_nested(): + class MockMlemABCListNested(MlemABC): + abs_name: ClassVar = "mock_list_complex" + values: List[List[str]] + + res = build_mlem_object( + MockMlemABCListNested, + MockMlemABCListNested.abs_name, + ["values.0.0=a", "values.0.1=b"], + ) + assert isinstance(res, MockMlemABCListNested) + assert isinstance(res.values, list) + assert res.values == [["a", "b"]] + + +def test_smart_split_dict(): + d = SmartSplitDict(sep=".") + d["a.b.c"] = 1 + d["a.b.d"] = 2 + d["a.e"] = 3 + d["a.f"] = 4 + d["g"] = 5 + + assert d.build() == {"g": 5, "a": {"f": 4, "e": 3, "b": {"d": 2, "c": 1}}} + + +def test_smart_split_dict_with_list(): + d = SmartSplitDict(sep=".") + d["a.0"] = 1 + d["a.1"] = 2 + d["b"] = 3 + + assert d.build() == {"a": [1, 2], "b": 3} + + +def test_smart_split_dict_with_nested(): + d = SmartSplitDict(sep=".") + d["ll.0.0"] = 1 + d["ll.0.1"] = 2 + d["ll.1.0"] = 3 + d["ll.1.1"] = 4 + d["ld.0.a"] = 5 + d["ld.0.b"] = 6 + d["ld.1.a"] = 7 + d["ld.1.b"] = 8 + d["dl.a.0"] = 9 + d["dl.a.1"] = 10 + d["dl.b.0"] = 11 + d["dl.b.1"] = 12 + d["dd.a.a"] = 13 + d["dd.a.b"] = 14 + d["dd.b.a"] = 15 + d["dd.b.b"] = 16 + + assert d.build() == { + "ll": [[1, 2], [3, 4]], + "ld": [{"a": 5, "b": 6}, {"a": 7, "b": 8}], + "dl": {"a": [9, 10], "b": [11, 12]}, + "dd": {"a": {"a": 13, "b": 14}, "b": {"a": 15, "b": 16}}, + } + + +def test_smart_split_dict_nested_list(): + d = SmartSplitDict() + d["r.k1.0"] = "lol" + d["r.k1.1"] = "lol" + d["r.k2.0"] = "lol" + d["r.k2.1"] = "lol" + + assert d.build() == {"r": {"k1": ["lol", "lol"], "k2": ["lol", "lol"]}} + + +def test_smart_split_dict_with_type(): + d = SmartSplitDict(sep=".") + d["server"] = "fastapi" + d["server.port"] = 8080 + assert d.build() == {"server": {"type": "fastapi", "port": 8080}} + + +def test_smart_split_dict_prebuilt(): + d = SmartSplitDict(sep=".") + d["a.b.c"] = 1 + d["a"] = {"b": {"d": 2}} + assert d.build() == {"a": {"b": {"c": 1, "d": 2}}} + + +def test_smart_split_dict_list_with_type(): + d = SmartSplitDict(sep=".") + d["server.0"] = "fastapi" + d["server.0.port"] = 8080 + assert d.build() == {"server": [{"type": "fastapi", "port": 8080}]} + + +def test_smart_split_dict_dict_with_type(): + d = SmartSplitDict(sep=".") + d["server.a"] = "fastapi" + d["server.a.port"] = 8080 + d["server.b"] = "fastapi" + d["server.b.port"] = 8080 + assert d.build() == { + "server": { + "a": {"type": "fastapi", "port": 8080}, + "b": {"type": "fastapi", "port": 8080}, + } + } diff --git a/tests/core/test_meta_io.py b/tests/core/test_meta_io.py index 0990f238..37fddf71 100644 --- a/tests/core/test_meta_io.py +++ b/tests/core/test_meta_io.py @@ -11,7 +11,7 @@ from mlem import LOCAL_CONFIG from mlem.core.errors import RevisionNotFound -from mlem.core.meta_io import UriResolver, get_fs, get_path_by_fs_path, read +from mlem.core.meta_io import Location, get_fs, get_path_by_fs_path, read from tests.conftest import ( MLEM_TEST_REPO, MLEM_TEST_REPO_NAME, @@ -84,7 +84,7 @@ def test_get_fs_github(uri, rev): @long def test_github_wrong_rev(): with pytest.raises(RevisionNotFound): - UriResolver.resolve( + Location.resolve( MLEM_TEST_REPO, project=None, rev="__not_exists__kek", fs=None ) diff --git a/tests/core/test_objects.py b/tests/core/test_objects.py index c24b35cd..2576d9eb 100644 --- a/tests/core/test_objects.py +++ b/tests/core/test_objects.py @@ -6,14 +6,14 @@ import pytest from fsspec.implementations.local import LocalFileSystem -from pydantic import ValidationError, parse_obj_as +from pydantic import parse_obj_as from sklearn.datasets import load_iris from mlem.core.artifacts import Artifacts, LocalArtifact, Storage from mlem.core.errors import MlemProjectNotFound, WrongRequirementsError from mlem.core.meta_io import MLEM_DIR, MLEM_EXT from mlem.core.metadata import load, load_meta -from mlem.core.model import ModelIO +from mlem.core.model import ModelIO, ModelType from mlem.core.objects import ( DeployState, MlemDeployment, @@ -370,14 +370,15 @@ def test_link_dump_in_mlem(model_path_mlem_project): def test_model_model_type_laziness(): payload = { - "model_type": {"type": "doesnotexist"}, + "model_type": {"type": "sklearn", "methods": {}}, "object_type": "model", "requirements": [], } model = parse_obj_as(MlemModel, payload) - assert model.model_type_raw == {"type": "doesnotexist"} - with pytest.raises(ValidationError): - print(model.model_type) + assert model.model_type_cache == {"type": "sklearn", "methods": {}} + assert isinstance(model.model_type_cache, dict) + assert isinstance(model.model_type, ModelType) + assert isinstance(model.model_type_cache, ModelType) def test_mlem_project_root(filled_mlem_project): @@ -426,11 +427,16 @@ def test_remove_old_artifacts(model, tmpdir, train): load(path).predict(train) +class MockModelType(ModelType): + io: ModelIO = MockModelIO(filename="") + + def test_checkenv(): model = MlemModel( requirements=Requirements.new( InstallableRequirement(module="pytest", version=pytest.__version__) - ) + ), + model_type=MockModelType(methods={}), ) model.checkenv() diff --git a/tests/utils/test_entrypoints.py b/tests/utils/test_entrypoints.py new file mode 100644 index 00000000..231dd72a --- /dev/null +++ b/tests/utils/test_entrypoints.py @@ -0,0 +1,39 @@ +from abc import abstractmethod + +from mlem.core.base import MlemABC +from mlem.core.objects import MlemEnv, MlemObject +from mlem.utils.entrypoints import list_implementations + + +class MockABC(MlemABC): + abs_name = "mock" + + class Config: + type_root = True + + @abstractmethod + def something(self): + pass + + +class MockImpl(MockABC): + type = "impl" + + def something(self): + pass + + +def test_list_implementations(): + assert list_implementations(MockABC) == ["impl"] + assert list_implementations("mock") == ["impl"] + + +def test_list_implementations_meta(): + assert "model" in list_implementations("meta") + assert "model" in list_implementations(MlemObject) + + assert "docker" in list_implementations("meta", MlemEnv) + assert "docker" in list_implementations(MlemObject, MlemEnv) + + assert "docker" in list_implementations("meta", "env") + assert "docker" in list_implementations(MlemObject, "env") From d33d092b0d631150b807244e29e9d9baf73b25ce Mon Sep 17 00:00:00 2001 From: Mikhail Sveshnikov Date: Wed, 14 Sep 2022 13:22:47 +0300 Subject: [PATCH 2/4] new state POC (#340) * new state POC * update docker and mock deployments * add locks * simplify deployment meta (by complexifying code) * fix tests * fix tests * fix tests * fix win tests * default env and server * fsspec manager as default * Sagemaker deployments (#366) * WIP * its alive (kinda) * it works but it's ugly * little less ugly * lil fix * fix lint * fix lint * fix tests * fix tests * fix windows bugs * fix tests * fix tests * fix for dirs deployment state * create MlemSource to choose how mlem is added to docker * test that all configs in entrypoints * better cli val error * better docker package install * finish merge * fix short tests * fix short tests * Update mlem/contrib/sagemaker/runtime.py Co-authored-by: Alexander Guschin <1aguschin@gmail.com> * Update mlem/core/objects.py Co-authored-by: Alexander Guschin <1aguschin@gmail.com> * Update mlem/core/objects.py Co-authored-by: Alexander Guschin <1aguschin@gmail.com> * Update mlem/contrib/docker/base.py Co-authored-by: Alexander Guschin <1aguschin@gmail.com> * Update mlem/contrib/docker/base.py Co-authored-by: Alexander Guschin <1aguschin@gmail.com> * Update mlem/contrib/docker/base.py Co-authored-by: Alexander Guschin <1aguschin@gmail.com> * Update mlem/contrib/heroku/meta.py Co-authored-by: Alexander Guschin <1aguschin@gmail.com> * Apply suggestions from code review Co-authored-by: Alexander Guschin <1aguschin@gmail.com> * Update mlem/contrib/sagemaker/build.py Co-authored-by: Alexander Guschin <1aguschin@gmail.com> Co-authored-by: Alexander Guschin <1aguschin@gmail.com> --- .pylintrc | 2 +- mlem/api/commands.py | 27 +- mlem/api/utils.py | 44 +- mlem/cli/declare.py | 1 + mlem/cli/deployment.py | 48 +- mlem/cli/types.py | 10 +- mlem/cli/utils.py | 13 +- mlem/config.py | 18 + mlem/constants.py | 2 + mlem/contrib/docker/base.py | 150 +++--- mlem/contrib/docker/copy.j2 | 1 + mlem/contrib/docker/dockerfile.j2 | 7 +- mlem/contrib/docker/install_req.j2 | 4 + mlem/contrib/heroku/meta.py | 83 ++- .../contrib/sagemaker}/__init__.py | 0 mlem/contrib/sagemaker/build.py | 135 +++++ mlem/contrib/sagemaker/copy.j2 | 0 mlem/contrib/sagemaker/env_setup.py | 93 ++++ mlem/contrib/sagemaker/meta.py | 484 ++++++++++++++++++ mlem/contrib/sagemaker/mlem_sagemaker.tf | 82 +++ mlem/contrib/sagemaker/post_copy.j2 | 3 + mlem/contrib/sagemaker/runtime.py | 68 +++ mlem/core/errors.py | 15 + mlem/core/meta_io.py | 18 +- mlem/core/objects.py | 447 ++++++++++++++-- mlem/core/requirements.py | 4 +- mlem/ext.py | 1 + mlem/polydantic/core.py | 14 +- mlem/ui.py | 1 + mlem/utils/fslock.py | 113 ++++ mlem/utils/templates.py | 2 +- setup.py | 12 +- tests/cli/test_deployment.py | 154 +++++- .../test_docker/resources/dockerfile.j2 | 3 + tests/contrib/test_docker/test_deploy.py | 53 +- tests/contrib/test_heroku.py | 13 +- tests/core/test_objects.py | 19 +- tests/core/test_requirements.py | 6 + tests/test_config.py | 5 + tests/test_ext.py | 34 +- tests/utils/test_fslock.py | 62 +++ 41 files changed, 2032 insertions(+), 219 deletions(-) create mode 100644 mlem/contrib/docker/copy.j2 create mode 100644 mlem/contrib/docker/install_req.j2 rename {tests/pack => mlem/contrib/sagemaker}/__init__.py (100%) create mode 100644 mlem/contrib/sagemaker/build.py create mode 100644 mlem/contrib/sagemaker/copy.j2 create mode 100644 mlem/contrib/sagemaker/env_setup.py create mode 100644 mlem/contrib/sagemaker/meta.py create mode 100644 mlem/contrib/sagemaker/mlem_sagemaker.tf create mode 100644 mlem/contrib/sagemaker/post_copy.j2 create mode 100644 mlem/contrib/sagemaker/runtime.py create mode 100644 mlem/utils/fslock.py create mode 100644 tests/contrib/test_docker/resources/dockerfile.j2 create mode 100644 tests/utils/test_fslock.py diff --git a/.pylintrc b/.pylintrc index 75fac659..af266a7f 100644 --- a/.pylintrc +++ b/.pylintrc @@ -369,7 +369,7 @@ indent-string=' ' max-line-length=100 # Maximum number of lines in a module. -max-module-lines=1000 +max-module-lines=2000 # Allow the body of a class to be on the same line as the declaration if body # contains single statement. diff --git a/mlem/api/commands.py b/mlem/api/commands.py index 99e74ebe..9d230365 100644 --- a/mlem/api/commands.py +++ b/mlem/api/commands.py @@ -420,9 +420,11 @@ def deploy( fs: Optional[AbstractFileSystem] = None, external: bool = None, index: bool = None, + env_kwargs: Dict[str, Any] = None, **deploy_kwargs, ) -> MlemDeployment: deploy_path = None + update = False if isinstance(deploy_meta_or_path, str): deploy_path = deploy_meta_or_path try: @@ -432,13 +434,13 @@ def deploy( fs=fs, force_type=MlemDeployment, ) + update = True except MlemObjectNotFound: deploy_meta = None else: deploy_meta = deploy_meta_or_path - if model is not None: - deploy_meta.replace_model(get_model_meta(model)) + update = True if deploy_meta is None: if model is None or env is None: @@ -448,15 +450,24 @@ def deploy( if not deploy_path: raise MlemError("deploy_path cannot be empty") model_meta = get_model_meta(model) - env_meta = ensure_meta(MlemEnv, env) - deploy_meta = env_meta.deploy_type( - model=model_meta, - env=env_meta, - env_link=env_meta.make_link(), - model_link=model_meta.make_link(), + env_meta = ensure_meta(MlemEnv, env, allow_typename=True) + if isinstance(env_meta, type): + env = None + if env_kwargs: + env = env_meta(**env_kwargs) + deploy_type = env_meta.deploy_type + deploy_meta = deploy_type( + model_cache=model_meta, + model=model_meta.make_link(), + env=env, **deploy_kwargs, ) deploy_meta.dump(deploy_path, fs, project, index, external) + else: + if model is not None: + deploy_meta.replace_model(get_model_meta(model, load_value=False)) + if update: + pass # todo update from deploy_args and env_args # ensuring links are working deploy_meta.get_env() deploy_meta.get_model() diff --git a/mlem/api/utils.py b/mlem/api/utils.py index 902b8640..5d941b6b 100644 --- a/mlem/api/utils.py +++ b/mlem/api/utils.py @@ -1,8 +1,10 @@ import re -from typing import Any, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Optional, Tuple, Type, TypeVar, Union, overload -from mlem.core.base import MlemABC, build_mlem_object -from mlem.core.errors import InvalidArgumentError +from typing_extensions import Literal + +from mlem.core.base import MlemABC, build_mlem_object, load_impl_ext +from mlem.core.errors import InvalidArgumentError, MlemObjectNotFound from mlem.core.metadata import load, load_meta from mlem.core.objects import MlemData, MlemModel, MlemObject @@ -45,9 +47,41 @@ def get_model_meta( MM = TypeVar("MM", bound=MlemObject) -def ensure_meta(as_class: Type[MM], obj_or_path: Union[str, MM]) -> MM: +@overload +def ensure_meta( + as_class: Type[MM], + obj_or_path: Union[str, MM], + allow_typename: bool = False, +) -> Union[MM, Type[MM]]: + pass + + +@overload +def ensure_meta( + as_class: Type[MM], + obj_or_path: Union[str, MM], + allow_typename: Literal[False] = False, +) -> MM: + pass + + +def ensure_meta( + as_class: Type[MM], + obj_or_path: Union[str, MM], + allow_typename: bool = False, +) -> Union[MM, Type[MM]]: if isinstance(obj_or_path, str): - return load_meta(obj_or_path, force_type=as_class) + try: + return load_meta(obj_or_path, force_type=as_class) + except MlemObjectNotFound: + if allow_typename: + impl = load_impl_ext( + as_class.abs_name, obj_or_path, raise_on_missing=False + ) + if impl is None or not issubclass(impl, as_class): + raise + return impl + raise if isinstance(obj_or_path, as_class): return obj_or_path raise ValueError(f"Cannot get {as_class} from '{obj_or_path}'") diff --git a/mlem/cli/declare.py b/mlem/cli/declare.py index ef47b819..9acc2952 100644 --- a/mlem/cli/declare.py +++ b/mlem/cli/declare.py @@ -149,6 +149,7 @@ def subtype_command( "requirement", "resolver", "storage", + "state", } for abs_name in list_abstractions(include_hidden=False): if abs_name in {"builder", "env", "deployment"}: diff --git a/mlem/cli/deployment.py b/mlem/cli/deployment.py index 3547cab2..724a2011 100644 --- a/mlem/cli/deployment.py +++ b/mlem/cli/deployment.py @@ -23,7 +23,7 @@ from mlem.core.data_type import DataAnalyzer from mlem.core.errors import DeploymentError from mlem.core.metadata import load_meta -from mlem.core.objects import MlemDeployment +from mlem.core.objects import DeployState, DeployStatus, MlemDeployment from mlem.ui import echo, no_echo, set_echo deployment = Typer( @@ -64,6 +64,9 @@ def deploy_run( """ from mlem.api.commands import deploy + conf = conf or [] + env_conf = [c[len("env.") :] for c in conf if c.startswith("env.")] + conf = [c for c in conf if not c.startswith("env.")] deploy( path, model, @@ -71,6 +74,7 @@ def deploy_run( project, external=external, index=index, + env_kwargs=parse_string_conf(env_conf), **parse_string_conf(conf or []), ) @@ -107,6 +111,40 @@ def deploy_status( echo(status) +@mlem_command("wait", parent=deployment) +def deploy_wait( + path: str = Argument(..., help="Path to deployment meta"), + project: Optional[str] = option_project, + statuses: List[DeployStatus] = Option( + [DeployStatus.RUNNING], + "-s", + "--status", + help="statuses to wait for", + ), + intermediate: List[DeployStatus] = Option( + None, "-i", "--intermediate", help="Possible intermediate statuses" + ), + poll_timeout: float = Option( + 1.0, "-p", "--poll-timeout", help="Timeout between attempts" + ), + times: int = Option( + 0, "-t", "--times", help="Number of attempts. 0 -> indefinite" + ), +): + """Wait for status of deployed service + + Examples: + $ mlem deployment status service_name + """ + with no_echo(): + deploy_meta = load_meta( + path, project=project, force_type=MlemDeployment + ) + deploy_meta.wait_for_status( + statuses, poll_timeout, times, allowed_intermediate=intermediate + ) + + @mlem_command("apply", parent=deployment) def deploy_apply( path: str = Argument(..., help="Path to deployment meta"), @@ -133,11 +171,15 @@ def deploy_apply( deploy_meta = load_meta( path, project=project, rev=rev, force_type=MlemDeployment ) - if deploy_meta.state is None: + state: DeployState = deploy_meta.get_state() + if ( + state == deploy_meta.state_type() + and not deploy_meta.state_type.allow_default + ): raise DeploymentError( f"{deploy_meta.type} deployment has no state. Either {deploy_meta.type} is not deployed yet or has been un-deployed again." ) - client = deploy_meta.state.get_client() + client = deploy_meta.get_client(state) result = run_apply_remote( client, diff --git a/mlem/cli/types.py b/mlem/cli/types.py index af1f094e..52bba21c 100644 --- a/mlem/cli/types.py +++ b/mlem/cli/types.py @@ -12,7 +12,11 @@ from mlem.utils.entrypoints import list_abstractions, list_implementations -def _add_examples(generator: Iterator[CliTypeField], parent_help=None): +def _add_examples( + generator: Iterator[CliTypeField], + root_cls: Type[BaseModel], + parent_help=None, +): for field in generator: field.help = parent_help or field.help yield field @@ -28,13 +32,15 @@ def _add_examples(generator: Iterator[CliTypeField], parent_help=None): required=False, allow_none=False, default=None, + root_cls=root_cls, ), + root_cls=root_cls, parent_help=f"Element of {field.path}", ) def type_fields_with_collection_examples(cls): - yield from _add_examples(iterate_type_fields(cls)) + yield from _add_examples(iterate_type_fields(cls), root_cls=cls) def explain_type(cls: Type[BaseModel]): diff --git a/mlem/cli/utils.py b/mlem/cli/utils.py index f06d28e6..d92b0044 100644 --- a/mlem/cli/utils.py +++ b/mlem/cli/utils.py @@ -207,6 +207,7 @@ def parse_type_field( required: bool, allow_none: bool, default: Any, + root_cls: Type[BaseModel], ) -> Iterator[CliTypeField]: """Recursively creates CliTypeFields from field description""" if is_list or is_mapping: @@ -246,7 +247,7 @@ def parse_type_field( return if isinstance(type_, type) and issubclass(type_, BaseModel): # BaseModel (including MlemABC non-root classes): reqursively get nested - yield from iterate_type_fields(type_, path, not required) + yield from iterate_type_fields(type_, path, not required, root_cls) return # probably primitive field yield CliTypeField( @@ -263,9 +264,16 @@ def parse_type_field( def iterate_type_fields( - cls: Type[BaseModel], path: str = "", force_not_req: bool = False + cls: Type[BaseModel], + path: str = "", + force_not_req: bool = False, + root_cls: Type[BaseModel] = None, ) -> Iterator[CliTypeField]: """Recursively get CliTypeFields from BaseModel""" + if cls is root_cls: + # avoid infinite recursion + return + root_cls = root_cls or cls field: ModelField for name, field in sorted( cls.__fields__.items(), key=lambda x: not x[1].required @@ -317,6 +325,7 @@ def iterate_type_fields( required=not force_not_req and bool(field.required), allow_none=field.allow_none, default=field.default, + root_cls=root_cls, ) diff --git a/mlem/config.py b/mlem/config.py index 51afee87..ee2fd797 100644 --- a/mlem/config.py +++ b/mlem/config.py @@ -119,6 +119,8 @@ class Config: INDEX: Dict = {} EXTERNAL: bool = False EMOJIS: bool = True + STATE: Dict = {} + SERVER: Dict = {} @property def storage(self): @@ -145,6 +147,22 @@ def additional_extensions(self) -> List[str]: "," ) + @property + def state(self): + if not self.STATE: + return None + from mlem.core.objects import StateManager + + return parse_obj_as(StateManager, self.STATE) + + @property + def server(self): + from mlem.runtime.server import Server + + if not self.SERVER: + return parse_obj_as(Server, {"type": "fastapi"}) + return parse_obj_as(Server, self.SERVER) + LOCAL_CONFIG = MlemConfig() diff --git a/mlem/constants.py b/mlem/constants.py index dfe1af43..1f09eb95 100644 --- a/mlem/constants.py +++ b/mlem/constants.py @@ -1,4 +1,6 @@ MLEM_DIR = ".mlem" +MLEM_STATE_DIR = ".mlem.state" +MLEM_STATE_EXT = ".state" PREDICT_METHOD_NAME = "predict" PREDICT_PROBA_METHOD_NAME = "predict_proba" diff --git a/mlem/contrib/docker/base.py b/mlem/contrib/docker/base.py index e4b9d997..66b6ac8c 100644 --- a/mlem/contrib/docker/base.py +++ b/mlem/contrib/docker/base.py @@ -3,6 +3,7 @@ import logging import os import tempfile +import time from time import sleep from typing import ClassVar, Dict, Generator, Iterator, Optional @@ -12,6 +13,7 @@ from docker.errors import NotFound from pydantic import BaseModel +from mlem.config import project_config from mlem.contrib.docker.context import DockerBuildArgs, DockerModelDirectory from mlem.contrib.docker.utils import ( build_image_with_logs, @@ -183,10 +185,10 @@ def push(self, client, tag): if "error" in status: error_msg = status["error"] raise DeploymentError(f"Cannot push docker image: {error_msg}") - echo(EMOJI_OK + f"Pushed image {tag} to {self.host}") + echo(EMOJI_OK + f"Pushed image {tag} to {self.get_host()}") def uri(self, image: str): - return f"{self.host}/{image}" + return f"{self.get_host()}/{image}" def _get_digest(self, name, tag): r = requests.head( @@ -279,27 +281,31 @@ class DockerContainerState(DeployState): type: ClassVar = "docker_container" image: Optional[DockerImage] - """built image""" + """Built image""" + container_name: Optional[str] + """Name of container""" container_id: Optional[str] - """started container id""" - - def get_client(self): - raise NotImplementedError + """Started container id""" class _DockerBuildMixin(BaseModel): - server: Server - """server to use""" + server: Optional[Server] = None + """Server to use""" args: DockerBuildArgs = DockerBuildArgs() - """additional docker arguments""" + """Additional docker arguments""" + + +def generate_docker_container_name(): + return f"mlem-deploy-{int(time.time())}" class DockerContainer(MlemDeployment, _DockerBuildMixin): """MlemDeployment implementation for docker containers""" type: ClassVar = "docker_container" + state_type: ClassVar = DockerContainerState - container_name: str + container_name: Optional[str] = None """Name to use for container""" image_name: Optional[str] = None """Name to use for image""" @@ -309,13 +315,14 @@ class DockerContainer(MlemDeployment, _DockerBuildMixin): """Additional params""" rm: bool = True """Remove container on stop""" - state: Optional[DockerContainerState] = None - """state""" @property def ensure_image_name(self): return self.image_name or self.container_name + def _get_client(self, state: DockerContainerState): + raise NotImplementedError + class DockerEnv(MlemEnv[DockerContainer]): """MlemEnv implementation for docker environment""" @@ -335,29 +342,36 @@ def image_exists(self, image: DockerImage): with self.daemon.client() as client: return image.exists(client) - def run_container(self, meta: DockerContainer): - if meta.state is None or meta.state.image is None: + def run_container( + self, + meta: DockerContainer, + state: Optional[DockerContainerState] = None, + ): + state = state or meta.get_state() + if state.image is None: raise DeploymentError( f"Image {meta.ensure_image_name} is not built" ) with self.daemon.client() as client: - meta.state.image.registry.login(client) + state.image.registry.login(client) try: # always detach from container and just stream logs if detach=False + name = meta.container_name or generate_docker_container_name() container = client.containers.run( - meta.state.image.uri, - name=meta.container_name, + state.image.uri, + name=name, auto_remove=meta.rm, ports=meta.port_mapping, detach=True, **meta.params, ) - meta.state.container_id = container.id - meta.update() + state.container_id = container.id + state.container_name = name + meta.update_state(state) sleep(0.5) - if not container_is_running(client, meta.container_name): + if not container_is_running(client, name): if not meta.rm: for log in self.logs(meta, stdout=False, stderr=True): raise DeploymentError( @@ -377,73 +391,79 @@ def run_container(self, meta: DockerContainer): def logs( self, meta: DockerContainer, **kwargs ) -> Generator[str, None, None]: - if meta.state is None or meta.state.container_id is None: + state = meta.get_state() + if state.container_id is None: raise DeploymentError( f"Container {meta.container_name} is not deployed" ) with self.daemon.client() as client: - container = client.containers.get(meta.state.container_id) + container = client.containers.get(state.container_id) yield from container_logs(container, **kwargs) def deploy(self, meta: DockerContainer): self.check_type(meta) - - if meta.state is None: - meta.state = DockerContainerState() - - meta.update() - redeploy = False - if meta.state.image is None or meta.model_changed(): - from .helpers import build_model_image - - image_name = meta.image_name or meta.container_name - echo(EMOJI_BUILD + f"Creating docker image {image_name}") - with set_offset(2): - meta.state.image = build_model_image( - meta.get_model(), - image_name, - meta.server, - self, - force_overwrite=True, - **meta.args.dict(), + with meta.lock_state(): + state = meta.get_state() + if state.image is None or meta.model_changed(): + from .helpers import build_model_image + + image_name = ( + meta.image_name + or meta.container_name + or generate_docker_container_name() ) - meta.update_model_hash() - meta.update() - redeploy = True - if meta.state.container_id is None or redeploy: - self.run_container(meta) - - echo(EMOJI_OK + f"Container {meta.container_name} is up") + echo(EMOJI_BUILD + f"Creating docker image {image_name}") + with set_offset(2): + state.image = build_model_image( + meta.get_model(), + image_name, + meta.server + or project_config( + meta.loc.project if meta.is_saved else None + ).server, + self, + force_overwrite=True, + **meta.args.dict(), + ) + meta.update_model_hash(state=state) + meta.update_state(state) + redeploy = True + if state.container_id is None or redeploy: + self.run_container(meta, state) + + echo(EMOJI_OK + f"Container {state.container_name} is up") def remove(self, meta: DockerContainer): self.check_type(meta) - if meta.state is None or meta.state.container_id is None: - raise DeploymentError( - f"Container {meta.container_name} is not deployed" - ) + with meta.lock_state(): + state = meta.get_state() + if state.container_id is None: + raise DeploymentError( + f"Container {meta.container_name} is not deployed" + ) - with self.daemon.client() as client: - try: - container = client.containers.get(meta.state.container_id) - container.stop() - container.remove() - except docker.errors.NotFound: - pass - meta.state.container_id = None - meta.update() + with self.daemon.client() as client: + try: + container = client.containers.get(state.container_id) + container.stop() + container.remove() + except docker.errors.NotFound: + pass + state.container_id = None + meta.update_state(state) def get_status( self, meta: DockerContainer, raise_on_error=True ) -> DeployStatus: self.check_type(meta) - - if meta.state is None or meta.state.container_id is None: + state = meta.get_state() + if state.container_id is None: return DeployStatus.NOT_DEPLOYED with self.daemon.client() as client: try: - status = container_status(client, meta.state.container_id) + status = container_status(client, state.container_id) return CONTAINER_STATUS_MAPPING[status] except NotFound: return DeployStatus.UNKNOWN diff --git a/mlem/contrib/docker/copy.j2 b/mlem/contrib/docker/copy.j2 new file mode 100644 index 00000000..916bbf2c --- /dev/null +++ b/mlem/contrib/docker/copy.j2 @@ -0,0 +1 @@ +COPY . ./ diff --git a/mlem/contrib/docker/dockerfile.j2 b/mlem/contrib/docker/dockerfile.j2 index b5720706..a9c62383 100644 --- a/mlem/contrib/docker/dockerfile.j2 +++ b/mlem/contrib/docker/dockerfile.j2 @@ -1,12 +1,9 @@ FROM {{ base_image }} WORKDIR /app {% include "pre_install.j2" ignore missing %} -{% if packages %}RUN {{ package_install_cmd }} {{ packages|join(" ") }} {{ package_clean_cmd }}{% endif %} -COPY requirements.txt . -RUN pip install -r requirements.txt -{{ mlem_install }} +{% include "install_req.j2" %} {% include "post_install.j2" ignore missing %} -COPY . ./ +{% include "copy.j2" %} {% for name, value in env.items() %}ENV {{ name }}={{ value }} {% endfor %} {% include "post_copy.j2" ignore missing %} diff --git a/mlem/contrib/docker/install_req.j2 b/mlem/contrib/docker/install_req.j2 new file mode 100644 index 00000000..64f22d04 --- /dev/null +++ b/mlem/contrib/docker/install_req.j2 @@ -0,0 +1,4 @@ +{% if packages %}RUN {{ package_install_cmd }} {{ packages|join(" ") }} {{ package_clean_cmd }}{% endif %} +COPY requirements.txt . +RUN pip install -r requirements.txt +{{ mlem_install }} diff --git a/mlem/contrib/heroku/meta.py b/mlem/contrib/heroku/meta.py index 59c24263..91e3d488 100644 --- a/mlem/contrib/heroku/meta.py +++ b/mlem/contrib/heroku/meta.py @@ -52,18 +52,13 @@ def ensured_app(self) -> HerokuAppMeta: raise ValueError("App is not created yet") return self.app - def get_client(self) -> Client: - return HTTPClient( - host=urlparse(self.ensured_app.web_url).netloc, port=80 - ) - class HerokuDeployment(MlemDeployment): """Heroku App""" type: ClassVar = "heroku" - state: Optional[HerokuState] - """state""" + state_type: ClassVar = HerokuState + app_name: str """Heroku application name""" region: str = "us" @@ -71,7 +66,12 @@ class HerokuDeployment(MlemDeployment): stack: str = "container" """stack to use""" team: Optional[str] = None - """heroku team""" + """Heroku team""" + + def _get_client(self, state: HerokuState) -> Client: + return HTTPClient( + host=urlparse(state.ensured_app.web_url).netloc, port=80 + ) class HerokuEnv(MlemEnv[HerokuDeployment]): @@ -85,47 +85,43 @@ class HerokuEnv(MlemEnv[HerokuDeployment]): def deploy(self, meta: HerokuDeployment): from .utils import create_app, release_docker_app - if meta.state is None: - meta.state = HerokuState() - - meta.update() self.check_type(meta) + with meta.lock_state(): + state: HerokuState = meta.get_state() + if state.app is None: + state.app = create_app(meta, api_key=self.api_key) + meta.update_state(state) + + redeploy = False + if state.image is None or meta.model_changed(): + state.image = build_heroku_docker( + meta.get_model(), state.app.name, api_key=self.api_key + ) + meta.update_model_hash(state=state) + redeploy = True + if state.release_state is None or redeploy: + state.release_state = release_docker_app( + state.app.name, + state.image.image_id, + api_key=self.api_key, + ) + meta.update_state(state) - if meta.state.app is None: - meta.state.app = create_app(meta, api_key=self.api_key) - meta.update() - - redeploy = False - if meta.state.image is None or meta.model_changed(): - meta.state.image = build_heroku_docker( - meta.get_model(), meta.state.app.name, api_key=self.api_key + echo( + EMOJI_OK + + f"Service {meta.app_name} is up. You can check it out at {state.app.web_url}" ) - meta.update_model_hash() - meta.update() - redeploy = True - if meta.state.release_state is None or redeploy: - meta.state.release_state = release_docker_app( - meta.state.app.name, - meta.state.image.image_id, - api_key=self.api_key, - ) - meta.update() - - echo( - EMOJI_OK - + f"Service {meta.app_name} is up. You can check it out at {meta.state.app.web_url}" - ) def remove(self, meta: HerokuDeployment): from .utils import delete_app self.check_type(meta) - if meta.state is None: - return + with meta.lock_state(): + state: HerokuState = meta.get_state() - delete_app(meta.state.ensured_app.name, self.api_key) - meta.state = None - meta.update() + if state.app is not None: + delete_app(state.ensured_app.name, self.api_key) + meta.purge_state() def get_status( self, meta: "HerokuDeployment", raise_on_error=True @@ -133,14 +129,15 @@ def get_status( from .utils import list_dynos self.check_type(meta) - if meta.state is None or meta.state.app is None: + state: HerokuState = meta.get_state() + if state.app is None: return DeployStatus.NOT_DEPLOYED - dynos = list_dynos(meta.state.ensured_app.name, "web", self.api_key) + dynos = list_dynos(state.ensured_app.name, "web", self.api_key) if not dynos: if raise_on_error: raise DeploymentError( f"No heroku web dynos found, check your dashboard " - f"at https://dashboard.heroku.com/apps/{meta.state.ensured_app.name}" + f"at https://dashboard.heroku.com/apps/{state.ensured_app.name}" ) return DeployStatus.NOT_DEPLOYED return HEROKU_STATE_MAPPING[dynos[0]["state"]] diff --git a/tests/pack/__init__.py b/mlem/contrib/sagemaker/__init__.py similarity index 100% rename from tests/pack/__init__.py rename to mlem/contrib/sagemaker/__init__.py diff --git a/mlem/contrib/sagemaker/build.py b/mlem/contrib/sagemaker/build.py new file mode 100644 index 00000000..6fc8cb54 --- /dev/null +++ b/mlem/contrib/sagemaker/build.py @@ -0,0 +1,135 @@ +import base64 +import os +from typing import ClassVar, Optional + +import boto3 +import sagemaker +from pydantic import BaseModel + +from ...core.objects import MlemModel +from ...ui import EMOJI_BUILD, EMOJI_KEY, echo, set_offset +from ..docker.base import DockerEnv, DockerImage, RemoteRegistry +from ..docker.helpers import build_model_image +from .runtime import SageMakerServer + +IMAGE_NAME = "mlem-sagemaker-runner" + + +class AWSVars(BaseModel): + """AWS Configuration""" + + profile: str + """AWS Profile""" + bucket: str + """S3 Bucket""" + region: str + """AWS Region""" + account: str + """AWS Account name""" + role_name: str + """AWS Role name""" + + @property + def role(self): + return f"arn:aws:iam::{self.account}:role/{self.role_name}" + + def get_sagemaker_session(self): + return sagemaker.Session( + self.get_session(), default_bucket=self.bucket + ) + + def get_session(self): + return boto3.Session( + profile_name=self.profile, region_name=self.region + ) + + +def ecr_repo_check(region, repository, session: boto3.Session): + client = session.client("ecr", region_name=region) + + repos = client.describe_repositories()["repositories"] + + if repository not in {r["repositoryName"] for r in repos}: + echo(EMOJI_BUILD + f"Creating ECR repository {repository}") + client.create_repository(repositoryName=repository) + + +class ECRegistry(RemoteRegistry): + """ECR registry""" + + class Config: + exclude = {"aws_vars"} + + type: ClassVar = "ecr" + account: str + """AWS Account""" + region: str + """AWS Region""" + + aws_vars: Optional[AWSVars] = None + """AWS Configuration cache""" + + def login(self, client): + auth_data = self.ecr_client.get_authorization_token() + token = auth_data["authorizationData"][0]["authorizationToken"] + user, token = base64.b64decode(token).decode("utf8").split(":") + self._login(self.get_host(), client, user, token) + echo( + EMOJI_KEY + + f"Logged in to remote registry at host {self.get_host()}" + ) + + def get_host(self) -> Optional[str]: + return f"{self.account}.dkr.ecr.{self.region}.amazonaws.com" + + def image_exists(self, client, image: DockerImage): + images = self.ecr_client.list_images(repositoryName=image.name)[ + "imageIds" + ] + return len(images) > 0 + + def delete_image(self, client, image: DockerImage, force=False, **kwargs): + return self.ecr_client.batch_delete_image( + repositoryName=image.name, + imageIds=[{"imageTag": image.tag}], + ) + + def with_aws_vars(self, aws_vars): + self.aws_vars = aws_vars + return self + + @property + def ecr_client(self): + return ( + self.aws_vars.get_session().client("ecr") + if self.aws_vars + else boto3.client("ecr", region_name=self.region) + ) + + +def build_sagemaker_docker( + meta: MlemModel, + method: str, + account: str, + region: str, + image_name: str, + repository: str, + aws_vars: AWSVars, +): + docker_env = DockerEnv( + registry=ECRegistry(account=account, region=region).with_aws_vars( + aws_vars + ) + ) + ecr_repo_check(region, repository, aws_vars.get_session()) + echo(EMOJI_BUILD + "Creating docker image for sagemaker") + with set_offset(2): + return build_model_image( + meta, + name=repository, + tag=image_name, + server=SageMakerServer(method=method), + env=docker_env, + force_overwrite=True, + templates_dir=[os.path.dirname(__file__)], + ) diff --git a/mlem/contrib/sagemaker/copy.j2 b/mlem/contrib/sagemaker/copy.j2 new file mode 100644 index 00000000..e69de29b diff --git a/mlem/contrib/sagemaker/env_setup.py b/mlem/contrib/sagemaker/env_setup.py new file mode 100644 index 00000000..1b10258b --- /dev/null +++ b/mlem/contrib/sagemaker/env_setup.py @@ -0,0 +1,93 @@ +import os +import shutil +import subprocess + +from mlem.ui import echo + +MLEM_TF = "mlem_sagemaker.tf" + + +def _tf_command(tf_dir, command, *flags, **args): + args = " ".join(f"-var='{k}={v}'" for k, v in args.items()) + return " ".join( + [ + "terraform", + f"-chdir={tf_dir}", + command, + *flags, + args, + ] + ) + + +def _tf_get_var(tf_dir, varname): + return ( + subprocess.check_output( + _tf_command(tf_dir, "output", varname), shell=True + ) + .decode("utf8") + .strip() + .strip('"') + ) + + +def sagemaker_terraform( + user_name: str = "mlem", + role_name: str = "mlem", + region_name: str = "us-east-1", + profile: str = "default", + plan: bool = False, + work_dir: str = ".", + export_secret: str = None, +): + if not os.path.exists(work_dir): + os.makedirs(work_dir, exist_ok=True) + + shutil.copy( + os.path.join(os.path.dirname(__file__), MLEM_TF), + os.path.join(work_dir, MLEM_TF), + ) + subprocess.check_output(_tf_command(work_dir, "init"), shell=True) + + flags = ["-auto-approve"] if not plan else [] + + echo( + subprocess.check_output( + _tf_command( + work_dir, + "plan" if plan else "apply", + *flags, + role_name=role_name, + user_name=user_name, + region_name=region_name, + profile=profile, + ), + shell=True, + ) + ) + + if not plan and export_secret: + if os.path.exists(export_secret): + print( + f"Creds already present at {export_secret}, please backup and remove them" + ) + return + key_id = _tf_get_var(work_dir, "access_key_id") + access_secret = _tf_get_var(work_dir, "secret_access_key") + region = _tf_get_var(work_dir, "region_name") + profile = _tf_get_var(work_dir, "aws_user") + print(profile, region) + if export_secret.endswith(".csv"): + secrets = f"""User Name,Access key ID,Secret access key +{profile},{key_id},{access_secret}""" + print( + f"Import new profile:\naws configure import --csv file://{export_secret}\naws configure set region {region} --profile {profile}" + ) + else: + secrets = f"""export AWS_ACCESS_KEY_ID={key_id} +export AWS_SECRET_ACCESS_KEY={access_secret} +export AWS_REGION={region} +""" + print(f"Source envs:\nsource {export_secret}") + with open(export_secret, "w", encoding="utf8") as f: + f.write(secrets) diff --git a/mlem/contrib/sagemaker/meta.py b/mlem/contrib/sagemaker/meta.py new file mode 100644 index 00000000..385cb1bd --- /dev/null +++ b/mlem/contrib/sagemaker/meta.py @@ -0,0 +1,484 @@ +import os +import posixpath +import tarfile +import tempfile +from typing import ClassVar, Optional, Tuple + +import boto3 +import sagemaker +from pydantic import validator +from sagemaker.deserializers import JSONDeserializer +from sagemaker.serializers import JSONSerializer + +from mlem.config import MlemConfigBase, project_config +from mlem.contrib.docker.base import DockerDaemon, DockerImage +from mlem.contrib.sagemaker.build import ( + AWSVars, + ECRegistry, + build_sagemaker_docker, +) +from mlem.core.errors import WrongMethodError +from mlem.core.model import Signature +from mlem.core.objects import ( + DeployState, + DeployStatus, + MlemDeployment, + MlemEnv, + MlemModel, +) +from mlem.runtime.client import Client +from mlem.runtime.interface import InterfaceDescriptor +from mlem.ui import EMOJI_BUILD, EMOJI_UPLOAD, echo + +MODEL_TAR_FILENAME = "model.tar.gz" +DEFAULT_ECR_REPOSITORY = "mlem" + + +class AWSConfig(MlemConfigBase): + ROLE: Optional[str] + PROFILE: Optional[str] + + class Config: + section = "aws" + env_prefix = "AWS_" + + +def generate_model_file_name(deploy_id): + return f"mlem-model-{deploy_id}" + + +def generate_image_name(deploy_id): + return f"mlem-sagemaker-image-{deploy_id}" + + +class SagemakerClient(Client): + """Client to make SageMaker requests""" + + type: ClassVar = "sagemaker" + + endpoint_name: str + """Name of SageMaker Endpoint""" + aws_vars: AWSVars + """AWS Configuration""" + signature: Signature + """Signature of deployed method""" + + def _interface_factory(self) -> InterfaceDescriptor: + return InterfaceDescriptor(methods={"predict": self.signature}) + + def get_predictor(self): + sess = self.aws_vars.get_sagemaker_session() + predictor = sagemaker.Predictor( + endpoint_name=self.endpoint_name, + sagemaker_session=sess, + serializer=JSONSerializer(), + deserializer=JSONDeserializer(), + ) + return predictor + + def _call_method(self, name, args): + return self.get_predictor().predict(args) + + +class SagemakerDeployState(DeployState): + """State of SageMaker deployment""" + + type: ClassVar = "sagemaker" + + image: Optional[DockerImage] = None + """Built image""" + image_tag: Optional[str] = None + """Built image tag""" + model_location: Optional[str] = None + """Location of uploaded model""" + endpoint_name: Optional[str] = None + """Name of SageMaker endpoint""" + endpoint_model_hash: Optional[str] = None + """Hash of deployed model""" + method_signature: Optional[Signature] = None + """Signature of deployed method""" + region: Optional[str] = None + """AWS Region""" + previous: Optional["SagemakerDeployState"] = None + """Previous state""" + + @property + def image_uri(self): + if self.image is None: + if self.image_tag is None: + raise ValueError( + "Cannot get image_uri: image not built or not specified prebuilt image uri" + ) + return self.image_tag + return self.image.uri + + def get_predictor(self, session: sagemaker.Session): + predictor = sagemaker.Predictor( + endpoint_name=self.endpoint_name, + sagemaker_session=session, + serializer=JSONSerializer(), + deserializer=JSONDeserializer(), + ) + return predictor + + +class SagemakerDeployment(MlemDeployment): + """SageMaker Deployment""" + + type: ClassVar = "sagemaker" + state_type: ClassVar = SagemakerDeployState + + method: str = "predict" + """Model method to be deployed""" + image_tag: Optional[str] = None + """Name of the docker image to use""" + use_prebuilt: bool = False + """Use pre-built docker image. If True, image_name should be set""" + model_arch_location: Optional[str] = None + """Path on s3 to store model archive (excluding bucket)""" + model_name: Optional[str] + """Name for SageMaker Model""" + endpoint_name: Optional[str] = None + """Name for SageMaker Endpoint""" + initial_instance_count: int = 1 + """Initial instance count for Endpoint""" + instance_type: str = "ml.t2.medium" + """Instance type for Endpoint""" + accelerator_type: Optional[str] = None + "The size of the Elastic Inference (EI) instance to use" + + @validator("use_prebuilt") + def ensure_image_name( # pylint: disable=no-self-argument + cls, value, values # noqa: B902 + ): + if value and "image_name" not in values: + raise ValueError( + "image_name should be set if use_prebuilt is true" + ) + return value + + def _get_client(self, state: "SagemakerDeployState"): + return SagemakerClient( + endpoint_name=state.endpoint_name, + aws_vars=self.get_env().get_session_and_aws_vars( + region=state.region + )[1], + signature=state.method_signature, + ) + + +ENDPOINT_STATUS_MAPPING = { + "Creating": DeployStatus.STARTING, + "Failed": DeployStatus.CRASHED, + "InService": DeployStatus.RUNNING, + "OutOfService": DeployStatus.STOPPED, + "Updating": DeployStatus.STARTING, + "SystemUpdating": DeployStatus.STARTING, + "RollingBack": DeployStatus.STARTING, + "Deleting": DeployStatus.STOPPED, +} + + +class SagemakerEnv(MlemEnv): + """SageMaker environment""" + + type: ClassVar = "sagemaker" + deploy_type: ClassVar = SagemakerDeployment + + role: Optional[str] = None + """Default role""" + account: Optional[str] = None + """Default account""" + region: Optional[str] = None + """Default region""" + bucket: Optional[str] = None + """Default bucket""" + profile: Optional[str] = None + """Default profile""" + ecr_repository: Optional[str] = None + """Default ECR repository""" + + @property + def role_name(self): + return f"arn:aws:iam::{self.account}:role/{self.role}" + + @staticmethod + def _create_and_upload_model_arch( + session: sagemaker.Session, + model: MlemModel, + bucket: str, + model_arch_location: str, + ) -> str: + with tempfile.TemporaryDirectory() as dirname: + model.clone(os.path.join(dirname, "model", "model")) + arch_path = os.path.join(dirname, "arch", MODEL_TAR_FILENAME) + os.makedirs(os.path.dirname(arch_path)) + with tarfile.open(arch_path, "w:gz") as tar: + path = os.path.join(dirname, "model") + for file in os.listdir(path): + tar.add(os.path.join(path, file), arcname=file) + + model_location = session.upload_data( + os.path.dirname(arch_path), + bucket=bucket, + key_prefix=posixpath.join( + model_arch_location, model.meta_hash() + ), + ) + + return model_location + + @staticmethod + def _delete_model_file(session: sagemaker.Session, model_path: str): + s3_client = session.boto_session.client("s3") + if model_path.startswith("s3://"): + model_path = model_path[len("s3://") :] + bucket, *paths = model_path.split("/") + model_path = posixpath.join(*paths, MODEL_TAR_FILENAME) + s3_client.delete_object(Bucket=bucket, Key=model_path) + + def deploy(self, meta: SagemakerDeployment): + with meta.lock_state(): + state: SagemakerDeployState = meta.get_state() + redeploy = meta.model_changed() + state.previous = state.previous or SagemakerDeployState() + + session, aws_vars = self.get_session_and_aws_vars(state.region) + if state.region is None: + state.region = aws_vars.region + meta.update_state(state) + + if not meta.use_prebuilt and (state.image_tag is None or redeploy): + self._build_image(meta, state, aws_vars) + + if state.model_location is None or redeploy: + self._upload_model(meta, state, aws_vars, session) + + if ( + state.endpoint_name is None + or redeploy + or state.endpoint_model_hash is not None + and state.endpoint_model_hash != state.model_hash + ): + if state.endpoint_name is None: + self._deploy_model(meta, state, aws_vars, session) + else: + self._update_model(meta, state, aws_vars, session) + + def _update_model( + self, + meta: SagemakerDeployment, + state: SagemakerDeployState, + aws_vars: AWSVars, + session: sagemaker.Session, + ): + assert state.model_location is not None # TODO + sm_model = sagemaker.Model( + image_uri=state.image_uri, + model_data=posixpath.join( + state.model_location, MODEL_TAR_FILENAME + ), + name=meta.model_name, + role=aws_vars.role, + sagemaker_session=session, + ) + sm_model.create( + instance_type=meta.instance_type, + accelerator_type=meta.accelerator_type, + ) + prev_endpoint_conf = session.sagemaker_client.describe_endpoint( + EndpointName=state.endpoint_name + )["EndpointConfigName"] + prev_model_name = session.sagemaker_client.describe_endpoint_config( + EndpointConfigName=prev_endpoint_conf + )["ProductionVariants"][0]["ModelName"] + + predictor = state.get_predictor(session) + predictor.update_endpoint( + model_name=sm_model.name, + initial_instance_count=meta.initial_instance_count, + instance_type=meta.instance_type, + accelerator_type=meta.accelerator_type, + wait=True, + ) + session.sagemaker_client.delete_model(ModelName=prev_model_name) + prev = state.previous + if prev is not None: + if prev.image is not None: + self._delete_image(meta, prev, aws_vars) + if prev.model_location is not None: + self._delete_model_file(session, prev.model_location) + prev.model_location = None + session.sagemaker_client.delete_endpoint_config( + EndpointConfigName=prev_endpoint_conf + ) + state.endpoint_model_hash = state.model_hash + meta.update_state(state) + + def _delete_image(self, meta, state, aws_vars): + with DockerDaemon(host="").client() as client: + if isinstance(state.image.registry, ECRegistry): + state.image.registry.with_aws_vars(aws_vars) + state.image.delete(client) + state.image = None + meta.update_state(state) + + def _deploy_model( + self, + meta: SagemakerDeployment, + state: SagemakerDeployState, + aws_vars: AWSVars, + session: sagemaker.Session, + ): + assert state.model_location is not None # TODO + sm_model = sagemaker.Model( + image_uri=state.image_uri, + model_data=posixpath.join( + state.model_location, MODEL_TAR_FILENAME + ), + name=meta.model_name, + role=aws_vars.role, + sagemaker_session=session, + ) + echo( + EMOJI_BUILD + + f"Starting up sagemaker {meta.initial_instance_count} `{meta.instance_type}` instance(s)..." + ) + sm_model.deploy( + initial_instance_count=meta.initial_instance_count, + instance_type=meta.instance_type, + accelerator_type=meta.accelerator_type, + endpoint_name=meta.endpoint_name, + wait=False, + ) + state.endpoint_name = sm_model.endpoint_name + state.endpoint_model_hash = state.model_hash + meta.update_state(state) + + def _upload_model( + self, + meta: SagemakerDeployment, + state: SagemakerDeployState, + aws_vars: AWSVars, + session: sagemaker.Session, + ): + assert state.previous is not None # TODO + echo( + EMOJI_UPLOAD + + f"Uploading model distribution to {aws_vars.bucket}..." + ) + if state.model_location is not None: + state.previous.model_location = state.model_location + state.model_location = self._create_and_upload_model_arch( + session, + meta.get_model(), + aws_vars.bucket, + meta.model_arch_location + or generate_model_file_name(meta.get_model().meta_hash()), + ) + meta.update_model_hash(state=state) + meta.update_state(state) + + def _build_image( + self, + meta: SagemakerDeployment, + state: SagemakerDeployState, + aws_vars: AWSVars, + ): + assert state.previous is not None # TODO + model = meta.get_model() + try: + state.method_signature = model.model_type.methods[meta.method] + except KeyError as e: + raise WrongMethodError( + f"Wrong method {meta.method} for model {model.name}" + ) from e + image_tag = meta.image_tag or model.meta_hash() + if state.image_tag is not None: + state.previous.image_tag = state.image_tag + state.previous.image = state.image + state.image = build_sagemaker_docker( + model, + meta.method, + aws_vars.account, + aws_vars.region, + image_tag, + self.ecr_repository or DEFAULT_ECR_REPOSITORY, + aws_vars, + ) + state.image_tag = image_tag + meta.update_state(state) + + def remove(self, meta: SagemakerDeployment): + with meta.lock_state(): + state: SagemakerDeployState = meta.get_state() + session, aws_vars = self.get_session_and_aws_vars(state.region) + if state.model_location is not None: + self._delete_model_file(session, state.model_location) + if state.endpoint_name is not None: + + client = session.sagemaker_client + endpoint_conf = session.sagemaker_client.describe_endpoint( + EndpointName=state.endpoint_name + )["EndpointConfigName"] + + model_name = client.describe_endpoint_config( + EndpointConfigName=endpoint_conf + )["ProductionVariants"][0]["ModelName"] + client.delete_model(ModelName=model_name) + client.delete_endpoint(EndpointName=state.endpoint_name) + client.delete_endpoint_config(EndpointConfigName=endpoint_conf) + if state.image is not None: + self._delete_image(meta, state, aws_vars) + meta.purge_state() + + def get_status( + self, meta: SagemakerDeployment, raise_on_error=True + ) -> "DeployStatus": + with meta.lock_state(): + state: SagemakerDeployState = meta.get_state() + session = self.get_session(state.region) + + endpoint = session.sagemaker_client.describe_endpoint( + EndpointName=state.endpoint_name + ) + status = endpoint["EndpointStatus"] + return ENDPOINT_STATUS_MAPPING.get(status, DeployStatus.UNKNOWN) + + def get_session(self, region: str = None) -> sagemaker.Session: + return self.get_session_and_aws_vars(region)[0] + + def get_session_and_aws_vars( + self, region: str = None + ) -> Tuple[sagemaker.Session, AWSVars]: + return init_aws_vars( + self.profile, + self.role, + self.bucket, + region or self.region, + self.account, + ) + + +def init_aws_vars( + profile=None, role=None, bucket=None, region=None, account=None +): + boto_session = boto3.Session(profile_name=profile, region_name=region) + sess = sagemaker.Session(boto_session, default_bucket=bucket) + + bucket = ( + bucket or sess.default_bucket() + ) # Replace with your own bucket name if needed + region = region or boto_session.region_name + config = project_config(project="", section=AWSConfig) + role = role or config.ROLE or sagemaker.get_execution_role(sess) + account = account or boto_session.client("sts").get_caller_identity().get( + "Account" + ) + return sess, AWSVars( + bucket=bucket, + region=region, + account=account, + role_name=role, + profile=profile or config.PROFILE, + ) diff --git a/mlem/contrib/sagemaker/mlem_sagemaker.tf b/mlem/contrib/sagemaker/mlem_sagemaker.tf new file mode 100644 index 00000000..ffbb5a5d --- /dev/null +++ b/mlem/contrib/sagemaker/mlem_sagemaker.tf @@ -0,0 +1,82 @@ +variable "profile" { + description = "AWS Profile to use for API calls" + type = string + default = "default" +} + +variable "role_name" { + description = "AWS role name" + type = string + default = "mlem" +} + +variable "user_name" { + description = "AWS user name" + type = string + default = "mlem" +} + +variable "region_name" { + description = "AWS region name" + type = string + default = "us-east-1" +} + +provider "aws" { + region = var.region_name + profile = var.profile +} + +resource "aws_iam_user" "aws_user" { + name = var.user_name +} + +resource "aws_iam_access_key" "aws_user" { + user = aws_iam_user.aws_user.name +} + +resource "aws_iam_user_policy_attachment" "sagemaker_policy" { + user = aws_iam_user.aws_user.name + policy_arn = "arn:aws:iam::aws:policy/AmazonSageMakerFullAccess" +} + +resource "aws_iam_user_policy_attachment" "ecr_policy" { + user = aws_iam_user.aws_user.name + policy_arn = "arn:aws:iam::aws:policy/AmazonEC2ContainerRegistryFullAccess" +} + +resource "aws_iam_role" "aws_role" { + name = var.role_name + description = "MLEM SageMaker Role" + assume_role_policy = < /usr/local/bin/serve && chmod +x /usr/local/bin/serve +ENTRYPOINT ["bash", "-c"] diff --git a/mlem/contrib/sagemaker/runtime.py b/mlem/contrib/sagemaker/runtime.py new file mode 100644 index 00000000..a7c67171 --- /dev/null +++ b/mlem/contrib/sagemaker/runtime.py @@ -0,0 +1,68 @@ +import logging +from types import ModuleType +from typing import ClassVar, Dict, List + +import boto3 +import fastapi +import sagemaker +import uvicorn + +from mlem.config import MlemConfigBase, project_config +from mlem.contrib.fastapi import FastAPIServer +from mlem.runtime import Interface + +logger = logging.getLogger(__name__) + + +class SageMakerServerConfig(MlemConfigBase): + HOST: str = "0.0.0.0" + PORT: int = 8080 + METHOD: str = "predict" + + class Config: + section = "sagemaker" + + +local_config = project_config("", section=SageMakerServerConfig) + + +def ping(): + return "OK" + + +class SageMakerServer(FastAPIServer): + """Server to use inside SageMaker containers""" + + type: ClassVar = "_sagemaker" + libraries: ClassVar[List[ModuleType]] = [ + uvicorn, + fastapi, + sagemaker, + boto3, + ] + method: str = local_config.METHOD + """Method to expose""" + port: int = local_config.PORT + """Port to use""" + host: str = local_config.HOST + """Host to use""" + + def app_init(self, interface: Interface): + app = super().app_init(interface) + + handler, response_model = self._create_handler( + "invocations", + interface.get_method_signature(self.method), + interface.get_method_executor(self.method), + ) + app.add_api_route( + "/invocations", + handler, + methods=["POST"], + response_model=response_model, + ) + app.add_api_route("/ping", ping, methods=["GET"]) + return app + + def get_env_vars(self) -> Dict[str, str]: + return {"SAGEMAKER_METHOD": self.method} diff --git a/mlem/core/errors.py b/mlem/core/errors.py index c01a6f31..d63b22cc 100644 --- a/mlem/core/errors.py +++ b/mlem/core/errors.py @@ -119,6 +119,21 @@ def __init__(self, meta, force_type): ) +class WrongMetaSubType(TypeError, MlemError): + def __init__(self, meta, force_type): + loc = f"from {meta.loc.uri} " if meta.is_saved else "" + super().__init__( + f"Wrong type of meta loaded, got {meta.object_type} {meta.type} {loc}instead of {force_type.object_type} {force_type.type}" + ) + + +class WrongABCType(TypeError, MlemError): + def __init__(self, instance, expected_abc_type): + super().__init__( + f"Wrong implementation type, got {instance.type} instead of {expected_abc_type.type}" + ) + + class DeploymentError(MlemError): """Thrown if something goes wrong during deployment process""" diff --git a/mlem/core/meta_io.py b/mlem/core/meta_io.py index 35405fd7..b0c13f92 100644 --- a/mlem/core/meta_io.py +++ b/mlem/core/meta_io.py @@ -2,6 +2,7 @@ Utils functions that parse and process supplied URI, serialize/derialize MLEM objects """ import contextlib +import os import posixpath from abc import ABC, abstractmethod from inspect import isabstract @@ -43,8 +44,18 @@ def fullpath(self): def path_in_project(self): return posixpath.relpath(self.fullpath, self.project) + @property + def dirname(self): + return posixpath.dirname(self.fullpath) + + @property + def basename(self): + return posixpath.basename(self.path) + @contextlib.contextmanager - def open(self, mode="r", **kwargs): + def open(self, mode="r", make_dir: bool = False, **kwargs): + if make_dir: + self.fs.makedirs(posixpath.dirname(self.fullpath), exist_ok=True) with self.fs.open(self.fullpath, mode, **kwargs) as f: yield f @@ -57,12 +68,17 @@ def abs(cls, path: str, fs: AbstractFileSystem): def update_path(self, path): if not self.uri.endswith(self.path): raise ValueError("cannot automatically update uri") + if os.path.isabs(self.path) and not os.path.isabs(path): + path = posixpath.join(posixpath.dirname(self.path), path) self.uri = self.uri[: -len(self.path)] + path self.path = path def exists(self): return self.fs.exists(self.fullpath) + def delete(self): + self.fs.delete(self.fullpath) + def is_same_project(self, other: "Location"): return other.fs == self.fs and other.project == self.project diff --git a/mlem/core/objects.py b/mlem/core/objects.py index 77e8f845..9294f997 100644 --- a/mlem/core/objects.py +++ b/mlem/core/objects.py @@ -1,7 +1,9 @@ """ Base classes for meta objects in MLEM """ +import contextlib import hashlib +import itertools import os import posixpath import time @@ -9,8 +11,10 @@ from enum import Enum from functools import partial from typing import ( + TYPE_CHECKING, Any, ClassVar, + ContextManager, Dict, Generic, Iterable, @@ -23,13 +27,15 @@ overload, ) +import fsspec from fsspec import AbstractFileSystem from fsspec.implementations.local import LocalFileSystem from pydantic import ValidationError, parse_obj_as, validator -from typing_extensions import Literal +from typing_extensions import Literal, TypeAlias from yaml import safe_dump, safe_load from mlem.config import project_config +from mlem.constants import MLEM_STATE_DIR, MLEM_STATE_EXT from mlem.core.artifacts import ( Artifacts, FSSpecStorage, @@ -40,9 +46,12 @@ from mlem.core.data_type import DataReader, DataType from mlem.core.errors import ( DeploymentError, + MlemError, MlemObjectNotFound, MlemObjectNotSavedError, MlemProjectNotFound, + WrongABCType, + WrongMetaSubType, WrongMetaType, ) from mlem.core.meta_io import MLEM_DIR, MLEM_EXT, Location, get_path_by_fs_path @@ -50,9 +59,19 @@ from mlem.core.requirements import Requirements from mlem.polydantic.lazy import lazy_field from mlem.ui import EMOJI_LINK, EMOJI_LOAD, EMOJI_SAVE, echo, no_echo +from mlem.utils.fslock import FSLock from mlem.utils.path import make_posix from mlem.utils.root import find_project_root +if TYPE_CHECKING: + from pydantic.typing import ( + AbstractSetIntStr, + MappingIntStrAny, + TupleGenerator, + ) + + from mlem.runtime.client import Client + T = TypeVar("T", bound="MlemObject") @@ -350,10 +369,16 @@ def meta_hash(self): return hashlib.md5(safe_dump(self.dict()).encode("utf8")).hexdigest() +TL = TypeVar("TL", bound="MlemLink") + + class MlemLink(MlemObject): """Link is a special MlemObject that represents a MlemObject in a different location""" + object_type: ClassVar = "link" + __link_type_map__: ClassVar[Dict[str, Type["TypedLink"]]] = {} + path: str """path to object""" project: Optional[str] = None @@ -363,8 +388,6 @@ class MlemLink(MlemObject): link_type: str """type of underlying object""" - object_type: ClassVar = "link" - @property def link_cls(self) -> Type[MlemObject]: return MlemObject.__type_map__[self.link_type] @@ -442,6 +465,64 @@ def from_location( else link_type, ) + @classmethod + def typed_link( + cls: Type["MlemLink"], type_: Union[str, Type[MlemObject]] + ) -> Type["MlemLink"]: + type_name = type_ if isinstance(type_, str) else type_.object_type + + class TypedMlemLink(TypedLink): + object_type: ClassVar = f"link_{type_name}" + _link_type: ClassVar = type_name + link_type = type_name + + def _iter( + self, + to_dict: bool = False, + by_alias: bool = False, + include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, + exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + ) -> "TupleGenerator": + exclude = exclude or set() + if isinstance(exclude, set): + exclude.update(("type", "object_type", "link_type")) + elif isinstance(exclude, dict): + exclude.update( + {"type": True, "object_type": True, "link_type": True} + ) + return super()._iter( + to_dict, + by_alias, + include, + exclude, + exclude_unset, + exclude_defaults, + exclude_none, + ) + + TypedMlemLink.__doc__ = f"""Link to {type_name} MLEM object""" + return TypedMlemLink + + @property + def typed(self) -> "TypedLink": + type_ = MlemLink.__link_type_map__[self.link_type] + return type_(**self.dict()) + + +class TypedLink(MlemLink, ABC): + """Base class for specific type link classes""" + + __abstract__: ClassVar = True + object_type: ClassVar = "_typed_link" + _link_type: ClassVar + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + MlemLink.__link_type_map__[cls._link_type] = cls + class _WithArtifacts(ABC, MlemObject): """Special subtype of MlemObject that can have files (artifacts) attached""" @@ -731,14 +812,12 @@ class Config: type_root = True abs_name: ClassVar[str] = "deploy_state" + type: ClassVar[str] + allow_default: ClassVar[bool] = False model_hash: Optional[str] = None """hash of deployed model meta""" - @abstractmethod - def get_client(self): - raise NotImplementedError - DT = TypeVar("DT", bound="MlemDeployment") @@ -773,6 +852,11 @@ def check_type(self, deploy: "MlemDeployment"): f"Meta of the {self.type} deployment should be {self.deploy_type}, not {deploy.__class__}" ) + def __init_subclass__(cls): + if hasattr(cls, "deploy_type"): + cls.deploy_type.env_type = cls + super().__init_subclass__() + class DeployStatus(str, Enum): """Enum with deployment statuses""" @@ -785,7 +869,174 @@ class DeployStatus(str, Enum): RUNNING = "running" -class MlemDeployment(MlemObject): +ST = TypeVar("ST", bound=DeployState) + + +@contextlib.contextmanager +def _no_lock(): + yield + + +class StateManager(MlemABC): + abs_name: ClassVar = "state" + type: ClassVar[str] + + class Config: + type_root = True + default_type = "fsspec" + + @abstractmethod + def _get_state( + self, deployment: "MlemDeployment" + ) -> Optional[DeployState]: + pass + + def get_state( + self, deployment: "MlemDeployment", state_type: Type[ST] + ) -> Optional[ST]: + state = self._get_state(deployment) + if state is not None and not isinstance(state, state_type): + raise DeploymentError( + f"State for {deployment.name} is {state.type}, but should be {state_type.type}" + ) + return state + + @abstractmethod + def update_state(self, deployment: "MlemDeployment", state: DeployState): + pass + + @abstractmethod + def purge_state(self, deployment: "MlemDeployment"): + pass + + @abstractmethod + def lock(self, deployment: "MlemDeployment") -> ContextManager: + return _no_lock() + + +class LocalFileStateManager(StateManager): + """StateManager that stores state as yaml file locally""" + + type: ClassVar = "local" + + locking: bool = True + """Enable state locking""" + lock_timeout: float = 10 * 60 + """Lock timeout""" + + @staticmethod + def location(deployment: "MlemDeployment") -> Location: + loc = deployment.loc.copy() + loc.update_path(loc.path + MLEM_STATE_EXT) + return loc + + def _get_state( + self, deployment: "MlemDeployment" + ) -> Optional[DeployState]: + try: + with self.location(deployment).open("r") as f: + return parse_obj_as(DeployState, safe_load(f)) + except FileNotFoundError: + return None + + def update_state(self, deployment: "MlemDeployment", state: DeployState): + with self.location(deployment).open("w", make_dir=True) as f: + safe_dump(state.dict(), f) + + def purge_state(self, deployment: "MlemDeployment"): + loc = self.location(deployment) + if loc.exists(): + loc.delete() + + def lock(self, deployment: "MlemDeployment"): + if self.locking: + loc = self.location(deployment) + dirname, filename = posixpath.split(loc.fullpath) + return FSLock( + loc.fs, + dirname, + filename, + timeout=self.lock_timeout, + ) + return super().lock(deployment) + + +class FSSpecStateManager(StateManager): + """StateManager that stores state as yaml file in fsspec-supported filesystem""" + + type: ClassVar = "fsspec" + + class Config: + exclude = {"fs", "path"} + arbitrary_types_allowed = True + + uri: str + """URI of directory to store state files""" + storage_options: Dict = {} + """Additional options""" + locking: bool = True + """Enable state locking""" + lock_timeout: float = 10 * 60 + """Lock timeout""" + + fs: Optional[AbstractFileSystem] = None + """Filesystem cache""" + path: str = "" + """Path inside filesystem cache""" + + def get_fs(self) -> AbstractFileSystem: + if self.fs is None: + self.fs, _, (self.path,) = fsspec.get_fs_token_paths( + self.uri, storage_options=self.storage_options + ) + return self.fs + + def _get_path(self, deployment: "MlemDeployment"): + self.get_fs() + return posixpath.join(self.path, MLEM_STATE_DIR, deployment.name) + + def _get_state( + self, deployment: "MlemDeployment" + ) -> Optional[DeployState]: + try: + with self.get_fs().open(self._get_path(deployment)) as f: + return parse_obj_as(DeployState, safe_load(f)) + except FileNotFoundError: + return None + + def update_state(self, deployment: "MlemDeployment", state: DeployState): + path = self._get_path(deployment) + fs = self.get_fs() + fs.makedirs(posixpath.dirname(path), exist_ok=True) + with fs.open(path, "w") as f: + safe_dump(state.dict(), f) + + def purge_state(self, deployment: "MlemDeployment"): + path = self._get_path(deployment) + fs = self.get_fs() + if fs.exists(path): + fs.delete(path) + + def lock(self, deployment: "MlemDeployment"): + if self.locking: + fullpath = self._get_path(deployment) + dirname, filename = posixpath.split(fullpath) + return FSLock( + self.get_fs(), + dirname, + filename, + timeout=self.lock_timeout, + ) + return super().lock(deployment) + + +EnvLink: TypeAlias = MlemLink.typed_link(MlemEnv) +ModelLink: TypeAlias = MlemLink.typed_link(MlemModel) + +ET = TypeVar("ET", bound=MlemEnv) + + +class MlemDeployment(MlemObject, Generic[ST, ET]): """Base class for deployment metadata""" object_type: ClassVar = "deployment" @@ -793,36 +1044,139 @@ class MlemDeployment(MlemObject): class Config: type_root = True type_field = "type" - exclude = {"model", "env"} + exclude = {"model_cache", "env_cache"} use_enum_values = True abs_name: ClassVar = "deployment" type: ClassVar[str] + state_type: ClassVar[Type[ST]] + env_type: ClassVar[Type[ET]] - env_link: MlemLink - """Enironment to use""" - env: Optional[MlemEnv] + env: Union[str, MlemEnv, EnvLink, None] = None """Enironment to use""" - model_link: MlemLink - """Model to use""" - model: Optional[MlemModel] + env_cache: Optional[MlemEnv] = None + model: Union[ModelLink, str] """Model to use""" - state: Optional[DeployState] - """state""" + model_cache: Optional[MlemModel] = None + state_manager: Optional[StateManager] + """State manager used""" - def get_env(self): - if self.env is None: - self.env = self.env_link.bind(self.loc).load_link( - force_type=MlemEnv - ) - return self.env + @validator("state_manager", always=True) + def default_state_manager( # pylint: disable=no-self-argument + cls, value # noqa: B902 + ): + if value is None: + value = project_config("").state + return value - def get_model(self): - if self.model is None: - self.model = self.model_link.bind(self.loc).load_link( - force_type=MlemModel - ) - return self.model + @property + def _state_manager(self) -> StateManager: + if self.state_manager is None: + return LocalFileStateManager() + return self.state_manager + + def get_state(self) -> ST: + return ( + self._state_manager.get_state(self, self.state_type) + or self.state_type() + ) + + def lock_state(self): + return self._state_manager.lock(self) + + def update_state(self, state: ST): + self._state_manager.update_state(self, state) + + def purge_state(self): + self._state_manager.purge_state(self) + + def get_client(self, state: DeployState = None) -> "Client": + if state is not None and not isinstance(state, self.state_type): + raise WrongABCType(state, self.state_type) + return self._get_client(state or self.get_state()) + + @abstractmethod + def _get_client(self, state: ST) -> "Client": + raise NotImplementedError + + @validator("env") + def validate_env(cls, value): # pylint: disable=no-self-argument + if isinstance(value, MlemLink): + if value.project is None: + return value.path + if not isinstance(value, EnvLink): + return EnvLink(**value.dict()) + if isinstance(value, str): + return make_posix(value) + return value + + def get_env(self) -> ET: + if self.env_cache is None: + if isinstance(self.env, str): + link = MlemLink( + path=self.env, + project=self.loc.project + if not os.path.isabs(self.env) + else None, + rev=self.loc.rev if not os.path.isabs(self.env) else None, + link_type=MlemEnv.object_type, + ) + self.env_cache = link.load_link(force_type=MlemEnv) + elif isinstance(self.env, MlemEnv): + self.env_cache = self.env + elif isinstance(self.env, MlemLink): + self.env_cache = self.env.load_link(force_type=MlemEnv) + elif self.env is None: + try: + self.env_cache = self.env_type() + except ValidationError as e: + raise MlemError( + f"{self.env_type} env does not have default value, please set `env` field" + ) from e + else: + raise ValueError( + "env should be one of [str, MlemLink, MlemEnv]" + ) + if not isinstance(self.env_cache, self.env_type): + raise WrongMetaSubType(self.env_cache, self.env_type) + return self.env_cache + + @validator("model") + def validate_model(cls, value): # pylint: disable=no-self-argument + if isinstance(value, MlemLink): + if value.project is None: + return value.path + if not isinstance(value, ModelLink): + return ModelLink(**value.dict()) + if isinstance(value, str): + return make_posix(value) + return value + + def get_model(self) -> MlemModel: + if self.model_cache is None: + if isinstance(self.model, str): + link = MlemLink( + path=self.model, + project=self.loc.project + if not os.path.isabs(self.model) + else None, + rev=self.loc.rev + if not os.path.isabs(self.model) + else None, + link_type=MlemModel.object_type, + ) + if self.is_saved: + link.bind(self.loc) + self.model_cache = link.load_link(force_type=MlemModel) + elif isinstance(self.model, MlemLink): + if self.is_saved: + self.model.bind(self.loc) + self.model_cache = self.model.load_link(force_type=MlemModel) + else: + raise ValueError( + f"model field should be either str or MlemLink instance, got {self.model.__class__}" + ) + return self.model_cache def run(self): return self.get_env().deploy(self) @@ -842,7 +1196,7 @@ def wait_for_status( DeployStatus, Iterable[DeployStatus] ] = None, raise_on_timeout: bool = True, - ): + ) -> object: if isinstance(status, DeployStatus): statuses = {status} else: @@ -854,7 +1208,12 @@ def wait_for_status( allowed = set(allowed_intermediate) current = DeployStatus.UNKNOWN - for _ in range(times): + iterator: Iterable + if times == 0: + iterator = itertools.count() + else: + iterator = range(times) + for _ in iterator: current = self.get_status(raise_on_error=False) if current in statuses: return True @@ -866,25 +1225,33 @@ def wait_for_status( return False time.sleep(timeout) if raise_on_timeout: + # TODO: count actual time passed raise DeploymentError( f"Deployment status is still {current} after {times * timeout} seconds" ) return False - def model_changed(self): - if self.state is None or self.state.model_hash is None: + def model_changed(self, state: Optional[ST] = None): + state = state or self.get_state() + if state.model_hash is None: return True - return self.get_model().meta_hash() != self.state.model_hash + return self.get_model().meta_hash() != state.model_hash - def update_model_hash(self, model: Optional[MlemModel] = None): + def update_model_hash( + self, + model: Optional[MlemModel] = None, + state: Optional[ST] = None, + update_state: bool = True, + ): model = model or self.get_model() - if self.state is None: - return - self.state.model_hash = model.meta_hash() + state = state or self.get_state() + state.model_hash = model.meta_hash() + if update_state: + self.update_state(state) def replace_model(self, model: MlemModel): - self.model = model - self.model_link = self.model.make_link() + self.model = model.make_link().typed + self.model_cache = model def find_object( diff --git a/mlem/core/requirements.py b/mlem/core/requirements.py index a0348e26..6f04f25e 100644 --- a/mlem/core/requirements.py +++ b/mlem/core/requirements.py @@ -2,6 +2,7 @@ Base classes to work with requirements which come with ML models and data """ import base64 +import collections import contextlib import glob import itertools @@ -492,7 +493,8 @@ def resolve_requirements(other: "AnyRequirements") -> Requirements: if isinstance(other[0], str): return Requirements( __root__=[ - InstallableRequirement.from_str(r) for r in set(other) + InstallableRequirement.from_str(r) + for r in collections.OrderedDict.fromkeys(other) ] ) diff --git a/mlem/ext.py b/mlem/ext.py index 1aecf256..31150828 100644 --- a/mlem/ext.py +++ b/mlem/ext.py @@ -108,6 +108,7 @@ class ExtensionLoader: Extension("mlem.contrib.github", [], True), Extension("mlem.contrib.gitlabfs", [], True), Extension("mlem.contrib.bitbucketfs", [], True), + Extension("mlem.contrib.sagemaker", ["sagemaker", "boto3"], False), ) _loaded_extensions: Dict[Extension, ModuleType] = {} diff --git a/mlem/polydantic/core.py b/mlem/polydantic/core.py index 6f2b8910..c5b86ec9 100644 --- a/mlem/polydantic/core.py +++ b/mlem/polydantic/core.py @@ -71,6 +71,8 @@ def validate(cls, value): return super().validate(value) if isinstance(value, str): value = {cls.__config__.type_field: value} + if not isinstance(value, dict): + raise ValueError(f"{value} is neither dict nor {cls}") value = value.copy() type_name = value.pop( cls.__config__.type_field, cls.__config__.default_type @@ -108,15 +110,21 @@ def _iter( exclude_defaults=exclude_defaults, exclude_none=exclude_none, ) + exclude = exclude or set() if self.__is_root__: alias = self.__get_alias__(self.__config__.type_field) - if not exclude_defaults or alias != self.__config__.default_type: + if ( + not exclude_defaults or alias != self.__config__.default_type + ) and self.__config__.type_field not in exclude: yield self.__config__.type_field, alias for parent in self.__iter_parents__(include_top=False): alias = parent.__get_alias__() - if not exclude_defaults or alias != parent.__config__.default_type: - yield parent.__type_field__(), alias + parent_type_field = parent.__type_field__() + if ( + not exclude_defaults or alias != parent.__config__.default_type + ) and parent_type_field not in exclude: + yield parent_type_field, alias def __iter__(self): """Add alias field""" diff --git a/mlem/ui.py b/mlem/ui.py index e66aa010..a42fb100 100644 --- a/mlem/ui.py +++ b/mlem/ui.py @@ -100,3 +100,4 @@ def bold(text): EMOJI_BUILD = emoji("🛠") EMOJI_UPLOAD = emoji("🔼") EMOJI_STOP = emoji("🔻") +EMOJI_KEY = emoji("🗝") diff --git a/mlem/utils/fslock.py b/mlem/utils/fslock.py new file mode 100644 index 00000000..396d3c8f --- /dev/null +++ b/mlem/utils/fslock.py @@ -0,0 +1,113 @@ +import posixpath +import random +import re +import time +from typing import List, Tuple + +from fsspec import AbstractFileSystem + +from mlem.utils.path import make_posix + +LOCK_EXT = "lock" + + +class LockTimeoutError(Exception): + pass + + +class FSLock: + def __init__( + self, + fs: AbstractFileSystem, + dirpath: str, + name: str, + timeout: float = None, + retry_timeout: float = 0.1, + *, + salt=None, + ): + self.fs = fs + self.dirpath = make_posix(str(dirpath)) + self.name = name + self.timeout = timeout + self.retry_timeout = retry_timeout + self._salt = salt + self._timestamp = None + + @property + def salt(self): + if self._salt is None: + self._salt = random.randint(10**3, 10**4) + return self._salt + + @property + def timestamp(self): + if self._timestamp is None: + self._timestamp = time.time_ns() + return self._timestamp + + @property + def lock_filename(self): + return f"{self.name}.{self.timestamp}.{self.salt}.{LOCK_EXT}" + + @property + def lock_path(self): + return posixpath.join(self.dirpath, self.lock_filename) + + def _list_locks(self) -> List[Tuple[int, int]]: + locks = [ + posixpath.basename(make_posix(f)) + for f in self.fs.listdir(self.dirpath, detail=False) + ] + locks = [ + f[len(self.name) :] + for f in locks + if f.startswith(self.name) and f.endswith(LOCK_EXT) + ] + pat = re.compile(rf"\.(\d+)\.(\d+)\.{LOCK_EXT}") + locks_re = [pat.match(lock) for lock in locks] + return [ + (int(m.group(1)), int(m.group(2))) + for m in locks_re + if m is not None + ] + + def _double_check(self): + locks = self._list_locks() + if not locks: + return False + minlock = min(locks) + c = minlock == (self._timestamp, self._salt) + return c + + def _write_lockfile(self): + self.fs.touch(self.lock_path) + + def _clear(self): + self._timestamp = None + self._salt = None + + def _delete_lockfile(self): + try: + self.fs.delete(self.lock_path) + except FileNotFoundError: + pass + + def __enter__(self): + start = time.time() + + self._write_lockfile() + time.sleep(self.retry_timeout) + + while not self._double_check(): + if self.timeout is not None and time.time() - start > self.timeout: + self._delete_lockfile() + self._clear() + raise LockTimeoutError( + f"Lock aquiring timeouted after {self.timeout}" + ) + time.sleep(self.retry_timeout) + + def __exit__(self, exc_type, exc_val, exc_tb): + self._delete_lockfile() + self._clear() diff --git a/mlem/utils/templates.py b/mlem/utils/templates.py index d86bd557..8cc2f56e 100644 --- a/mlem/utils/templates.py +++ b/mlem/utils/templates.py @@ -20,7 +20,7 @@ def prepare_dict(self): def generate(self, **additional): j2 = Environment( - loader=FileSystemLoader([self.TEMPLATE_DIR] + self.templates_dir), + loader=FileSystemLoader(self.templates_dir + [self.TEMPLATE_DIR]), undefined=StrictUndefined, ) template = j2.get_template(self.TEMPLATE_FILE) diff --git a/setup.py b/setup.py index 6e940fe1..dd456f68 100644 --- a/setup.py +++ b/setup.py @@ -66,7 +66,7 @@ "xgboost": ["xgboost"], "lightgbm": ["lightgbm"], "fastapi": ["uvicorn", "fastapi"], - # "sagemaker": ["boto3==1.19.12", "sagemaker"], + "sagemaker": ["boto3", "sagemaker"], "torch": ["torch"], "tensorflow": ["tensorflow"], "azure": ["adlfs>=2021.10.0", "azure-identity>=1.4.0", "knack"], @@ -178,7 +178,6 @@ "model_type.onnx = mlem.contrib.onnx:ONNXModel", "data_type.dataframe = mlem.contrib.pandas:DataFrameType", "import.pandas = mlem.contrib.pandas:PandasImport", - "import.torch = mlem.contrib.torch:TorchModelImport", "data_reader.pandas = mlem.contrib.pandas:PandasReader", "data_reader.pandas_series = mlem.contrib.pandas:PandasSeriesReader", "data_writer.pandas_series = mlem.contrib.pandas:PandasSeriesWriter", @@ -188,6 +187,12 @@ "builder.whl = mlem.contrib.pip.base:WhlBuilder", "client.rmq = mlem.contrib.rabbitmq:RabbitMQClient", "server.rmq = mlem.contrib.rabbitmq:RabbitMQServer", + "docker_registry.ecr = mlem.contrib.sagemaker.build:ECRegistry", + "client.sagemaker = mlem.contrib.sagemaker.meta:SagemakerClient", + "deploy_state.sagemaker = mlem.contrib.sagemaker.meta:SagemakerDeployState", + "deployment.sagemaker = mlem.contrib.sagemaker.meta:SagemakerDeployment", + "env.sagemaker = mlem.contrib.sagemaker.meta:SagemakerEnv", + "server._sagemaker = mlem.contrib.sagemaker.runtime:SageMakerServer", "model_type.sklearn = mlem.contrib.sklearn:SklearnModel", "model_type.sklearn_pipeline = mlem.contrib.sklearn:SklearnPipelineType", "model_type.tf_keras = mlem.contrib.tensorflow:TFKerasModel", @@ -197,6 +202,7 @@ "data_writer.tf_tensor = mlem.contrib.tensorflow:TFTensorWriter", "model_type.torch = mlem.contrib.torch:TorchModel", "model_io.torch_io = mlem.contrib.torch:TorchModelIO", + "import.torch = mlem.contrib.torch:TorchModelImport", "data_type.torch = mlem.contrib.torch:TorchTensorDataType", "data_reader.torch = mlem.contrib.torch:TorchTensorReader", "data_writer.torch = mlem.contrib.torch:TorchTensorWriter", @@ -210,6 +216,8 @@ "docker = mlem.contrib.docker.context:DockerConfig", "heroku = mlem.contrib.heroku.config:HerokuConfig", "pandas = mlem.contrib.pandas:PandasConfig", + "aws = mlem.contrib.sagemaker.meta:AWSConfig", + "sagemaker = mlem.contrib.sagemaker.runtime:SageMakerServerConfig", ], }, zip_safe=False, diff --git a/tests/cli/test_deployment.py b/tests/cli/test_deployment.py index a719e1e3..587041e3 100644 --- a/tests/cli/test_deployment.py +++ b/tests/cli/test_deployment.py @@ -3,6 +3,7 @@ import pytest from numpy import ndarray +from yaml import safe_load from mlem.api import load from mlem.core.meta_io import MLEM_EXT @@ -15,22 +16,14 @@ MlemLink, ) from mlem.runtime.client import Client, HTTPClient +from mlem.utils.path import make_posix from tests.cli.conftest import Runner -@pytest.fixture -def mock_deploy_get_client(mocker, request_get_mock, request_post_mock): - return mocker.patch( - "tests.cli.test_deployment.DeployStateMock.get_client", - return_value=HTTPClient(host="", port=None), - ) - - class DeployStateMock(DeployState): """mock""" - def get_client(self) -> Client: - pass + allow_default: ClassVar = True class MlemDeploymentMock(MlemDeployment): @@ -40,12 +33,15 @@ class Config: use_enum_values = True type: ClassVar = "mock" + state_type: ClassVar = DeployStateMock + status: DeployStatus = DeployStatus.NOT_DEPLOYED """status""" param: str = "" """param""" - state: DeployState = DeployStateMock() - """state""" + + def _get_client(self, state) -> Client: + return HTTPClient(host="", port=None) class MlemEnvMock(MlemEnv): @@ -80,12 +76,137 @@ def mock_deploy_path(tmp_path, mock_env_path, model_meta_saved_single): path = os.path.join(tmp_path, "deployname") MlemDeploymentMock( param="bbb", - model_link=model_meta_saved_single.make_link(), - env_link=MlemLink(path=mock_env_path, link_type="env"), + model=model_meta_saved_single.make_link(), + model_cache=model_meta_saved_single, + env=mock_env_path, ).dump(path) return path +def test_deploy_meta_str_model(mlem_project, model_meta, mock_env_path): + model_meta.dump("model", project=mlem_project) + + deployment = MlemDeploymentMock(model="model", env=mock_env_path) + deployment.dump("deployment", project=mlem_project) + + with deployment.loc.open("r") as f: + data = safe_load(f) + assert data == { + "model": "model", + "object_type": "deployment", + "type": "mock", + "env": make_posix(mock_env_path), + } + + deployment2 = load_meta( + "deployment", project=mlem_project, force_type=MlemDeployment + ) + assert deployment2 == deployment + assert deployment2.get_model() == model_meta + assert deployment2.get_env() == load_meta(mock_env_path) + + +def test_deploy_meta_link_str_model(mlem_project, model_meta, mock_env_path): + model_meta.dump("model", project=mlem_project) + + deployment = MlemDeploymentMock( + model=MlemLink(path="model", link_type="model"), + env=MlemLink(path=mock_env_path, link_type="env"), + ) + deployment.dump("deployment", project=mlem_project) + + with deployment.loc.open("r") as f: + data = safe_load(f) + assert data == { + "model": "model", + "object_type": "deployment", + "type": "mock", + "env": make_posix(mock_env_path), + } + + deployment2 = load_meta( + "deployment", project=mlem_project, force_type=MlemDeployment + ) + assert deployment2 == deployment + assert deployment2.get_model() == model_meta + assert deployment2.get_env() == load_meta(mock_env_path) + + +def test_deploy_meta_link_model(mlem_project, model_meta, mock_env_path): + model_meta.dump("model", project=mlem_project) + load_meta(mock_env_path).clone("project_env", project=mlem_project) + + deployment = MlemDeploymentMock( + model=MlemLink(path="model", project=mlem_project, link_type="model"), + env=MlemLink( + path="project_env", project=mlem_project, link_type="env" + ), + ) + deployment.dump("deployment", project=mlem_project) + + with deployment.loc.open("r") as f: + data = safe_load(f) + assert data == { + "model": {"path": "model", "project": make_posix(mlem_project)}, + "object_type": "deployment", + "type": "mock", + "env": { + "path": "project_env", + "project": make_posix(mlem_project), + }, + } + + deployment2 = load_meta( + "deployment", project=mlem_project, force_type=MlemDeployment + ) + assert deployment2 == deployment + assert deployment2.get_model() == model_meta + assert deployment2.get_env() == load_meta(mock_env_path) + + +def test_deploy_meta_link_model_no_project(tmpdir, model_meta, mock_env_path): + model_path = os.path.join(tmpdir, "model") + model_meta.dump(model_path) + + deployment = MlemDeploymentMock( + model=MlemLink(path="model", link_type="model"), + env=MlemLink(path=mock_env_path, link_type="env"), + ) + deployment_path = os.path.join(tmpdir, "deployment") + deployment.dump(deployment_path) + + with deployment.loc.open("r") as f: + data = safe_load(f) + assert data == { + "model": "model", + "object_type": "deployment", + "type": "mock", + "env": make_posix(mock_env_path), + } + + deployment2 = load_meta(deployment_path, force_type=MlemDeployment) + assert deployment2 == deployment + assert deployment2.get_model() == model_meta + assert deployment2.get_env() == load_meta(mock_env_path) + + +def test_read_relative_model_from_remote_deploy_meta(): + """TODO + path = "s3://..." + model.dump(path / "model"); + deployment = MlemDeploymentMock( + model=model, + env=MlemLink( + path=mock_env_path, link_type="env" + ), + ) + deployment.dump(path / deployment) + + deployment2 = load_meta(...) + deployment2.get_model() + """ + + def test_deploy_create_new( runner: Runner, model_meta_saved_single, mock_env_path, tmp_path ): @@ -130,14 +251,15 @@ def test_deploy_apply( runner: Runner, mock_deploy_path, data_path, - mock_deploy_get_client, tmp_path, + request_get_mock, + request_post_mock, ): path = os.path.join(tmp_path, "output") result = runner.invoke( f"deploy apply {mock_deploy_path} {data_path} -o {path}".split() ) - assert result.exit_code == 0, result.output + assert result.exit_code == 0, (result.output, result.exception) meta = load_meta(mock_deploy_path) assert isinstance(meta, MlemDeploymentMock) assert meta.status == DeployStatus.NOT_DEPLOYED diff --git a/tests/contrib/test_docker/resources/dockerfile.j2 b/tests/contrib/test_docker/resources/dockerfile.j2 new file mode 100644 index 00000000..e6a7521f --- /dev/null +++ b/tests/contrib/test_docker/resources/dockerfile.j2 @@ -0,0 +1,3 @@ +FROM alpine + +CMD sleep infinity diff --git a/tests/contrib/test_docker/test_deploy.py b/tests/contrib/test_docker/test_deploy.py index 1ebc37fb..d36a3bf1 100644 --- a/tests/contrib/test_docker/test_deploy.py +++ b/tests/contrib/test_docker/test_deploy.py @@ -6,15 +6,18 @@ import pytest from requests.exceptions import HTTPError +from mlem.api import deploy from mlem.contrib.docker.base import ( DockerContainer, DockerContainerState, DockerEnv, DockerImage, ) +from mlem.contrib.docker.context import DockerBuildArgs from mlem.contrib.fastapi import FastAPIServer from mlem.core.errors import DeploymentError from mlem.core.objects import DeployStatus +from tests.conftest import resource_path from tests.contrib.test_docker.conftest import docker_test IMAGE_NAME = "mike0sv/ebaklya" @@ -26,7 +29,15 @@ @pytest.fixture(scope="session") -def _test_images(tmpdir_factory, dockerenv_local, dockerenv_remote): +def _test_images(dockerenv_local): + with dockerenv_local.daemon.client() as client: + client.images.pull(IMAGE_NAME, "latest") + + +@pytest.fixture(scope="session") +def _test_images_remote( + tmpdir_factory, dockerenv_local, dockerenv_remote, _test_images +): with dockerenv_local.daemon.client() as client: tag_name = f"{dockerenv_remote.registry.get_host()}/{REPOSITORY_NAME}/{IMAGE_NAME}" client.images.pull(IMAGE_NAME, "latest").tag(tag_name) @@ -57,7 +68,7 @@ def test_run_default_registry( @docker_test def test_run_remote_registry( - dockerenv_remote, _test_images, model_meta_saved_single + dockerenv_remote, _test_images_remote, model_meta_saved_single ): _check_runner(IMAGE_NAME, dockerenv_remote, model_meta_saved_single) @@ -76,7 +87,7 @@ def test_run_local_image_name_that_will_never_exist( @docker_test def test_run_local_fail_inside_container( - dockerenv_remote, _test_images, model_meta_saved_single + dockerenv_remote, _test_images_remote, model_meta_saved_single ): with pytest.raises(DeploymentError): _check_runner( @@ -86,19 +97,47 @@ def test_run_local_fail_inside_container( ) +@docker_test +def test_deploy_full( + tmp_path_factory, dockerenv_local, model_meta_saved_single +): + meta_path = tmp_path_factory.mktemp("deploy-meta") + meta = deploy( + str(meta_path), + model_meta_saved_single, + dockerenv_local, + args=DockerBuildArgs(templates_dir=[resource_path(__file__)]), + server="fastapi", + container_name="test_full_deploy", + ) + + meta.wait_for_status( + DeployStatus.RUNNING, + allowed_intermediate=[ + DeployStatus.NOT_DEPLOYED, + DeployStatus.STARTING, + ], + times=50, + ) + assert meta.get_status() == DeployStatus.RUNNING + + def _check_runner(img, env: DockerEnv, model): with tempfile.TemporaryDirectory() as tmpdir: instance = DockerContainer( container_name=CONTAINER_NAME, port_mapping={80: 8008}, - state=DockerContainerState(image=DockerImage(name=img)), server=FastAPIServer(), - model_link=model.make_link(), - env_link=env.make_link(), + model=model.make_link(), + env=env, rm=False, ) - instance.update_model_hash(model) instance.dump(os.path.join(tmpdir, "deploy")) + instance.update_state( + DockerContainerState( + image=DockerImage(name=img), model_hash=model.meta_hash() + ) + ) assert env.get_status(instance) == DeployStatus.NOT_DEPLOYED env.deploy(instance) diff --git a/tests/contrib/test_heroku.py b/tests/contrib/test_heroku.py index 8558b0fc..aff9ad57 100644 --- a/tests/contrib/test_heroku.py +++ b/tests/contrib/test_heroku.py @@ -90,8 +90,8 @@ def test_create_app(heroku_app_name, heroku_env, model): name = heroku_app_name("create-app") heroku_deploy = HerokuDeployment( app_name=name, - env_link=heroku_env.make_link(), - model_link=model.make_link(), + env=heroku_env, + model=model.make_link(), team=HEROKU_TEAM, ) create_app(heroku_deploy) @@ -120,7 +120,8 @@ def test_state_ensured_app(): def _check_heroku_deployment(meta): assert isinstance(meta, HerokuDeployment) - assert heroku_api_request("GET", f"/apps/{meta.state.ensured_app.name}") + state = meta.get_state() + assert heroku_api_request("GET", f"/apps/{state.ensured_app.name}") meta.wait_for_status( DeployStatus.RUNNING, allowed_intermediate=[ @@ -132,7 +133,7 @@ def _check_heroku_deployment(meta): assert meta.get_status() == DeployStatus.RUNNING time.sleep(10) docs_page = requests.post( - meta.state.ensured_app.web_url + "predict", + state.ensured_app.web_url + "predict", json={ "data": { "values": [ @@ -159,7 +160,7 @@ def is_not_crash(err, *args): # pylint: disable=unused-argument return not needs_another_try -@flaky(rerun_filter=is_not_crash, max_runs=2) +@flaky(rerun_filter=is_not_crash, max_runs=1) @heroku @long @heroku_matrix @@ -186,7 +187,7 @@ def test_env_deploy_full( if CLEAR_APPS: meta.remove() - assert meta.state is None + assert meta.get_state() == HerokuState() meta.wait_for_status( DeployStatus.NOT_DEPLOYED, allowed_intermediate=DeployStatus.RUNNING, diff --git a/tests/core/test_objects.py b/tests/core/test_objects.py index 2576d9eb..e96f6176 100644 --- a/tests/core/test_objects.py +++ b/tests/core/test_objects.py @@ -20,6 +20,7 @@ MlemLink, MlemModel, MlemObject, + ModelLink, ) from mlem.core.requirements import InstallableRequirement, Requirements from tests.conftest import ( @@ -43,16 +44,17 @@ def get_status(self): def destroy(self): pass - def get_client(self): + +class MyMlemDeployment(MlemDeployment): + def _get_client(self, state): pass @pytest.fixture() def meta(): - return MlemDeployment( - env_link=MlemLink(path="", link_type="env"), - model_link=MlemLink(path="", link_type="model"), - state=MyDeployState(), + return MyMlemDeployment( + env="", + model=MlemLink(path="", link_type="model"), ) @@ -337,6 +339,13 @@ def test_double_link_load(filled_mlem_project): assert isinstance(model, MlemModel) +def test_typed_link(): + link = ModelLink(path="aaa") + assert link.dict() == {"path": "aaa"} + + assert parse_obj_as(ModelLink, {"path": "aaa"}) == link + + @long @need_test_repo_auth def test_load_link_from_rev(): diff --git a/tests/core/test_requirements.py b/tests/core/test_requirements.py index 15c5a094..e88c1b4a 100644 --- a/tests/core/test_requirements.py +++ b/tests/core/test_requirements.py @@ -143,6 +143,12 @@ def test_req_collection_main(tmpdir, postfix): } +def test_consistent_resolve_order(): + reqs = ["a", "b", "c"] + for _ in range(10): + assert resolve_requirements(reqs).modules == reqs + + # Copyright 2019 Zyfra # Copyright 2021 Iterative # diff --git a/tests/test_config.py b/tests/test_config.py index 12d386b6..4a6c4e83 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -2,6 +2,7 @@ from mlem.config import CONFIG_FILE_NAME, MlemConfig, project_config from mlem.constants import MLEM_DIR +from mlem.contrib.fastapi import FastAPIServer from mlem.core.artifacts import FSSpecStorage, LocalStorage from mlem.core.meta_io import get_fs from tests.conftest import long @@ -28,3 +29,7 @@ def test_loading_remote(s3_tmp_path, s3_storage_fs): with fs.open(path, "w") as f: f.write("core:\n ADDITIONAL_EXTENSIONS: ext1\n") assert project_config(path, fs=fs).additional_extensions == ["ext1"] + + +def test_default_server(): + assert project_config("").server == FastAPIServer() diff --git a/tests/test_ext.py b/tests/test_ext.py index 690f68ef..11968a4c 100644 --- a/tests/test_ext.py +++ b/tests/test_ext.py @@ -1,3 +1,6 @@ +import re +from pathlib import Path + from mlem import ExtensionLoader from mlem.config import MlemConfig, MlemConfigBase from mlem.utils.entrypoints import ( @@ -24,6 +27,19 @@ def test_find_implementations(): assert not i.startswith("None") +def _write_entrypoints(impls_sorted, section: str): + setup_path = Path(__file__).parent.parent / "setup.py" + with open(setup_path, encoding="utf8") as f: + setup_py = f.read() + impls_string = ",\n".join(f' "{i}"' for i in impls_sorted) + new_entrypoints = f'"{section}": [\n{impls_string},\n ]' + setup_py = re.subn(rf'"{section}": \[\n[^]]*]', new_entrypoints, setup_py)[ + 0 + ] + with open(setup_path, "w", encoding="utf8") as f: + f.write(setup_py) + + def test_all_impls_in_entrypoints(): # if this test fails, add new entrypoints (take the result of find_implementations()) to setup.py and # reinstall your dev copy of mlem to re-populate them @@ -33,15 +49,27 @@ def test_all_impls_in_entrypoints(): impls_sorted = sorted( impls, key=lambda x: tuple(x.split(" = ")[1].split(":")) ) - assert exts == set(impls), str(impls_sorted) + impls_set = set(impls) + if exts != impls_set: + _write_entrypoints(impls_sorted, "mlem.contrib") + assert ( + exts == impls_set + ), "New enrtypoints written to setup.py, please reinstall" def test_all_configs_in_entrypoints(): impls = find_implementations(MlemConfigBase) impls[MlemConfig] = f"{MlemConfig.__module__}:{MlemConfig.__name__}" - assert { + impls_sorted = sorted( + {f"{i.__config__.section} = {k}" for i, k in impls.items()}, + key=lambda x: tuple(x.split(" = ")[1].split(":")), + ) + exts = { e.entry for e in load_entrypoints(MLEM_CONFIG_ENTRY_POINT).values() - } == {f"{i.__config__.section} = {k}" for i, k in impls.items()} + } + if exts != set(impls_sorted): + _write_entrypoints(impls_sorted, "mlem.config") + assert exts == impls_sorted def test_all_ext_has_pip_extra(): diff --git a/tests/utils/test_fslock.py b/tests/utils/test_fslock.py new file mode 100644 index 00000000..3f93bec9 --- /dev/null +++ b/tests/utils/test_fslock.py @@ -0,0 +1,62 @@ +import os +import time +from threading import Thread + +from fsspec.implementations.local import LocalFileSystem + +from mlem.utils.fslock import LOCK_EXT, FSLock +from mlem.utils.path import make_posix + +NAME = "testlock" + + +# pylint: disable=protected-access +def test_fslock(tmpdir): + fs = LocalFileSystem() + lock = FSLock(fs, tmpdir, NAME) + + with lock: + assert lock._timestamp is not None + assert lock._salt is not None + lock_path = make_posix( + os.path.join( + tmpdir, f"{NAME}.{lock._timestamp}.{lock._salt}.{LOCK_EXT}" + ) + ) + assert lock.lock_path == lock_path + assert fs.exists(lock_path) + + assert lock._timestamp is None + assert lock._salt is None + assert not fs.exists(lock_path) + + +def _work(dirname, num): + time.sleep(0.3 + num / 5) + with FSLock(LocalFileSystem(), dirname, NAME, salt=num): + path = os.path.join(dirname, NAME) + if os.path.exists(path): + with open(path, "r+", encoding="utf8") as f: + data = f.read() + else: + data = "" + time.sleep(0.05) + with open(path, "w", encoding="utf8") as f: + f.write(data + f"{num}\n") + + +def test_fslock_concurrent(tmpdir): + start = 0 + end = 10 + threads = [ + Thread(target=_work, args=(tmpdir, n)) for n in range(start, end) + ] + for t in threads: + t.start() + for t in threads: + t.join() + with open(os.path.join(tmpdir, NAME), encoding="utf8") as f: + data = f.read() + + assert data.splitlines() == [str(i) for i in range(start, end)] + assert os.listdir(tmpdir) == [NAME] From 33031e04c5959929bd4f28b32c7c1ee36b9c879c Mon Sep 17 00:00:00 2001 From: Madhur Tandon <20173739+madhur-tandon@users.noreply.github.com> Date: Thu, 15 Sep 2022 13:13:33 +0530 Subject: [PATCH 3/4] add support for deployment to K8s (#374) * fix tests * Sagemaker deployments (#366) * WIP * its alive (kinda) * it works but it's ugly * little less ugly * lil fix * fix lint * fix lint * fix tests * fix tests * fix windows bugs * fix tests * fix tests * test that all configs in entrypoints * fix short tests * wip kubernetes support * use APIs to deploy and get status, deletion still pending * remove get client from state * fix param * fix jinja template * working remove and status * fix client * small fixes * attempt to add tests * setup github actions for k8s tests * fix linter * use predict method of client * allow registry to be configurable by cli * change calculation of host and port according to service type * re-enable k8s test as new workflow * fix daemon access in tests * make linter happy * fix fixtures * suggested fixes and refactor * make namespace as a separate field and use enums * use watcher to figure out when resources are deleted * check minikube status before loading kubeconfig in fixture * minor suggestions * use enums for comparisons as well * create abstract class for services for host and port info * raise error when service of type clusterIP * fix build and use tag as model hash * fix echo message * hot swapping of docker image deployed * remove unnecessary f-string * skip swapping when same hash is tried to be deployed again * suggested improvements * fix lint * fix pylint * suggested improvements * fix pylint * update entrypoints * add docstrings for K8sYamlBuildArgs * add docstrings for k8s service type classes * capitalize docstrings for fields * remove service type enum * Remove new workflow for K8s * remove duplicate methods * remove version from iterative-telemetry Co-authored-by: mike0sv --- .github/workflows/check-test-release.yml | 3 + mlem/contrib/kubernetes/__init__.py | 0 mlem/contrib/kubernetes/base.py | 219 ++++++++++++++++++ mlem/contrib/kubernetes/build.py | 30 +++ mlem/contrib/kubernetes/context.py | 55 +++++ mlem/contrib/kubernetes/resources.yaml.j2 | 47 ++++ mlem/contrib/kubernetes/service.py | 115 +++++++++ mlem/contrib/kubernetes/utils.py | 80 +++++++ mlem/core/errors.py | 4 + setup.cfg | 1 + setup.py | 8 + tests/conftest.py | 4 + tests/contrib/test_docker/test_context.py | 5 +- tests/contrib/test_kubernetes/__init__.py | 0 tests/contrib/test_kubernetes/conftest.py | 46 ++++ tests/contrib/test_kubernetes/test_base.py | 131 +++++++++++ tests/contrib/test_kubernetes/test_context.py | 150 ++++++++++++ tests/contrib/test_kubernetes/utils.py | 34 +++ 18 files changed, 928 insertions(+), 4 deletions(-) create mode 100644 mlem/contrib/kubernetes/__init__.py create mode 100644 mlem/contrib/kubernetes/base.py create mode 100644 mlem/contrib/kubernetes/build.py create mode 100644 mlem/contrib/kubernetes/context.py create mode 100644 mlem/contrib/kubernetes/resources.yaml.j2 create mode 100644 mlem/contrib/kubernetes/service.py create mode 100644 mlem/contrib/kubernetes/utils.py create mode 100644 tests/contrib/test_kubernetes/__init__.py create mode 100644 tests/contrib/test_kubernetes/conftest.py create mode 100644 tests/contrib/test_kubernetes/test_base.py create mode 100644 tests/contrib/test_kubernetes/test_context.py create mode 100644 tests/contrib/test_kubernetes/utils.py diff --git a/.github/workflows/check-test-release.yml b/.github/workflows/check-test-release.yml index 0baf5058..ff8bb36e 100644 --- a/.github/workflows/check-test-release.yml +++ b/.github/workflows/check-test-release.yml @@ -92,6 +92,9 @@ jobs: pip install pre-commit .[tests] - run: pre-commit run pylint -a -v --show-diff-on-failure if: matrix.python != '3.7' + - name: Start minikube + if: matrix.os == 'ubuntu-latest' && matrix.python == '3.9' + uses: medyagh/setup-minikube@master - name: Run tests timeout-minutes: 40 run: pytest diff --git a/mlem/contrib/kubernetes/__init__.py b/mlem/contrib/kubernetes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mlem/contrib/kubernetes/base.py b/mlem/contrib/kubernetes/base.py new file mode 100644 index 00000000..af5c3279 --- /dev/null +++ b/mlem/contrib/kubernetes/base.py @@ -0,0 +1,219 @@ +import os +from typing import ClassVar, List, Optional + +from kubernetes import client, config + +from mlem.config import project_config +from mlem.core.errors import DeploymentError, EndpointNotFound, MlemError +from mlem.core.objects import ( + DeployState, + DeployStatus, + MlemBuilder, + MlemDeployment, + MlemEnv, + MlemModel, +) +from mlem.runtime.client import Client, HTTPClient +from mlem.runtime.server import Server +from mlem.ui import EMOJI_OK, echo + +from ..docker.base import ( + DockerDaemon, + DockerImage, + DockerRegistry, + generate_docker_container_name, +) +from .build import build_k8s_docker +from .context import K8sYamlBuildArgs, K8sYamlGenerator +from .utils import create_k8s_resources, namespace_deleted, pod_is_running + +POD_STATE_MAPPING = { + "Pending": DeployStatus.STARTING, + "Running": DeployStatus.RUNNING, + "Succeeded": DeployStatus.STOPPED, + "Failed": DeployStatus.CRASHED, + "Unknown": DeployStatus.UNKNOWN, +} + + +class K8sDeploymentState(DeployState): + """DeployState implementation for Kubernetes deployments""" + + type: ClassVar = "kubernetes" + + image: Optional[DockerImage] = None + """Docker Image being used for Deployment""" + deployment_name: Optional[str] = None + """Name of Deployment""" + + +class K8sDeployment(MlemDeployment, K8sYamlBuildArgs): + """MlemDeployment implementation for Kubernetes deployments""" + + type: ClassVar = "kubernetes" + state_type: ClassVar = K8sDeploymentState + """Type of state for Kubernetes deployments""" + + server: Optional[Server] = None + """Type of Server to use, with options such as FastAPI, RabbitMQ etc.""" + registry: Optional[DockerRegistry] = DockerRegistry() + """Docker registry""" + daemon: Optional[DockerDaemon] = DockerDaemon(host="") + """Docker daemon""" + kube_config_file_path: Optional[str] = None + """Path for kube config file of the cluster""" + templates_dir: List[str] = [] + """List of dirs where templates reside""" + + def load_kube_config(self): + config.load_kube_config( + config_file=self.kube_config_file_path + or os.getenv("KUBECONFIG", default="~/.kube/config") + ) + + def _get_client(self, state: K8sDeploymentState) -> Client: + host, port = None, None + self.load_kube_config() + service = client.CoreV1Api().list_namespaced_service(self.namespace) + try: + host, port = self.service_type.get_host_and_port( + service, self.namespace + ) + except MlemError as e: + raise EndpointNotFound( + "Couldn't determine host and port from the service deployed" + ) from e + if host is not None and port is not None: + return HTTPClient(host=host, port=port) + raise MlemError( + f"host and port determined are not valid, received host as {host} and port as {port}" + ) + + +class K8sEnv(MlemEnv[K8sDeployment]): + """MlemEnv implementation for Kubernetes Environments""" + + type: ClassVar = "kubernetes" + deploy_type: ClassVar = K8sDeployment + """Type of deployment being used for the Kubernetes environment""" + + registry: Optional[DockerRegistry] = None + """Docker registry""" + templates_dir: List[str] = [] + """List of dirs where templates reside""" + + def get_registry(self, meta: K8sDeployment): + registry = meta.registry or self.registry + if not registry: + raise MlemError( + "registry to be used by Docker is not set or supplied" + ) + return registry + + def get_image_name(self, meta: K8sDeployment): + return meta.image_name or generate_docker_container_name() + + def get_server(self, meta: K8sDeployment): + return ( + meta.server + or project_config( + meta.loc.project if meta.is_saved else None + ).server + ) + + def deploy(self, meta: K8sDeployment): + self.check_type(meta) + redeploy = False + with meta.lock_state(): + meta.load_kube_config() + state: K8sDeploymentState = meta.get_state() + if state.image is None or meta.model_changed(): + image_name = self.get_image_name(meta) + state.image = build_k8s_docker( + meta=meta.get_model(), + image_name=image_name, + registry=self.get_registry(meta), + daemon=meta.daemon, + server=self.get_server(meta), + ) + meta.update_model_hash(state=state) + redeploy = True + + if ( + state.deployment_name is None or redeploy + ) and state.image is not None: + generator = K8sYamlGenerator( + namespace=meta.namespace, + image_name=state.image.name, + image_uri=state.image.uri, + image_pull_policy=meta.image_pull_policy, + port=meta.port, + service_type=meta.service_type, + templates_dir=meta.templates_dir or self.templates_dir, + ) + create_k8s_resources(generator) + + if pod_is_running(namespace=meta.namespace): + deployments_list = ( + client.AppsV1Api().list_namespaced_deployment( + namespace=meta.namespace + ) + ) + + if len(deployments_list.items) == 0: + raise DeploymentError( + f"Deployment {image_name} couldn't be found in {meta.namespace} namespace" + ) + dpl_name = deployments_list.items[0].metadata.name + state.deployment_name = dpl_name + meta.update_state(state) + + echo( + EMOJI_OK + + f"Deployment {state.deployment_name} is up in {meta.namespace} namespace" + ) + else: + raise DeploymentError( + f"Deployment {image_name} couldn't be set-up on the Kubernetes cluster" + ) + + def remove(self, meta: K8sDeployment): + self.check_type(meta) + with meta.lock_state(): + meta.load_kube_config() + state: K8sDeploymentState = meta.get_state() + if state.deployment_name is not None: + client.CoreV1Api().delete_namespace(name=meta.namespace) + if namespace_deleted(meta.namespace): + echo( + EMOJI_OK + + f"Deployment {state.deployment_name} and the corresponding service are removed from {meta.namespace} namespace" + ) + state.deployment_name = None + meta.update_state(state) + + def get_status( + self, meta: K8sDeployment, raise_on_error=True + ) -> DeployStatus: + self.check_type(meta) + meta.load_kube_config() + state: K8sDeploymentState = meta.get_state() + if state.deployment_name is None: + return DeployStatus.NOT_DEPLOYED + + pods_list = client.CoreV1Api().list_namespaced_pod(meta.namespace) + + return POD_STATE_MAPPING[pods_list.items[0].status.phase] + + +class K8sYamlBuilder(MlemBuilder, K8sYamlGenerator): + """MlemBuilder implementation for building Kubernetes manifests/yamls""" + + type: ClassVar = "kubernetes" + + target: str + """Target path for the manifest/yaml""" + + def build(self, obj: MlemModel): + self.write(self.target) + echo(EMOJI_OK + f"{self.target} generated for {obj.basename}") diff --git a/mlem/contrib/kubernetes/build.py b/mlem/contrib/kubernetes/build.py new file mode 100644 index 00000000..b5a98c26 --- /dev/null +++ b/mlem/contrib/kubernetes/build.py @@ -0,0 +1,30 @@ +from typing import Optional + +from mlem.core.objects import MlemModel +from mlem.runtime.server import Server +from mlem.ui import EMOJI_BUILD, echo, set_offset + +from ..docker.base import DockerDaemon, DockerEnv, DockerRegistry +from ..docker.helpers import build_model_image + + +def build_k8s_docker( + meta: MlemModel, + image_name: str, + registry: Optional[DockerRegistry], + daemon: Optional[DockerDaemon], + server: Server, + platform: Optional[str] = "linux/amd64", + # runners usually do not support arm64 images built on Mac M1 devices +): + echo(EMOJI_BUILD + f"Creating docker image {image_name}") + with set_offset(2): + return build_model_image( + meta, + image_name, + server, + DockerEnv(registry=registry, daemon=daemon), + tag=meta.meta_hash(), + force_overwrite=True, + platform=platform, + ) diff --git a/mlem/contrib/kubernetes/context.py b/mlem/contrib/kubernetes/context.py new file mode 100644 index 00000000..c6649ced --- /dev/null +++ b/mlem/contrib/kubernetes/context.py @@ -0,0 +1,55 @@ +import logging +import os +from enum import Enum +from typing import ClassVar + +from pydantic import BaseModel + +from mlem.contrib.kubernetes.service import NodePortService, ServiceType +from mlem.utils.templates import TemplateModel + +logger = logging.getLogger(__name__) + + +class ImagePullPolicy(str, Enum): + always = "Always" + never = "Never" + if_not_present = "IfNotPresent" + + +class K8sYamlBuildArgs(BaseModel): + """Class encapsulating parameters for Kubernetes manifests/yamls""" + + class Config: + use_enum_values = True + + namespace: str = "mlem" + """Namespace to create kubernetes resources such as pods, service in""" + image_name: str = "ml" + """Name of the docker image to be deployed""" + image_uri: str = "ml:latest" + """URI of the docker image to be deployed""" + image_pull_policy: ImagePullPolicy = ImagePullPolicy.always + """Image pull policy for the docker image to be deployed""" + port: int = 8080 + """Port where the service should be available""" + service_type: ServiceType = NodePortService() + """Type of service by which endpoints of the model are exposed""" + + +class K8sYamlGenerator(K8sYamlBuildArgs, TemplateModel): + TEMPLATE_FILE: ClassVar = "resources.yaml.j2" + TEMPLATE_DIR: ClassVar = os.path.dirname(__file__) + + def prepare_dict(self): + logger.debug( + 'Generating Resource Yaml via templates from "%s"...', + self.templates_dir, + ) + + logger.debug('Docker image is based on "%s".', self.image_uri) + + k8s_yaml_args = self.dict() + k8s_yaml_args["service_type"] = self.service_type.get_string() + k8s_yaml_args.pop("templates_dir") + return k8s_yaml_args diff --git a/mlem/contrib/kubernetes/resources.yaml.j2 b/mlem/contrib/kubernetes/resources.yaml.j2 new file mode 100644 index 00000000..5cbe9b7b --- /dev/null +++ b/mlem/contrib/kubernetes/resources.yaml.j2 @@ -0,0 +1,47 @@ +apiVersion: v1 +kind: Namespace +metadata: + name: {{ namespace }} + labels: + name: {{ namespace }} + +--- + +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{ image_name }} + namespace: {{ namespace }} +spec: + selector: + matchLabels: + app: {{ image_name }} + template: + metadata: + labels: + app: {{ image_name }} + spec: + containers: + - name: {{ image_name }} + image: {{ image_uri }} + imagePullPolicy: {{ image_pull_policy }} + ports: + - containerPort: {{ port }} + +--- + +apiVersion: v1 +kind: Service +metadata: + name: {{ image_name }} + namespace: {{ namespace }} + labels: + run: {{ image_name }} +spec: + ports: + - port: {{ port }} + protocol: TCP + targetPort: {{ port }} + selector: + app: {{ image_name }} + type: {{ service_type }} diff --git a/mlem/contrib/kubernetes/service.py b/mlem/contrib/kubernetes/service.py new file mode 100644 index 00000000..9f347abd --- /dev/null +++ b/mlem/contrib/kubernetes/service.py @@ -0,0 +1,115 @@ +from abc import abstractmethod +from typing import ClassVar, Optional, Tuple + +from kubernetes import client + +from mlem.core.base import MlemABC +from mlem.core.errors import EndpointNotFound, MlemError + + +def find_index(nodes_list, node_name): + for i, each_node in enumerate(nodes_list): + if each_node.metadata.name == node_name: + return i + return -1 + + +class ServiceType(MlemABC): + """Service Type for services inside a Kubernetes Cluster""" + + abs_name: ClassVar = "k8s_service_type" + + class Config: + type_root = True + + @abstractmethod + def get_string(self): + raise NotImplementedError + + @abstractmethod + def get_host_and_port( + self, service, namespace="mlem" # pylint: disable=unused-argument + ) -> Tuple[Optional[str], Optional[int]]: + """Returns host and port for the service in Kubernetes""" + raise NotImplementedError + + +class NodePortService(ServiceType): + """NodePort Service implementation for service inside a Kubernetes Cluster""" + + type: ClassVar = "nodeport" + + def get_string(self): + return "NodePort" + + def get_host_and_port(self, service, namespace="mlem"): + try: + port = service.items[0].spec.ports[0].node_port + except (IndexError, AttributeError) as e: + raise MlemError( + "Couldn't determine node port of the deployed service" + ) from e + try: + node_name = ( + client.CoreV1Api() + .list_namespaced_pod(namespace) + .items[0] + .spec.node_name + ) + except (IndexError, AttributeError) as e: + raise MlemError( + "Couldn't determine name of the node where the pod is deployed" + ) from e + node_list = client.CoreV1Api().list_node().items + node_index = find_index(node_list, node_name) + if node_index == -1: + raise MlemError( + f"Couldn't find the node where pods in namespace {namespace} exists" + ) + address_dict = node_list[node_index].status.addresses + for each_address in address_dict: + if each_address.type == "ExternalIP": + host = each_address.address + return host, port + raise EndpointNotFound( + f"Node {node_name} doesn't have an externally reachable IP address" + ) + + +class LoadBalancerService(ServiceType): + """LoadBalancer Service implementation for service inside a Kubernetes Cluster""" + + type: ClassVar = "loadbalancer" + + def get_string(self): + return "LoadBalancer" + + def get_host_and_port(self, service, namespace="mlem"): + try: + port = service.items[0].spec.ports[0].port + except (IndexError, AttributeError) as e: + raise MlemError( + "Couldn't determine port of the deployed service" + ) from e + try: + ingress = service.items[0].status.load_balancer.ingress[0] + host = ingress.hostname or ingress.ip + except (IndexError, AttributeError) as e: + raise MlemError( + "Couldn't determine IP address of the deployed service" + ) from e + return host, port + + +class ClusterIPService(ServiceType): + """ClusterIP Service implementation for service inside a Kubernetes Cluster""" + + type: ClassVar = "clusterip" + + def get_string(self): + return "ClusterIP" + + def get_host_and_port(self, service, namespace="mlem"): + raise MlemError( + "Cannot expose service of type ClusterIP outside the Kubernetes Cluster" + ) diff --git a/mlem/contrib/kubernetes/utils.py b/mlem/contrib/kubernetes/utils.py new file mode 100644 index 00000000..ae11fbe8 --- /dev/null +++ b/mlem/contrib/kubernetes/utils.py @@ -0,0 +1,80 @@ +import json +import os +import tempfile + +from kubernetes import client, utils, watch + +from .context import K8sYamlGenerator + + +def create_k8s_resources(generator: K8sYamlGenerator): + k8s_client = client.ApiClient() + with tempfile.TemporaryDirectory(prefix="mlem_k8s_yaml_build_") as tempdir: + filename = os.path.join(tempdir, "resource.yaml") + generator.write(filename) + try: + utils.create_from_yaml(k8s_client, filename, verbose=True) + except utils.FailToCreateError as e: + failures = e.api_exceptions + for each_failure in failures: + error_info = json.loads(each_failure.body) + if error_info["reason"] != "AlreadyExists": + raise e + if error_info["details"]["kind"] == "deployments": + existing_image_uri = ( + client.CoreV1Api() + .list_namespaced_pod(generator.namespace) + .items[0] + .spec.containers[0] + .image + ) + if existing_image_uri != generator.image_uri: + api_instance = client.AppsV1Api() + body = { + "spec": { + "template": { + "spec": { + "containers": [ + { + "name": generator.image_name, + "image": generator.image_uri, + } + ] + } + } + } + } + api_instance.patch_namespaced_deployment( + generator.image_name, + generator.namespace, + body, + pretty=True, + ) + + +def pod_is_running(namespace, timeout=60) -> bool: + w = watch.Watch() + for event in w.stream( + func=client.CoreV1Api().list_namespaced_pod, + namespace=namespace, + timeout_seconds=timeout, + ): + if event["object"].status.phase == "Running": + w.stop() + return True + return False + + +def namespace_deleted(namespace, timeout=60) -> bool: + w = watch.Watch() + for event in w.stream( + func=client.CoreV1Api().list_namespace, + timeout_seconds=timeout, + ): + if ( + namespace == event["object"].metadata.name + and event["type"] == "DELETED" + ): + w.stop() + return True + return False diff --git a/mlem/core/errors.py b/mlem/core/errors.py index d63b22cc..d8ff7f50 100644 --- a/mlem/core/errors.py +++ b/mlem/core/errors.py @@ -39,6 +39,10 @@ class LocationNotFound(MlemError): """Thrown if MLEM could not resolve location""" +class EndpointNotFound(MlemError): + """Thrown if MLEM could not resolve endpoint""" + + class RevisionNotFound(LocationNotFound): _message = "Revision '{rev}' wasn't found in path={path}, fs={fs}" diff --git a/setup.cfg b/setup.cfg index 2ed8d795..0d10ac3b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,6 +24,7 @@ log_level = debug markers = long: Marks long-running tests docker: Marks tests that needs Docker + kubernetes: Marks tests that needs Kubernetes testpaths = tests addopts = -rav --durations=0 --cov=mlem --cov-report=term-missing --cov-report=xml diff --git a/setup.py b/setup.py index dd456f68..3209bc1b 100644 --- a/setup.py +++ b/setup.py @@ -80,6 +80,7 @@ "rmq": ["pika"], "docker": ["docker"], "heroku": ["docker", "fastapi", "uvicorn"], + "kubernetes": ["docker", "kubernetes"], "dvc": ["dvc~=2.0"], } @@ -163,6 +164,13 @@ "env.heroku = mlem.contrib.heroku.meta:HerokuEnv", "deploy_state.heroku = mlem.contrib.heroku.meta:HerokuState", "server._heroku = mlem.contrib.heroku.server:HerokuServer", + "deployment.kubernetes = mlem.contrib.kubernetes.base:K8sDeployment", + "deploy_state.kubernetes = mlem.contrib.kubernetes.base:K8sDeploymentState", + "env.kubernetes = mlem.contrib.kubernetes.base:K8sEnv", + "builder.kubernetes = mlem.contrib.kubernetes.base:K8sYamlBuilder", + "k8s_service_type.clusterip = mlem.contrib.kubernetes.service:ClusterIPService", + "k8s_service_type.loadbalancer = mlem.contrib.kubernetes.service:LoadBalancerService", + "k8s_service_type.nodeport = mlem.contrib.kubernetes.service:NodePortService", "data_reader.lightgbm = mlem.contrib.lightgbm:LightGBMDataReader", "data_type.lightgbm = mlem.contrib.lightgbm:LightGBMDataType", "data_writer.lightgbm = mlem.contrib.lightgbm:LightGBMDataWriter", diff --git a/tests/conftest.py b/tests/conftest.py index 1aa44793..d18bb539 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -43,6 +43,10 @@ MLEM_S3_TEST_BUCKET = "mlem-tests" +def _cut_empty_lines(string): + return "\n".join(line for line in string.splitlines() if line) + + def _check_github_test_repo_ssh_auth(): try: git.cmd.Git().ls_remote(MLEM_TEST_REPO) diff --git a/tests/contrib/test_docker/test_context.py b/tests/contrib/test_docker/test_context.py index 0bb788c8..d8e378f8 100644 --- a/tests/contrib/test_docker/test_context.py +++ b/tests/contrib/test_docker/test_context.py @@ -11,6 +11,7 @@ use_mlem_source, ) from mlem.core.requirements import UnixPackageRequirement +from tests.conftest import _cut_empty_lines from tests.contrib.test_docker.conftest import docker_test REGISTRY_PORT = 5000 @@ -112,10 +113,6 @@ def test_use_wheel_installation(tmpdir): assert f"RUN pip install {MLEM_LOCAL_WHL}" in dockerfile -def _cut_empty_lines(string): - return "\n".join(line for line in string.splitlines() if line) - - def _generate_dockerfile(unix_packages=None, **kwargs): return _cut_empty_lines( DockerfileGenerator(**kwargs).generate( diff --git a/tests/contrib/test_kubernetes/__init__.py b/tests/contrib/test_kubernetes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/contrib/test_kubernetes/conftest.py b/tests/contrib/test_kubernetes/conftest.py new file mode 100644 index 00000000..9df2824e --- /dev/null +++ b/tests/contrib/test_kubernetes/conftest.py @@ -0,0 +1,46 @@ +import os + +import pytest +from kubernetes import client, config + +from tests.conftest import long + +from .utils import Command + + +def is_minikube_running() -> bool: + try: + cmd = Command("minikube status") + returncode = cmd.run(timeout=3, shell=True) + if returncode == 0: + config.load_kube_config( + config_file=os.getenv("KUBECONFIG", default="~/.kube/config") + ) + client.CoreV1Api().list_namespaced_pod("default") + return True + return False + except (config.config_exception.ConfigException, ConnectionRefusedError): + return False + + +def has_k8s(): + if os.environ.get("SKIP_K8S_TESTS", None) == "true": + return False + current_os = os.environ.get("GITHUB_MATRIX_OS") + current_python = os.environ.get("GITHUB_MATRIX_PYTHON") + if ( + current_os is not None + and current_os != "ubuntu-latest" + or current_python is not None + and current_python != "3.9" + ): + return False + return is_minikube_running() + + +def k8s_test(f): + mark = pytest.mark.kubernetes + skip = pytest.mark.skipif( + not has_k8s(), reason="kubernetes is unavailable or skipped" + ) + return long(mark(skip(f))) diff --git a/tests/contrib/test_kubernetes/test_base.py b/tests/contrib/test_kubernetes/test_base.py new file mode 100644 index 00000000..269cd0e4 --- /dev/null +++ b/tests/contrib/test_kubernetes/test_base.py @@ -0,0 +1,131 @@ +import os +import re +import subprocess +import tempfile + +import numpy as np +import pytest +from kubernetes import config +from sklearn.datasets import load_iris +from sklearn.tree import DecisionTreeClassifier + +from mlem.api import save +from mlem.config import project_config +from mlem.contrib.docker.base import DockerDaemon, DockerRegistry +from mlem.contrib.kubernetes.base import ( + K8sDeployment, + K8sDeploymentState, + K8sEnv, +) +from mlem.contrib.kubernetes.build import build_k8s_docker +from mlem.contrib.kubernetes.context import ImagePullPolicy +from mlem.contrib.kubernetes.service import LoadBalancerService +from mlem.core.objects import DeployStatus +from tests.contrib.test_kubernetes.conftest import k8s_test +from tests.contrib.test_kubernetes.utils import Command + + +@pytest.fixture(scope="session") +def minikube_env_variables(): + old_environ = dict(os.environ) + output = subprocess.check_output( + ["minikube", "-p", "minikube", "docker-env"] + ) + export_re = re.compile('export ([A-Z_]+)="(.*)"\\n') + export_pairs = export_re.findall(output.decode("UTF-8")) + for k, v in export_pairs: + os.environ[k] = v + + yield + + os.environ.clear() + os.environ.update(old_environ) + + +@pytest.fixture +def load_kube_config(): + config.load_kube_config(os.getenv("KUBECONFIG", default="~/.kube/config")) + + +@pytest.fixture(scope="session") +def model_meta(tmp_path_factory): + path = os.path.join(tmp_path_factory.getbasetemp(), "saved-model-single") + train, target = load_iris(return_X_y=True) + model = DecisionTreeClassifier().fit(train, target) + return save(model, path, sample_data=train) + + +@pytest.fixture(scope="session") +def k8s_deployment(minikube_env_variables, model_meta): + return K8sDeployment( + name="ml", + model=model_meta.make_link(), + image_pull_policy=ImagePullPolicy.never, + service_type=LoadBalancerService(), + daemon=DockerDaemon(host=os.getenv("DOCKER_HOST", default="")), + ) + + +@pytest.fixture(scope="session") +def docker_image(k8s_deployment): + tmpdir = tempfile.mkdtemp() + k8s_deployment.dump(os.path.join(tmpdir, "deploy")) + return build_k8s_docker( + k8s_deployment.get_model(), + k8s_deployment.image_name, + DockerRegistry(), + DockerDaemon(host=os.getenv("DOCKER_HOST", default="")), + k8s_deployment.server or project_config(None).server, + platform=None, + ) + + +@pytest.fixture +def k8s_deployment_state(docker_image, model_meta): + return K8sDeploymentState( + image=docker_image, + model_hash=model_meta.meta_hash(), + ) + + +@pytest.fixture +def k8s_env(): + return K8sEnv() + + +@k8s_test +@pytest.mark.usefixtures("load_kube_config") +def test_deploy( + k8s_deployment, + k8s_deployment_state, + k8s_env, +): + k8s_deployment.update_state(k8s_deployment_state) + assert k8s_env.get_status(k8s_deployment) == DeployStatus.NOT_DEPLOYED + k8s_env.deploy(k8s_deployment) + k8s_deployment.wait_for_status( + DeployStatus.RUNNING, + allowed_intermediate=[DeployStatus.STARTING], + timeout=10, + times=5, + ) + assert k8s_env.get_status(k8s_deployment) == DeployStatus.RUNNING + k8s_env.remove(k8s_deployment) + assert k8s_env.get_status(k8s_deployment) == DeployStatus.NOT_DEPLOYED + + +@k8s_test +@pytest.mark.usefixtures("load_kube_config") +def test_deployed_service( + k8s_deployment, + k8s_deployment_state, + k8s_env, +): + k8s_deployment.update_state(k8s_deployment_state) + k8s_env.deploy(k8s_deployment) + cmd = Command("minikube tunnel") + cmd.run(timeout=20, shell=True) + client = k8s_deployment.get_client() + train, _ = load_iris(return_X_y=True) + response = client.predict(data=train) + assert np.array_equal(response, np.array([0] * 50 + [1] * 50 + [2] * 50)) diff --git a/tests/contrib/test_kubernetes/test_context.py b/tests/contrib/test_kubernetes/test_context.py new file mode 100644 index 00000000..076a3323 --- /dev/null +++ b/tests/contrib/test_kubernetes/test_context.py @@ -0,0 +1,150 @@ +import pytest + +from mlem.contrib.kubernetes.context import ( + ImagePullPolicy, + K8sYamlBuildArgs, + K8sYamlGenerator, +) +from mlem.contrib.kubernetes.service import LoadBalancerService +from tests.conftest import _cut_empty_lines + + +@pytest.fixture +def k8s_default_manifest(): + return _cut_empty_lines( + """apiVersion: v1 +kind: Namespace +metadata: + name: mlem + labels: + name: mlem + +--- + +apiVersion: apps/v1 +kind: Deployment +metadata: + name: ml + namespace: mlem +spec: + selector: + matchLabels: + app: ml + template: + metadata: + labels: + app: ml + spec: + containers: + - name: ml + image: ml:latest + imagePullPolicy: Always + ports: + - containerPort: 8080 + +--- + +apiVersion: v1 +kind: Service +metadata: + name: ml + namespace: mlem + labels: + run: ml +spec: + ports: + - port: 8080 + protocol: TCP + targetPort: 8080 + selector: + app: ml + type: NodePort +""" + ) + + +@pytest.fixture +def k8s_manifest(): + return _cut_empty_lines( + """apiVersion: v1 +kind: Namespace +metadata: + name: hello + labels: + name: hello + +--- + +apiVersion: apps/v1 +kind: Deployment +metadata: + name: test + namespace: hello +spec: + selector: + matchLabels: + app: test + template: + metadata: + labels: + app: test + spec: + containers: + - name: test + image: test:latest + imagePullPolicy: Never + ports: + - containerPort: 8080 + +--- + +apiVersion: v1 +kind: Service +metadata: + name: test + namespace: hello + labels: + run: test +spec: + ports: + - port: 8080 + protocol: TCP + targetPort: 8080 + selector: + app: test + type: LoadBalancer +""" + ) + + +def test_k8s_yaml_build_args_default(k8s_default_manifest): + build_args = K8sYamlBuildArgs() + assert _generate_k8s_manifest(**build_args.dict()) == k8s_default_manifest + + +def test_k8s_yaml_build_args(k8s_manifest): + build_args = K8sYamlBuildArgs( + namespace="hello", + image_name="test", + image_uri="test:latest", + image_pull_policy=ImagePullPolicy.never, + port=8080, + service_type=LoadBalancerService(), + ) + assert _generate_k8s_manifest(**build_args.dict()) == k8s_manifest + + +def test_k8s_yaml_generator(k8s_manifest): + kwargs = { + "namespace": "hello", + "image_name": "test", + "image_uri": "test:latest", + "image_pull_policy": "Never", + "port": 8080, + "service_type": LoadBalancerService(), + } + assert _generate_k8s_manifest(**kwargs) == k8s_manifest + + +def _generate_k8s_manifest(**kwargs): + return _cut_empty_lines(K8sYamlGenerator(**kwargs).generate()) diff --git a/tests/contrib/test_kubernetes/utils.py b/tests/contrib/test_kubernetes/utils.py new file mode 100644 index 00000000..5568e038 --- /dev/null +++ b/tests/contrib/test_kubernetes/utils.py @@ -0,0 +1,34 @@ +import subprocess +import threading + + +class Command: + """ + Enables to run subprocess commands in a different thread + with TIMEOUT option! + Based on jcollado's solution: + http://stackoverflow.com/questions/1191374/subprocess-with-timeout/4825933#4825933 + """ + + def __init__(self, cmd): + self.cmd = cmd + self.process = None + + def run(self, timeout=0, **kwargs): + def target(**kwargs): + self.process = ( + subprocess.Popen( # pylint: disable=consider-using-with + self.cmd, **kwargs + ) + ) + self.process.communicate() + + thread = threading.Thread(target=target, kwargs=kwargs) + thread.start() + + thread.join(timeout) + if thread.is_alive(): + self.process.terminate() + thread.join() + + return self.process.returncode From a18fa6812f5e19241835cc105de29ac1a30115e5 Mon Sep 17 00:00:00 2001 From: Alexander Guschin <1aguschin@gmail.com> Date: Thu, 15 Sep 2022 18:20:27 +0700 Subject: [PATCH 4/4] initial draft --- mlem/cli/apply.py | 16 ---------------- mlem/cli/build.py | 8 -------- mlem/cli/checkenv.py | 10 +--------- mlem/cli/clone.py | 7 ------- mlem/cli/config.py | 13 ++----------- mlem/cli/declare.py | 4 ---- mlem/cli/deployment.py | 35 ++++------------------------------- mlem/cli/dev.py | 3 --- mlem/cli/import_object.py | 13 +------------ mlem/cli/info.py | 15 +-------------- mlem/cli/init.py | 8 +------- mlem/cli/link.py | 7 ------- mlem/cli/main.py | 28 ++++++---------------------- mlem/cli/serve.py | 3 --- mlem/cli/types.py | 7 ------- mlem/cli/utils.py | 12 ------------ 16 files changed, 16 insertions(+), 173 deletions(-) diff --git a/mlem/cli/apply.py b/mlem/cli/apply.py index 37fc6c27..2c5a83bb 100644 --- a/mlem/cli/apply.py +++ b/mlem/cli/apply.py @@ -74,18 +74,6 @@ def apply( ): """Apply a model to data. The result will be saved as a MLEM object to `output` if provided. Otherwise, it will be printed to `stdout`. - - Examples: - Apply local mlem model to local mlem data - $ mlem apply mymodel mydata --method predict --output myprediction - - Apply local mlem model to local data file - $ mlem apply mymodel data.csv --method predict --import --import-type pandas[csv] --output myprediction - - Apply a version of remote model to a version of remote data - $ mlem apply models/logreg --project https://github.com/iterative/example-mlem --rev main - data/test_x --data-project https://github.com/iterative/example-mlem --data-rev main - --method predict --output myprediction """ from mlem.api import apply @@ -133,10 +121,6 @@ def apply( help="""Apply a deployed-model (possibly remotely) to data. The results will be saved as a MLEM object to `output` if provided. Otherwise, it will be printed to `stdout`. - - Examples: - Apply hosted mlem model to local mlem data - $ mlem apply-remote http mydata -c host="0.0.0.0" -c port=8080 --output myprediction """, cls=mlem_group("runtime"), subcommand_metavar="client", diff --git a/mlem/cli/build.py b/mlem/cli/build.py index 41561ea0..dc916cc2 100644 --- a/mlem/cli/build.py +++ b/mlem/cli/build.py @@ -26,14 +26,6 @@ help=""" Build models to create re-usable, ship-able entities such as a Docker image or Python package. - - Examples: - Build docker image from model - $ mlem build mymodel docker -c server.type=fastapi -c image.name=myimage - - Create build docker_dir declaration and build it - $ mlem declare builder docker_dir -c server=fastapi -c target=build build_dock - $ mlem build mymodel --load build_dock """, cls=mlem_group("runtime", aliases=["export"]), subcommand_metavar="builder", diff --git a/mlem/cli/checkenv.py b/mlem/cli/checkenv.py index b475ca98..15ab4348 100644 --- a/mlem/cli/checkenv.py +++ b/mlem/cli/checkenv.py @@ -14,15 +14,7 @@ def checkenv( project: Optional[str] = option_project, rev: Optional[str] = option_rev, ): - """Check that current environment satisfies object requrements - - Examples: - Check local object - $ mlem checkenv mymodel - - Check remote object - $ mlem checkenv https://github.com/iterative/example-mlem/models/logreg - """ + """Check that current environment satisfies object requrements""" meta = load_meta(path, project, rev, follow_links=True, load_value=False) if isinstance(meta, (MlemModel, MlemData)): meta.checkenv() diff --git a/mlem/cli/clone.py b/mlem/cli/clone.py index 3ae9b1a2..58c7b280 100644 --- a/mlem/cli/clone.py +++ b/mlem/cli/clone.py @@ -24,13 +24,6 @@ def clone( ): """Copy a MLEM Object from `uri` and saves a copy of it to `target` path. - - Examples: - Copy remote model to local directory - $ mlem clone models/logreg --project https://github.com/iterative/example-mlem --rev main mymodel - - Copy remote model to remote MLEM project - $ mlem clone models/logreg --project https://github.com/iterative/example-mlem --rev main mymodel --tp s3://mybucket/mymodel """ from mlem.api.commands import clone diff --git a/mlem/cli/config.py b/mlem/cli/config.py index a4f04ebe..7e73a061 100644 --- a/mlem/cli/config.py +++ b/mlem/cli/config.py @@ -31,11 +31,7 @@ def config_set( True, help="Whether to validate config schema after" ), ): - """Set configuration value - - Examples: - $ mlem config set pandas.default_format csv - """ + """Set configuration value""" fs, path = get_fs(project or "") project = find_project_root(path, fs=fs) try: @@ -68,12 +64,7 @@ def config_get( name: str = Argument(..., help="Dotted name of option"), project: Optional[str] = option_project, ): - """Get configuration value - - Examples: - $ mlem config get pandas.default_format - $ mlem config get pandas.default_format --project https://github.com/iterative/example-mlem/ - """ + """Get configuration value""" fs, path = get_fs(project or "") project = find_project_root(path, fs=fs) with fs.open(posixpath.join(project, MLEM_DIR, CONFIG_FILE_NAME)) as f: diff --git a/mlem/cli/declare.py b/mlem/cli/declare.py index 9acc2952..9d83217d 100644 --- a/mlem/cli/declare.py +++ b/mlem/cli/declare.py @@ -24,10 +24,6 @@ declare = Typer( name="declare", help="""Declares a new MLEM Object metafile from config args and config files. - - Examples: - Create heroku deployment - $ mlem declare env heroku production --api_key <...> """, cls=mlem_group("object"), subcommand_metavar="subtype", diff --git a/mlem/cli/deployment.py b/mlem/cli/deployment.py index 724a2011..3365a570 100644 --- a/mlem/cli/deployment.py +++ b/mlem/cli/deployment.py @@ -51,16 +51,6 @@ def deploy_run( ): """Deploy a model to a target environment. Can use an existing deployment declaration or create a new one on-the-fly. - - Examples: - Create new deployment - $ mlem declare env heroku staging -c api_key=... - $ mlem deploy run service_name -m model -t staging -c name=my_service - - Deploy existing meta - $ mlem declare env heroku staging -c api_key=... - $ mlem declare deployment heroku service_name -c app_name=my_service -c model=model -c env=staging - $ mlem deploy run service_name """ from mlem.api.commands import deploy @@ -84,11 +74,7 @@ def deploy_remove( path: str = Argument(..., help="Path to deployment meta"), project: Optional[str] = option_project, ): - """Stop and destroy deployed instance. - - Examples: - $ mlem deployment remove service_name - """ + """Stop and destroy deployed instance.""" deploy_meta = load_meta(path, project=project, force_type=MlemDeployment) deploy_meta.remove() @@ -98,11 +84,7 @@ def deploy_status( path: str = Argument(..., help="Path to deployment meta"), project: Optional[str] = option_project, ): - """Print status of deployed service. - - Examples: - $ mlem deployment status service_name - """ + """Print status of deployed service.""" with no_echo(): deploy_meta = load_meta( path, project=project, force_type=MlemDeployment @@ -131,11 +113,7 @@ def deploy_wait( 0, "-t", "--times", help="Number of attempts. 0 -> indefinite" ), ): - """Wait for status of deployed service - - Examples: - $ mlem deployment status service_name - """ + """Wait for status of deployed service""" with no_echo(): deploy_meta = load_meta( path, project=project, force_type=MlemDeployment @@ -161,12 +139,7 @@ def deploy_apply( index: bool = option_index, json: bool = option_json, ): - """Apply a deployed model to data. - - Examples: - $ mlem deployment apply service_name - """ - + """Apply a deployed model to data.""" with set_echo(None if json else ...): deploy_meta = load_meta( path, project=project, rev=rev, force_type=MlemDeployment diff --git a/mlem/cli/dev.py b/mlem/cli/dev.py index 526bb7ca..b5227901 100644 --- a/mlem/cli/dev.py +++ b/mlem/cli/dev.py @@ -23,9 +23,6 @@ def find_implementations_diff( ): """Loads `root` module or package and finds implementations of MLEM base classes Shows differences between what was found and what is registered in entrypoints - - Examples: - $ mlem dev fi """ exts = {e.entry for e in load_entrypoints().values()} impls = set(find_abc_implementations(root)[MLEM_ENTRY_POINT]) diff --git a/mlem/cli/import_object.py b/mlem/cli/import_object.py index 7eae2def..3070710b 100644 --- a/mlem/cli/import_object.py +++ b/mlem/cli/import_object.py @@ -29,18 +29,7 @@ def import_object( index: bool = option_index, external: bool = option_external, ): - """Create a `.mlem` metafile for a model or data in any file or directory. - - Examples: - Create MLEM data from local csv - $ mlem import data/data.csv data/imported_data --type pandas[csv] - - Create MLEM model from local pickle file - $ mlem import data/model.pkl data/imported_model - - Create MLEM model from remote pickle file - $ mlem import models/logreg --project https://github.com/iterative/example-mlem --rev no-dvc data/imported_model --type pickle - """ + """Create a `.mlem` metafile for a model or data in any file or directory.""" from mlem.api.commands import import_object import_object( diff --git a/mlem/cli/info.py b/mlem/cli/info.py index 9aea8315..c6d81403 100644 --- a/mlem/cli/info.py +++ b/mlem/cli/info.py @@ -59,13 +59,7 @@ def ls( False, "-i", "--ignore-errors", help="Ignore corrupted objects" ), ): - """List MLEM objects inside a MLEM project (location should be [initialized](/doc/command-reference/init)). - - - Examples: - $ mlem list https://github.com/iterative/example-mlem - $ mlem list -t models - """ + """List MLEM objects inside a MLEM project (location should be [initialized](/doc/command-reference/init)).""" from mlem.api.commands import ls if type_filter == "all": @@ -112,13 +106,6 @@ def pretty_print( ): """Display all details about a specific MLEM Object from an existing MLEM project. - - Examples: - Print local object - $ mlem pprint mymodel - - Print remote object - $ mlem pprint https://github.com/iterative/example-mlem/models/logreg """ with set_echo(None if json else ...): meta = load_meta( diff --git a/mlem/cli/init.py b/mlem/cli/init.py index 8160f21b..24ac5eeb 100644 --- a/mlem/cli/init.py +++ b/mlem/cli/init.py @@ -7,13 +7,7 @@ def init( path: str = Argument(".", help="Where to init project", show_default=False) ): - """Initialize a MLEM project. - - Examples: - $ mlem init - $ mlem init some/local/path - $ mlem init s3://bucket/path/in/cloud - """ + """Initialize a MLEM project.""" from mlem.api.commands import init init(path) diff --git a/mlem/cli/link.py b/mlem/cli/link.py index 691fa70c..078ae715 100644 --- a/mlem/cli/link.py +++ b/mlem/cli/link.py @@ -40,13 +40,6 @@ def link( ): """Create a link (read alias) for an existing MLEM Object, including from remote MLEM projects. - - Examples: - Add alias to local object - $ mlem link my_model latest - - Add remote object to your project without copy - $ mlem link models/logreg --source-project https://github.com/iteartive/example-mlem remote_model """ from mlem.api.commands import link diff --git a/mlem/cli/main.py b/mlem/cli/main.py index ebba1d01..aa5a19f4 100644 --- a/mlem/cli/main.py +++ b/mlem/cli/main.py @@ -28,7 +28,6 @@ LOAD_PARAM_NAME, NOT_SET, CallContext, - _extract_examples, _format_validation_error, get_extra_keys, ) @@ -47,13 +46,11 @@ class MlemMixin(Command): def __init__( self, *args, - examples: Optional[str], section: str = "other", aliases: List[str] = None, **kwargs, ): super().__init__(*args, **kwargs) - self.examples = examples self.section = section self.aliases = aliases self.rich_help_panel = section.capitalize() @@ -72,12 +69,6 @@ def get_help(self, ctx: Context) -> str: self.format_help(ctx, formatter) return formatter.getvalue().rstrip("\n") - def format_epilog(self, ctx: Context, formatter: HelpFormatter) -> None: - super().format_epilog(ctx, formatter) - if self.examples: - with formatter.section("Examples"): - formatter.write(self.examples) - class MlemCommand( MlemMixin, @@ -99,7 +90,8 @@ def __init__( ): self.dynamic_metavar = dynamic_metavar self.dynamic_options_generator = dynamic_options_generator - examples, help = _extract_examples(help) + if help is not None and "Documentation" not in help: + help = f"{help}\n\nDocumentation: " self._help = help self.lazy_help = lazy_help self.pass_from_parent = pass_from_parent @@ -107,7 +99,6 @@ def __init__( name=name, section=section, aliases=aliases, - examples=examples, help=help, **kwargs, ) @@ -197,11 +188,11 @@ def __init__( help: str = None, **attrs: Any, ) -> None: - examples, help = _extract_examples(help) + if help is not None and "Documentation" not in help: + help = f"{help}\n\nDocumentation: " super().__init__( name=name, help=help, - examples=examples, aliases=aliases, section=section, commands=commands, @@ -318,15 +309,8 @@ def mlem_callback( * Serialize any model trained in Python into ready-to-deploy format * Model lifecycle management using Git and GitOps principles * Provider-agnostic deployment - - Examples: - $ mlem init - $ mlem list https://github.com/iterative/example-mlem - $ mlem clone models/logreg --project https://github.com/iterative/example-mlem --rev main logreg - $ mlem link logreg latest - $ mlem apply latest https://github.com/iterative/example-mlem/data/test_x -o pred - $ mlem serve latest fastapi -c port=8001 - $ mlem build latest docker_dir -c target=build/ -c server.type=fastapi + \b + Documentation: """ if ctx.invoked_subcommand is None and show_version: with cli_echo(): diff --git a/mlem/cli/serve.py b/mlem/cli/serve.py index 5d6ff54e..961567bd 100644 --- a/mlem/cli/serve.py +++ b/mlem/cli/serve.py @@ -26,9 +26,6 @@ name="serve", help="""Deploy the model locally using a server implementation and expose its methods as endpoints. - - Examples: - $ mlem serve fastapi https://github.com/iterative/example-mlem/models/logreg """, cls=mlem_group("runtime"), subcommand_metavar="server", diff --git a/mlem/cli/types.py b/mlem/cli/types.py index 52bba21c..5f80db37 100644 --- a/mlem/cli/types.py +++ b/mlem/cli/types.py @@ -77,13 +77,6 @@ def list_types( ): """List different implementations available for a particular MLEM type. If a subtype is not provided, list all available MLEM types. - - Examples: - List ABCs - $ mlem types - - List available server implementations - $ mlem types server """ if abc is None: for at in MlemABC.abs_types.values(): diff --git a/mlem/cli/utils.py b/mlem/cli/utils.py index d92b0044..6ab578ee 100644 --- a/mlem/cli/utils.py +++ b/mlem/cli/utils.py @@ -616,15 +616,3 @@ def config_arg( ) with wrap_build_error(subtype, model): return build_mlem_object(model, subtype, conf, file_conf, kwargs) - - -def _extract_examples( - help_str: Optional[str], -) -> Tuple[Optional[str], Optional[str]]: - if help_str is None: - return None, None - try: - examples = help_str.index("Examples:") - except ValueError: - return None, help_str - return help_str[examples + len("Examples:") + 1 :], help_str[:examples]