diff --git a/mlem/api/commands.py b/mlem/api/commands.py index 6b2c9431..328fbadd 100644 --- a/mlem/api/commands.py +++ b/mlem/api/commands.py @@ -420,9 +420,11 @@ def deploy( fs: Optional[AbstractFileSystem] = None, external: bool = None, index: bool = None, + env_kwargs: Dict[str, Any] = None, **deploy_kwargs, ) -> MlemDeployment: deploy_path = None + update = False if isinstance(deploy_meta_or_path, str): deploy_path = deploy_meta_or_path try: @@ -432,13 +434,13 @@ def deploy( fs=fs, force_type=MlemDeployment, ) + update = True except MlemObjectNotFound: deploy_meta = None else: deploy_meta = deploy_meta_or_path - if model is not None: - deploy_meta.replace_model(get_model_meta(model)) + update = True if deploy_meta is None: if model is None or env is None: @@ -451,14 +453,21 @@ def deploy( env_meta = ensure_meta(MlemEnv, env, allow_typename=True) if isinstance(env_meta, type): env = None + if env_kwargs: + env = env_meta(**env_kwargs) deploy_type = env_meta.deploy_type deploy_meta = deploy_type( model_cache=model_meta, + model=model_meta.make_link(), env=env, - model=model, **deploy_kwargs, ) deploy_meta.dump(deploy_path, fs, project, index, external) + else: + if model is not None: + deploy_meta.replace_model(get_model_meta(model, load_value=False)) + if update: + pass # todo update from deploy_args and env_args # ensuring links are working deploy_meta.get_env() deploy_meta.get_model() diff --git a/mlem/cli/deployment.py b/mlem/cli/deployment.py index c6ae8eda..01f404da 100644 --- a/mlem/cli/deployment.py +++ b/mlem/cli/deployment.py @@ -22,7 +22,7 @@ from mlem.core.data_type import DataAnalyzer from mlem.core.errors import DeploymentError from mlem.core.metadata import load_meta -from mlem.core.objects import DeployState, MlemDeployment +from mlem.core.objects import DeployState, DeployStatus, MlemDeployment from mlem.ui import echo, no_echo, set_echo deployment = Typer( @@ -67,6 +67,9 @@ def deploy_run( """ from mlem.api.commands import deploy + conf = conf or [] + env_conf = [c[len("env.") :] for c in conf if c.startswith("env.")] + conf = [c for c in conf if not c.startswith("env.")] deploy( path, model, @@ -74,6 +77,7 @@ def deploy_run( project, external=external, index=index, + env_kwargs=parse_string_conf(env_conf), **parse_string_conf(conf or []), ) @@ -110,6 +114,40 @@ def deploy_status( echo(status) +@mlem_command("wait", parent=deployment) +def deploy_wait( + path: str = Argument(..., help="Path to deployment meta"), + project: Optional[str] = option_project, + statuses: List[DeployStatus] = Option( + [DeployStatus.RUNNING], + "-s", + "--status", + help="statuses to wait for", + ), + intermediate: List[DeployStatus] = Option( + None, "-i", "--intermediate", help="Possible intermediate statuses" + ), + poll_timeout: float = Option( + 1.0, "-p", "--poll-timeout", help="Timeout between attempts" + ), + times: int = Option( + 0, "-t", "--times", help="Number of attempts. 0 -> indefinite" + ), +): + """Wait for status of deployed service + + Examples: + $ mlem deployment status service_name + """ + with no_echo(): + deploy_meta = load_meta( + path, project=project, force_type=MlemDeployment + ) + deploy_meta.wait_for_status( + statuses, poll_timeout, times, allowed_intermediate=intermediate + ) + + @mlem_command("apply", parent=deployment) def deploy_apply( path: str = Argument(..., help="Path to deployment meta"), @@ -144,7 +182,7 @@ def deploy_apply( raise DeploymentError( f"{deploy_meta.type} deployment has no state. Either {deploy_meta.type} is not deployed yet or has been un-deployed again." ) - client = state.get_client() + client = deploy_meta.get_client(state) result = run_apply_remote( client, diff --git a/mlem/contrib/docker/base.py b/mlem/contrib/docker/base.py index 8ac3785a..39379c8f 100644 --- a/mlem/contrib/docker/base.py +++ b/mlem/contrib/docker/base.py @@ -187,10 +187,10 @@ def push(self, client, tag): if "error" in status: error_msg = status["error"] raise DeploymentError(f"Cannot push docker image: {error_msg}") - echo(EMOJI_OK + f"Pushed image {tag} to {self.host}") + echo(EMOJI_OK + f"Pushed image {tag} to {self.get_host()}") def uri(self, image: str): - return f"{self.host}/{image}" + return f"{self.get_host()}/{image}" def _get_digest(self, name, tag): r = requests.head( @@ -286,9 +286,6 @@ class DockerContainerState(DeployState): container_name: Optional[str] container_id: Optional[str] - def get_client(self): - raise NotImplementedError - class _DockerBuildMixin(BaseModel): server: Optional[Server] = None @@ -320,6 +317,9 @@ class DockerContainer(MlemDeployment, _DockerBuildMixin): def ensure_image_name(self): return self.image_name or self.container_name + def _get_client(self, state: DockerContainerState): + raise NotImplementedError + class DockerEnv(MlemEnv[DockerContainer]): """:class:`.MlemEnv` implementation for docker environment diff --git a/mlem/contrib/docker/copy.j2 b/mlem/contrib/docker/copy.j2 new file mode 100644 index 00000000..916bbf2c --- /dev/null +++ b/mlem/contrib/docker/copy.j2 @@ -0,0 +1 @@ +COPY . ./ diff --git a/mlem/contrib/docker/dockerfile.j2 b/mlem/contrib/docker/dockerfile.j2 index 447a2dc5..daa05601 100644 --- a/mlem/contrib/docker/dockerfile.j2 +++ b/mlem/contrib/docker/dockerfile.j2 @@ -1,12 +1,9 @@ FROM {{ base_image }} WORKDIR /app {% include "pre_install.j2" ignore missing %} -{% if packages %}RUN {{ package_install_cmd }} {{ packages|join(" ") }}{% endif %} -COPY requirements.txt . -RUN pip install -r requirements.txt -{{ mlem_install }} +{% include "install_req.j2" %} {% include "post_install.j2" ignore missing %} -COPY . ./ +{% include "copy.j2" %} {% for name, value in env.items() %}ENV {{ name }}={{ value }} {% endfor %} {% include "post_copy.j2" ignore missing %} diff --git a/mlem/contrib/docker/install_req.j2 b/mlem/contrib/docker/install_req.j2 new file mode 100644 index 00000000..35d8a35c --- /dev/null +++ b/mlem/contrib/docker/install_req.j2 @@ -0,0 +1,4 @@ +{% if packages %}RUN {{ package_install_cmd }} {{ packages|join(" ") }}{% endif %} +COPY requirements.txt . +RUN pip install -r requirements.txt +{{ mlem_install }} diff --git a/mlem/contrib/heroku/meta.py b/mlem/contrib/heroku/meta.py index b8838c2b..d2d98225 100644 --- a/mlem/contrib/heroku/meta.py +++ b/mlem/contrib/heroku/meta.py @@ -44,11 +44,6 @@ def ensured_app(self) -> HerokuAppMeta: raise ValueError("App is not created yet") return self.app - def get_client(self) -> Client: - return HTTPClient( - host=urlparse(self.ensured_app.web_url).netloc, port=80 - ) - class HerokuDeployment(MlemDeployment): type: ClassVar = "heroku" @@ -58,6 +53,11 @@ class HerokuDeployment(MlemDeployment): stack: str = "container" team: Optional[str] = None + def _get_client(self, state: HerokuState) -> Client: + return HTTPClient( + host=urlparse(state.ensured_app.web_url).netloc, port=80 + ) + class HerokuEnv(MlemEnv[HerokuDeployment]): type: ClassVar = "heroku" diff --git a/mlem/contrib/sagemaker/__init__.py b/mlem/contrib/sagemaker/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mlem/contrib/sagemaker/build.py b/mlem/contrib/sagemaker/build.py new file mode 100644 index 00000000..48f83bdd --- /dev/null +++ b/mlem/contrib/sagemaker/build.py @@ -0,0 +1,123 @@ +import base64 +import os +from typing import ClassVar, Optional + +import boto3 +import sagemaker +from pydantic import BaseModel + +from ...core.objects import MlemModel +from ...ui import EMOJI_BUILD, EMOJI_KEY, echo, set_offset +from ..docker.base import DockerEnv, DockerImage, RemoteRegistry +from ..docker.helpers import build_model_image +from .runtime import SageMakerServer + +IMAGE_NAME = "mlem-sagemaker-runner" + + +class AWSVars(BaseModel): + profile: str + bucket: str + region: str + account: str + role_name: str + + @property + def role(self): + return f"arn:aws:iam::{self.account}:role/{self.role_name}" + + def get_sagemaker_session(self): + return sagemaker.Session( + self.get_session(), default_bucket=self.bucket + ) + + def get_session(self): + return boto3.Session( + profile_name=self.profile, region_name=self.region + ) + + +def ecr_repo_check(region, repository, session: boto3.Session): + client = session.client("ecr", region_name=region) + + repos = client.describe_repositories()["repositories"] + + if repository not in {r["repositoryName"] for r in repos}: + echo(EMOJI_BUILD + f"Creating ECR repository {repository}") + client.create_repository(repositoryName=repository) + + +class ECRegistry(RemoteRegistry): + class Config: + exclude = {"aws_vars"} + + type: ClassVar = "ecr" + account: str + region: str + + aws_vars: Optional[AWSVars] = None + + def login(self, client): + auth_data = self.ecr_client.get_authorization_token() + token = auth_data["authorizationData"][0]["authorizationToken"] + user, token = base64.b64decode(token).decode("utf8").split(":") + self._login(self.get_host(), client, user, token) + echo( + EMOJI_KEY + + f"Logged in to remote registry at host {self.get_host()}" + ) + + def get_host(self) -> Optional[str]: + return f"{self.account}.dkr.ecr.{self.region}.amazonaws.com" + + def image_exists(self, client, image: DockerImage): + images = self.ecr_client.list_images(repositoryName=image.name)[ + "imageIds" + ] + return len(images) > 0 + + def delete_image(self, client, image: DockerImage, force=False, **kwargs): + self.ecr_client.batch_delete_image( + repositoryName=image.name, + imageIds=[{"imageTag": image.tag}], + ) + + def with_aws_vars(self, aws_vars): + self.aws_vars = aws_vars + return self + + @property + def ecr_client(self): + return ( + self.aws_vars.get_session().client("ecr") + if self.aws_vars + else boto3.client("ecr", region_name=self.region) + ) + + +def build_sagemaker_docker( + meta: MlemModel, + method: str, + account: str, + region: str, + image_name: str, + repository: str, + aws_vars: AWSVars, +): + docker_env = DockerEnv( + registry=ECRegistry(account=account, region=region).with_aws_vars( + aws_vars + ) + ) + ecr_repo_check(region, repository, aws_vars.get_session()) + echo(EMOJI_BUILD + "Creating docker image for sagemaker") + with set_offset(2): + return build_model_image( + meta, + name=repository, + tag=image_name, + server=SageMakerServer(method=method), + env=docker_env, + force_overwrite=True, + templates_dir=[os.path.dirname(__file__)], + ) diff --git a/mlem/contrib/sagemaker/copy.j2 b/mlem/contrib/sagemaker/copy.j2 new file mode 100644 index 00000000..e69de29b diff --git a/mlem/contrib/sagemaker/env_setup.py b/mlem/contrib/sagemaker/env_setup.py new file mode 100644 index 00000000..1b10258b --- /dev/null +++ b/mlem/contrib/sagemaker/env_setup.py @@ -0,0 +1,93 @@ +import os +import shutil +import subprocess + +from mlem.ui import echo + +MLEM_TF = "mlem_sagemaker.tf" + + +def _tf_command(tf_dir, command, *flags, **args): + args = " ".join(f"-var='{k}={v}'" for k, v in args.items()) + return " ".join( + [ + "terraform", + f"-chdir={tf_dir}", + command, + *flags, + args, + ] + ) + + +def _tf_get_var(tf_dir, varname): + return ( + subprocess.check_output( + _tf_command(tf_dir, "output", varname), shell=True + ) + .decode("utf8") + .strip() + .strip('"') + ) + + +def sagemaker_terraform( + user_name: str = "mlem", + role_name: str = "mlem", + region_name: str = "us-east-1", + profile: str = "default", + plan: bool = False, + work_dir: str = ".", + export_secret: str = None, +): + if not os.path.exists(work_dir): + os.makedirs(work_dir, exist_ok=True) + + shutil.copy( + os.path.join(os.path.dirname(__file__), MLEM_TF), + os.path.join(work_dir, MLEM_TF), + ) + subprocess.check_output(_tf_command(work_dir, "init"), shell=True) + + flags = ["-auto-approve"] if not plan else [] + + echo( + subprocess.check_output( + _tf_command( + work_dir, + "plan" if plan else "apply", + *flags, + role_name=role_name, + user_name=user_name, + region_name=region_name, + profile=profile, + ), + shell=True, + ) + ) + + if not plan and export_secret: + if os.path.exists(export_secret): + print( + f"Creds already present at {export_secret}, please backup and remove them" + ) + return + key_id = _tf_get_var(work_dir, "access_key_id") + access_secret = _tf_get_var(work_dir, "secret_access_key") + region = _tf_get_var(work_dir, "region_name") + profile = _tf_get_var(work_dir, "aws_user") + print(profile, region) + if export_secret.endswith(".csv"): + secrets = f"""User Name,Access key ID,Secret access key +{profile},{key_id},{access_secret}""" + print( + f"Import new profile:\naws configure import --csv file://{export_secret}\naws configure set region {region} --profile {profile}" + ) + else: + secrets = f"""export AWS_ACCESS_KEY_ID={key_id} +export AWS_SECRET_ACCESS_KEY={access_secret} +export AWS_REGION={region} +""" + print(f"Source envs:\nsource {export_secret}") + with open(export_secret, "w", encoding="utf8") as f: + f.write(secrets) diff --git a/mlem/contrib/sagemaker/meta.py b/mlem/contrib/sagemaker/meta.py new file mode 100644 index 00000000..3c798ead --- /dev/null +++ b/mlem/contrib/sagemaker/meta.py @@ -0,0 +1,458 @@ +import os +import posixpath +import tarfile +import tempfile +from typing import ClassVar, Optional, Tuple + +import boto3 +import sagemaker +from pydantic import validator +from sagemaker.deserializers import JSONDeserializer +from sagemaker.serializers import JSONSerializer + +from mlem.config import MlemConfigBase, project_config +from mlem.contrib.docker.base import DockerDaemon, DockerImage +from mlem.contrib.sagemaker.build import ( + AWSVars, + ECRegistry, + build_sagemaker_docker, +) +from mlem.core.errors import WrongMethodError +from mlem.core.model import Signature +from mlem.core.objects import ( + DeployState, + DeployStatus, + MlemDeployment, + MlemEnv, + MlemModel, +) +from mlem.runtime.client import Client +from mlem.runtime.interface import InterfaceDescriptor +from mlem.ui import EMOJI_BUILD, EMOJI_UPLOAD, echo + +MODEL_TAR_FILENAME = "model.tar.gz" +DEFAULT_ECR_REPOSITORY = "mlem" + + +class AWSConfig(MlemConfigBase): + ROLE: Optional[str] + PROFILE: Optional[str] + + class Config: + section = "aws" + env_prefix = "AWS_" + + +def generate_model_file_name(deploy_id): + return f"mlem-model-{deploy_id}" + + +def generate_image_name(deploy_id): + return f"mlem-sagemaker-image-{deploy_id}" + + +class SagemakerClient(Client): + type: ClassVar = "sagemaker" + + endpoint_name: str + aws_vars: AWSVars + signature: Signature + + def _interface_factory(self) -> InterfaceDescriptor: + return InterfaceDescriptor(methods={"predict": self.signature}) + + def get_predictor(self): + sess = self.aws_vars.get_sagemaker_session() + predictor = sagemaker.Predictor( + endpoint_name=self.endpoint_name, + sagemaker_session=sess, + serializer=JSONSerializer(), + deserializer=JSONDeserializer(), + ) + return predictor + + def _call_method(self, name, args): + return self.get_predictor().predict(args) + + +class SagemakerDeployState(DeployState): + type: ClassVar = "sagemaker" + image: Optional[DockerImage] = None + image_tag: Optional[str] = None + model_location: Optional[str] = None + endpoint_name: Optional[str] = None + endpoint_model_hash: Optional[str] = None + method_signature: Optional[Signature] = None + region: Optional[str] = None + previous: Optional["SagemakerDeployState"] = None + + @property + def image_uri(self): + if self.image is None: + if self.image_tag is None: + raise ValueError( + "Cannot get image_uri: image not built or not specified prebuilt image uri" + ) + return self.image_tag + return self.image.uri + + def get_predictor(self, session: sagemaker.Session): + predictor = sagemaker.Predictor( + endpoint_name=self.endpoint_name, + sagemaker_session=session, + serializer=JSONSerializer(), + deserializer=JSONDeserializer(), + ) + return predictor + + +class SagemakerDeployment(MlemDeployment): + type: ClassVar = "sagemaker" + state_type: ClassVar = SagemakerDeployState + + method: str = "predict" + """Model method to be deployed""" + image_tag: Optional[str] = None + """Name of the docker image to use""" + use_prebuilt: bool = False + """Use pre-built docker image. If True, image_name should be set""" + model_arch_location: Optional[str] = None + """Path on s3 to store model archive (excluding bucket)""" + model_name: Optional[str] + """Name for SageMaker Model""" + endpoint_name: Optional[str] = None + """Name for SageMaker Endpoint""" + initial_instance_count: int = 1 + """Initial instance count for Endpoint""" + instance_type: str = "ml.t2.medium" + """Instance type for Endpoint""" + accelerator_type: Optional[str] = None + "The size of the Elastic Inference (EI) instance to use" + + @validator("use_prebuilt") + def ensure_image_name( # pylint: disable=no-self-argument + cls, value, values # noqa: B902 + ): + if value and "image_name" not in values: + raise ValueError( + "image_name should be set if use_prebuilt is true" + ) + return value + + def _get_client(self, state: "SagemakerDeployState"): + return SagemakerClient( + endpoint_name=state.endpoint_name, + aws_vars=self.get_env().get_session_and_aws_vars( + region=state.region + )[1], + signature=state.method_signature, + ) + + +ENDPOINT_STATUS_MAPPING = { + "Creating": DeployStatus.STARTING, + "Failed": DeployStatus.CRASHED, + "InService": DeployStatus.RUNNING, + "OutOfService": DeployStatus.STOPPED, + "Updating": DeployStatus.STARTING, + "SystemUpdating": DeployStatus.STARTING, + "RollingBack": DeployStatus.STARTING, + "Deleting": DeployStatus.STOPPED, +} + + +class SagemakerEnv(MlemEnv): + type: ClassVar = "sagemaker" + deploy_type: ClassVar = SagemakerDeployment + + role: Optional[str] = None + account: Optional[str] = None + region: Optional[str] = None + bucket: Optional[str] = None + profile: Optional[str] = None + ecr_repository: Optional[str] = None + + @property + def role_name(self): + return f"arn:aws:iam::{self.account}:role/{self.role}" + + @staticmethod + def _create_and_upload_model_arch( + session: sagemaker.Session, + model: MlemModel, + bucket: str, + model_arch_location: str, + ) -> str: + with tempfile.TemporaryDirectory() as dirname: + model.clone(os.path.join(dirname, "model", "model")) + arch_path = os.path.join(dirname, "arch", MODEL_TAR_FILENAME) + os.makedirs(os.path.dirname(arch_path)) + with tarfile.open(arch_path, "w:gz") as tar: + path = os.path.join(dirname, "model") + for file in os.listdir(path): + tar.add(os.path.join(path, file), arcname=file) + + model_location = session.upload_data( + os.path.dirname(arch_path), + bucket=bucket, + key_prefix=posixpath.join( + model_arch_location, model.meta_hash() + ), + ) + + return model_location + + @staticmethod + def _delete_model_file(session: sagemaker.Session, model_path: str): + s3_client = session.boto_session.client("s3") + if model_path.startswith("s3://"): + model_path = model_path[len("s3://") :] + bucket, *paths = model_path.split("/") + model_path = posixpath.join(*paths, MODEL_TAR_FILENAME) + s3_client.delete_object(Bucket=bucket, Key=model_path) + + def deploy(self, meta: SagemakerDeployment): + with meta.lock_state(): + state: SagemakerDeployState = meta.get_state() + redeploy = meta.model_changed() + state.previous = state.previous or SagemakerDeployState() + + session, aws_vars = self.get_session_and_aws_vars(state.region) + if state.region is None: + state.region = aws_vars.region + meta.update_state(state) + + if not meta.use_prebuilt and (state.image_tag is None or redeploy): + self._build_image(meta, state, aws_vars) + + if state.model_location is None or redeploy: + self._upload_model(meta, state, aws_vars, session) + + if ( + state.endpoint_name is None + or redeploy + or state.endpoint_model_hash is not None + and state.endpoint_model_hash != state.model_hash + ): + if state.endpoint_name is None: + self._deploy_model(meta, state, aws_vars, session) + else: + self._update_model(meta, state, aws_vars, session) + + def _update_model( + self, + meta: SagemakerDeployment, + state: SagemakerDeployState, + aws_vars: AWSVars, + session: sagemaker.Session, + ): + assert state.model_location is not None # TODO + sm_model = sagemaker.Model( + image_uri=state.image_uri, + model_data=posixpath.join( + state.model_location, MODEL_TAR_FILENAME + ), + name=meta.model_name, + role=aws_vars.role, + sagemaker_session=session, + ) + sm_model.create( + instance_type=meta.instance_type, + accelerator_type=meta.accelerator_type, + ) + prev_endpoint_conf = session.sagemaker_client.describe_endpoint( + EndpointName=state.endpoint_name + )["EndpointConfigName"] + prev_model_name = session.sagemaker_client.describe_endpoint_config( + EndpointConfigName=prev_endpoint_conf + )["ProductionVariants"][0]["ModelName"] + + predictor = state.get_predictor(session) + predictor.update_endpoint( + model_name=sm_model.name, + initial_instance_count=meta.initial_instance_count, + instance_type=meta.instance_type, + accelerator_type=meta.accelerator_type, + wait=True, + ) + session.sagemaker_client.delete_model(ModelName=prev_model_name) + prev = state.previous + if prev is not None: + if prev.image is not None: + self._delete_image(meta, prev, aws_vars) + if prev.model_location is not None: + self._delete_model_file(session, prev.model_location) + prev.model_location = None + session.sagemaker_client.delete_endpoint_config( + EndpointConfigName=prev_endpoint_conf + ) + state.endpoint_model_hash = state.model_hash + meta.update_state(state) + + def _delete_image(self, meta, state, aws_vars): + with DockerDaemon(host="").client() as client: + if isinstance(state.image.registry, ECRegistry): + state.image.registry.with_aws_vars(aws_vars) + state.image.delete(client) + state.image = None + meta.update_state(state) + + def _deploy_model( + self, + meta: SagemakerDeployment, + state: SagemakerDeployState, + aws_vars: AWSVars, + session: sagemaker.Session, + ): + assert state.model_location is not None # TODO + sm_model = sagemaker.Model( + image_uri=state.image_uri, + model_data=posixpath.join( + state.model_location, MODEL_TAR_FILENAME + ), + name=meta.model_name, + role=aws_vars.role, + sagemaker_session=session, + ) + echo( + EMOJI_BUILD + + f"Starting up sagemaker {meta.initial_instance_count} `{meta.instance_type}` instance(s)..." + ) + sm_model.deploy( + initial_instance_count=meta.initial_instance_count, + instance_type=meta.instance_type, + accelerator_type=meta.accelerator_type, + endpoint_name=meta.endpoint_name, + wait=False, + ) + state.endpoint_name = sm_model.endpoint_name + state.endpoint_model_hash = state.model_hash + meta.update_state(state) + + def _upload_model( + self, + meta: SagemakerDeployment, + state: SagemakerDeployState, + aws_vars: AWSVars, + session: sagemaker.Session, + ): + assert state.previous is not None # TODO + echo( + EMOJI_UPLOAD + + f"Uploading model distribution to {aws_vars.bucket}..." + ) + if state.model_location is not None: + state.previous.model_location = state.model_location + state.model_location = self._create_and_upload_model_arch( + session, + meta.get_model(), + aws_vars.bucket, + meta.model_arch_location + or generate_model_file_name(meta.get_model().meta_hash()), + ) + meta.update_model_hash(state=state) + meta.update_state(state) + + def _build_image( + self, + meta: SagemakerDeployment, + state: SagemakerDeployState, + aws_vars: AWSVars, + ): + assert state.previous is not None # TODO + model = meta.get_model() + try: + state.method_signature = model.model_type.methods[meta.method] + except KeyError as e: + raise WrongMethodError( + f"Wrong method {meta.method} for model {model.name}" + ) from e + image_tag = meta.image_tag or model.meta_hash() + if state.image_tag is not None: + state.previous.image_tag = state.image_tag + state.previous.image = state.image + state.image = build_sagemaker_docker( + model, + meta.method, + aws_vars.account, + aws_vars.region, + image_tag, + self.ecr_repository or DEFAULT_ECR_REPOSITORY, + aws_vars, + ) + state.image_tag = image_tag + meta.update_state(state) + + def remove(self, meta: SagemakerDeployment): + with meta.lock_state(): + state: SagemakerDeployState = meta.get_state() + session, aws_vars = self.get_session_and_aws_vars(state.region) + if state.model_location is not None: + self._delete_model_file(session, state.model_location) + if state.endpoint_name is not None: + + client = session.sagemaker_client + endpoint_conf = session.sagemaker_client.describe_endpoint( + EndpointName=state.endpoint_name + )["EndpointConfigName"] + + model_name = client.describe_endpoint_config( + EndpointConfigName=endpoint_conf + )["ProductionVariants"][0]["ModelName"] + client.delete_model(ModelName=model_name) + client.delete_endpoint(EndpointName=state.endpoint_name) + client.delete_endpoint_config(EndpointConfigName=endpoint_conf) + if state.image is not None: + self._delete_image(meta, state, aws_vars) + meta.purge_state() + + def get_status( + self, meta: SagemakerDeployment, raise_on_error=True + ) -> "DeployStatus": + with meta.lock_state(): + state: SagemakerDeployState = meta.get_state() + session = self.get_session(state.region) + + endpoint = session.sagemaker_client.describe_endpoint( + EndpointName=state.endpoint_name + ) + status = endpoint["EndpointStatus"] + return ENDPOINT_STATUS_MAPPING.get(status, DeployStatus.UNKNOWN) + + def get_session(self, region: str = None) -> sagemaker.Session: + return self.get_session_and_aws_vars(region)[0] + + def get_session_and_aws_vars( + self, region: str = None + ) -> Tuple[sagemaker.Session, AWSVars]: + return init_aws_vars( + self.profile, + self.role, + self.bucket, + region or self.region, + self.account, + ) + + +def init_aws_vars( + profile=None, role=None, bucket=None, region=None, account=None +): + boto_session = boto3.Session(profile_name=profile, region_name=region) + sess = sagemaker.Session(boto_session, default_bucket=bucket) + + bucket = ( + bucket or sess.default_bucket() + ) # Replace with your own bucket name if needed + region = region or boto_session.region_name + config = project_config(project="", section=AWSConfig) + role = role or config.ROLE or sagemaker.get_execution_role(sess) + account = account or boto_session.client("sts").get_caller_identity().get( + "Account" + ) + return sess, AWSVars( + bucket=bucket, + region=region, + account=account, + role_name=role, + profile=profile or config.PROFILE, + ) diff --git a/mlem/contrib/sagemaker/mlem_sagemaker.tf b/mlem/contrib/sagemaker/mlem_sagemaker.tf new file mode 100644 index 00000000..ffbb5a5d --- /dev/null +++ b/mlem/contrib/sagemaker/mlem_sagemaker.tf @@ -0,0 +1,82 @@ +variable "profile" { + description = "AWS Profile to use for API calls" + type = string + default = "default" +} + +variable "role_name" { + description = "AWS role name" + type = string + default = "mlem" +} + +variable "user_name" { + description = "AWS user name" + type = string + default = "mlem" +} + +variable "region_name" { + description = "AWS region name" + type = string + default = "us-east-1" +} + +provider "aws" { + region = var.region_name + profile = var.profile +} + +resource "aws_iam_user" "aws_user" { + name = var.user_name +} + +resource "aws_iam_access_key" "aws_user" { + user = aws_iam_user.aws_user.name +} + +resource "aws_iam_user_policy_attachment" "sagemaker_policy" { + user = aws_iam_user.aws_user.name + policy_arn = "arn:aws:iam::aws:policy/AmazonSageMakerFullAccess" +} + +resource "aws_iam_user_policy_attachment" "ecr_policy" { + user = aws_iam_user.aws_user.name + policy_arn = "arn:aws:iam::aws:policy/AmazonEC2ContainerRegistryFullAccess" +} + +resource "aws_iam_role" "aws_role" { + name = var.role_name + description = "MLEM SageMaker Role" + assume_role_policy = < /usr/local/bin/serve && chmod +x /usr/local/bin/serve +ENTRYPOINT ["bash", "-c"] diff --git a/mlem/contrib/sagemaker/runtime.py b/mlem/contrib/sagemaker/runtime.py new file mode 100644 index 00000000..4ff780b4 --- /dev/null +++ b/mlem/contrib/sagemaker/runtime.py @@ -0,0 +1,63 @@ +import logging +from types import ModuleType +from typing import ClassVar, Dict, List + +import boto3 +import fastapi +import sagemaker +import uvicorn + +from mlem.config import MlemConfigBase, project_config +from mlem.contrib.fastapi import FastAPIServer +from mlem.runtime import Interface + +logger = logging.getLogger(__name__) + + +class SageMakerServerConfig(MlemConfigBase): + HOST: str = "0.0.0.0" + PORT: int = 8080 + METHOD: str = "predict" + + class Config: + section = "sagemaker" + + +local_config = project_config("", section=SageMakerServerConfig) + + +def ping(): + return "OK" + + +class SageMakerServer(FastAPIServer): + type: ClassVar = "sagemaker" + libraries: ClassVar[List[ModuleType]] = [ + uvicorn, + fastapi, + sagemaker, + boto3, + ] + method: str = local_config.METHOD + port: int = local_config.PORT + host: str = local_config.HOST + + def app_init(self, interface: Interface): + app = super().app_init(interface) + + handler, response_model = self._create_handler( + "invocations", + interface.get_method_signature(self.method), + interface.get_method_executor(self.method), + ) + app.add_api_route( + "/invocations", + handler, + methods=["POST"], + response_model=response_model, + ) + app.add_api_route("/ping", ping, methods=["GET"]) + return app + + def get_env_vars(self) -> Dict[str, str]: + return {"SAGAMAKER_METHOD": self.method} diff --git a/mlem/core/errors.py b/mlem/core/errors.py index 7b2a5aaf..9314f846 100644 --- a/mlem/core/errors.py +++ b/mlem/core/errors.py @@ -119,6 +119,21 @@ def __init__(self, meta, force_type): ) +class WrongMetaSubType(TypeError, MlemError): + def __init__(self, meta, force_type): + loc = f"from {meta.loc.uri} " if meta.is_saved else "" + super().__init__( + f"Wrong type of meta loaded, got {meta.object_type} {meta.type} {loc}instead of {force_type.object_type} {force_type.type}" + ) + + +class WrongABCType(TypeError, MlemError): + def __init__(self, instance, expected_abc_type): + super().__init__( + f"Wrong implementation type, got {instance.type} instead of {expected_abc_type.type}" + ) + + class DeploymentError(MlemError): """Thrown if something goes wrong during deployment process""" diff --git a/mlem/core/meta_io.py b/mlem/core/meta_io.py index 565d0a93..5ceaeb15 100644 --- a/mlem/core/meta_io.py +++ b/mlem/core/meta_io.py @@ -2,6 +2,7 @@ Utils functions that parse and process supplied URI, serialize/derialize MLEM objects """ import contextlib +import os import posixpath from abc import ABC, abstractmethod from inspect import isabstract @@ -67,6 +68,8 @@ def abs(cls, path: str, fs: AbstractFileSystem): def update_path(self, path): if not self.uri.endswith(self.path): raise ValueError("cannot automatically update uri") + if os.path.isabs(self.path) and not os.path.isabs(path): + path = posixpath.join(posixpath.dirname(self.path), path) self.uri = self.uri[: -len(self.path)] + path self.path = path diff --git a/mlem/core/objects.py b/mlem/core/objects.py index 98331b96..964a6065 100644 --- a/mlem/core/objects.py +++ b/mlem/core/objects.py @@ -3,6 +3,7 @@ """ import contextlib import hashlib +import itertools import os import posixpath import time @@ -49,6 +50,8 @@ MlemObjectNotFound, MlemObjectNotSavedError, MlemProjectNotFound, + WrongABCType, + WrongMetaSubType, WrongMetaType, ) from mlem.core.meta_io import ( @@ -804,10 +807,6 @@ class Config: model_hash: Optional[str] = None - @abstractmethod - def get_client(self): - raise NotImplementedError - DT = TypeVar("DT", bound="MlemDeployment") @@ -1007,8 +1006,10 @@ def lock(self, deployment: "MlemDeployment"): EnvLink: TypeAlias = MlemLink.typed_link(MlemEnv) ModelLink: TypeAlias = MlemLink.typed_link(MlemModel) +ET = TypeVar("ET", bound=MlemEnv) + -class MlemDeployment(MlemObject, Generic[ST]): +class MlemDeployment(MlemObject, Generic[ST, ET]): """Base class for deployment metadata""" object_type: ClassVar = "deployment" @@ -1022,7 +1023,7 @@ class Config: abs_name: ClassVar = "deployment" type: ClassVar[str] state_type: ClassVar[Type[ST]] - env_type: ClassVar[MlemEnv] + env_type: ClassVar[Type[ET]] env: Union[str, MlemEnv, EnvLink, None] = None env_cache: Optional[MlemEnv] = None @@ -1059,8 +1060,14 @@ def update_state(self, state: ST): def purge_state(self): self._state_manager.purge_state(self) - def get_client(self) -> "Client": - return self.get_state().get_client() + def get_client(self, state: DeployState = None) -> "Client": + if state is not None and not isinstance(state, self.state_type): + raise WrongABCType(state, self.state_type) + return self._get_client(state or self.get_state()) + + @abstractmethod + def _get_client(self, state: ST) -> "Client": + raise NotImplementedError @validator("env") def validate_env(cls, value): # pylint: disable=no-self-argument @@ -1069,15 +1076,19 @@ def validate_env(cls, value): # pylint: disable=no-self-argument return value.path if not isinstance(value, EnvLink): return EnvLink(**value.dict()) + if isinstance(value, str): + return make_posix(value) return value - def get_env(self): + def get_env(self) -> ET: if self.env_cache is None: if isinstance(self.env, str): link = MlemLink( path=self.env, - project=self.loc.project, - rev=self.loc.rev, + project=self.loc.project + if not os.path.isabs(self.env) + else None, + rev=self.loc.rev if not os.path.isabs(self.env) else None, link_type=MlemEnv.object_type, ) self.env_cache = link.load_link(force_type=MlemEnv) @@ -1092,6 +1103,12 @@ def get_env(self): raise MlemError( f"{self.env_type} env does not have default value, please set `env` field" ) from e + else: + raise ValueError( + "env should be one of [str, MlemLink, MlemEnv]" + ) + if not isinstance(self.env_cache, self.env_type): + raise WrongMetaSubType(self.env_cache, self.env_type) return self.env_cache @validator("model") @@ -1101,6 +1118,8 @@ def validate_model(cls, value): # pylint: disable=no-self-argument return value.path if not isinstance(value, ModelLink): return ModelLink(**value.dict()) + if isinstance(value, str): + return make_posix(value) return value def get_model(self) -> MlemModel: @@ -1108,12 +1127,20 @@ def get_model(self) -> MlemModel: if isinstance(self.model, str): link = MlemLink( path=self.model, - project=self.loc.project, - rev=self.loc.rev, + project=self.loc.project + if not os.path.isabs(self.model) + else None, + rev=self.loc.rev + if not os.path.isabs(self.model) + else None, link_type=MlemModel.object_type, ) + if self.is_saved: + link.bind(self.loc) self.model_cache = link.load_link(force_type=MlemModel) elif isinstance(self.model, MlemLink): + if self.is_saved: + self.model.bind(self.loc) self.model_cache = self.model.load_link(force_type=MlemModel) else: raise ValueError( @@ -1139,7 +1166,7 @@ def wait_for_status( DeployStatus, Iterable[DeployStatus] ] = None, raise_on_timeout: bool = True, - ): + ) -> object: if isinstance(status, DeployStatus): statuses = {status} else: @@ -1151,7 +1178,12 @@ def wait_for_status( allowed = set(allowed_intermediate) current = DeployStatus.UNKNOWN - for _ in range(times): + iterator: Iterable + if times == 0: + iterator = itertools.count() + else: + iterator = range(times) + for _ in iterator: current = self.get_status(raise_on_error=False) if current in statuses: return True @@ -1163,13 +1195,14 @@ def wait_for_status( return False time.sleep(timeout) if raise_on_timeout: + # TODO: count actual time passed raise DeploymentError( f"Deployment status is still {current} after {times * timeout} seconds" ) return False - def model_changed(self): - state = self.get_state() + def model_changed(self, state: Optional[ST] = None): + state = state or self.get_state() if state.model_hash is None: return True return self.get_model().meta_hash() != state.model_hash diff --git a/mlem/core/requirements.py b/mlem/core/requirements.py index 26e7d6ee..048c72da 100644 --- a/mlem/core/requirements.py +++ b/mlem/core/requirements.py @@ -2,6 +2,7 @@ Base classes to work with requirements which come with ML models and data """ import base64 +import collections import contextlib import glob import itertools @@ -490,7 +491,8 @@ def resolve_requirements(other: "AnyRequirements") -> Requirements: if isinstance(other[0], str): return Requirements( __root__=[ - InstallableRequirement.from_str(r) for r in set(other) + InstallableRequirement.from_str(r) + for r in collections.OrderedDict.fromkeys(other) ] ) diff --git a/mlem/ext.py b/mlem/ext.py index 1aecf256..31150828 100644 --- a/mlem/ext.py +++ b/mlem/ext.py @@ -108,6 +108,7 @@ class ExtensionLoader: Extension("mlem.contrib.github", [], True), Extension("mlem.contrib.gitlabfs", [], True), Extension("mlem.contrib.bitbucketfs", [], True), + Extension("mlem.contrib.sagemaker", ["sagemaker", "boto3"], False), ) _loaded_extensions: Dict[Extension, ModuleType] = {} diff --git a/mlem/ui.py b/mlem/ui.py index e66aa010..a42fb100 100644 --- a/mlem/ui.py +++ b/mlem/ui.py @@ -100,3 +100,4 @@ def bold(text): EMOJI_BUILD = emoji("🛠") EMOJI_UPLOAD = emoji("🔼") EMOJI_STOP = emoji("🔻") +EMOJI_KEY = emoji("🗝") diff --git a/mlem/utils/fslock.py b/mlem/utils/fslock.py index f2850290..17c1edd2 100644 --- a/mlem/utils/fslock.py +++ b/mlem/utils/fslock.py @@ -27,7 +27,7 @@ def __init__( salt=None, ): self.fs = fs - self.dirpath = dirpath + self.dirpath = make_posix(str(dirpath)) self.name = name self.timeout = timeout self.retry_timeout = retry_timeout diff --git a/setup.py b/setup.py index 48178827..1d73db3b 100644 --- a/setup.py +++ b/setup.py @@ -66,7 +66,7 @@ "xgboost": ["xgboost"], "lightgbm": ["lightgbm"], "fastapi": ["uvicorn", "fastapi"], - # "sagemaker": ["boto3==1.19.12", "sagemaker"], + "sagemaker": ["boto3", "sagemaker"], "torch": ["torch"], "tensorflow": ["tensorflow"], "azure": ["adlfs>=2021.10.0", "azure-identity>=1.4.0", "knack"], @@ -187,6 +187,12 @@ "builder.whl = mlem.contrib.pip.base:WhlBuilder", "client.rmq = mlem.contrib.rabbitmq:RabbitMQClient", "server.rmq = mlem.contrib.rabbitmq:RabbitMQServer", + "docker_registry.ecr = mlem.contrib.sagemaker.build:ECRegistry", + "client.sagemaker = mlem.contrib.sagemaker.meta:SagemakerClient", + "deploy_state.sagemaker = mlem.contrib.sagemaker.meta:SagemakerDeployState", + "deployment.sagemaker = mlem.contrib.sagemaker.meta:SagemakerDeployment", + "env.sagemaker = mlem.contrib.sagemaker.meta:SagemakerEnv", + "server.sagemaker = mlem.contrib.sagemaker.runtime:SageMakerServer", "model_type.sklearn = mlem.contrib.sklearn:SklearnModel", "model_type.sklearn_pipeline = mlem.contrib.sklearn:SklearnPipelineType", "model_type.tf_keras = mlem.contrib.tensorflow:TFKerasModel", diff --git a/tests/cli/test_deployment.py b/tests/cli/test_deployment.py index 3e3a9593..87c5e77d 100644 --- a/tests/cli/test_deployment.py +++ b/tests/cli/test_deployment.py @@ -20,20 +20,9 @@ from tests.cli.conftest import Runner -@pytest.fixture -def mock_deploy_get_client(mocker, request_get_mock, request_post_mock): - return mocker.patch( - "tests.cli.test_deployment.DeployStateMock.get_client", - return_value=HTTPClient(host="", port=None), - ) - - class DeployStateMock(DeployState): allow_default: ClassVar = True - def get_client(self) -> Client: - pass - class MlemDeploymentMock(MlemDeployment): class Config: @@ -45,6 +34,9 @@ class Config: status: DeployStatus = DeployStatus.NOT_DEPLOYED param: str = "" + def _get_client(self, state) -> Client: + return HTTPClient(host="", port=None) + class MlemEnvMock(MlemEnv): type: ClassVar = "mock" @@ -98,12 +90,12 @@ def test_deploy_meta_str_model(mlem_project, model_meta, mock_env_path): "env": make_posix(mock_env_path), } - assert ( - load_meta( - "deployment", project=mlem_project, force_type=MlemDeployment - ) - == deployment + deployment2 = load_meta( + "deployment", project=mlem_project, force_type=MlemDeployment ) + assert deployment2 == deployment + assert deployment2.get_model() == model_meta + assert deployment2.get_env() == load_meta(mock_env_path) def test_deploy_meta_link_str_model(mlem_project, model_meta, mock_env_path): @@ -124,21 +116,22 @@ def test_deploy_meta_link_str_model(mlem_project, model_meta, mock_env_path): "env": make_posix(mock_env_path), } - assert ( - load_meta( - "deployment", project=mlem_project, force_type=MlemDeployment - ) - == deployment + deployment2 = load_meta( + "deployment", project=mlem_project, force_type=MlemDeployment ) + assert deployment2 == deployment + assert deployment2.get_model() == model_meta + assert deployment2.get_env() == load_meta(mock_env_path) def test_deploy_meta_link_model(mlem_project, model_meta, mock_env_path): model_meta.dump("model", project=mlem_project) + load_meta(mock_env_path).clone("project_env", project=mlem_project) deployment = MlemDeploymentMock( model=MlemLink(path="model", project=mlem_project, link_type="model"), env=MlemLink( - path=mock_env_path, project=mlem_project, link_type="env" + path="project_env", project=mlem_project, link_type="env" ), ) deployment.dump("deployment", project=mlem_project) @@ -146,21 +139,64 @@ def test_deploy_meta_link_model(mlem_project, model_meta, mock_env_path): with deployment.loc.open("r") as f: data = safe_load(f) assert data == { - "model": {"path": "model", "project": mlem_project}, + "model": {"path": "model", "project": make_posix(mlem_project)}, "object_type": "deployment", "type": "mock", "env": { - "path": make_posix(mock_env_path), - "project": mlem_project, + "path": "project_env", + "project": make_posix(mlem_project), }, } - assert ( - load_meta( - "deployment", project=mlem_project, force_type=MlemDeployment - ) - == deployment + deployment2 = load_meta( + "deployment", project=mlem_project, force_type=MlemDeployment + ) + assert deployment2 == deployment + assert deployment2.get_model() == model_meta + assert deployment2.get_env() == load_meta(mock_env_path) + + +def test_deploy_meta_link_model_no_project(tmpdir, model_meta, mock_env_path): + model_path = os.path.join(tmpdir, "model") + model_meta.dump(model_path) + + deployment = MlemDeploymentMock( + model=MlemLink(path="model", link_type="model"), + env=MlemLink(path=mock_env_path, link_type="env"), + ) + deployment_path = os.path.join(tmpdir, "deployment") + deployment.dump(deployment_path) + + with deployment.loc.open("r") as f: + data = safe_load(f) + assert data == { + "model": "model", + "object_type": "deployment", + "type": "mock", + "env": make_posix(mock_env_path), + } + + deployment2 = load_meta(deployment_path, force_type=MlemDeployment) + assert deployment2 == deployment + assert deployment2.get_model() == model_meta + assert deployment2.get_env() == load_meta(mock_env_path) + + +def test_read_relative_model_from_remote_deploy_meta(): + """TODO + path = "s3://..." + model.dump(path / "model"); + deployment = MlemDeploymentMock( + model=model, + env=MlemLink( + path=mock_env_path, link_type="env" + ), ) + deployment.dump(path / deployment) + + deployment2 = load_meta(...) + deployment2.get_model() + """ def test_deploy_create_new( @@ -205,8 +241,9 @@ def test_deploy_apply( runner: Runner, mock_deploy_path, data_path, - mock_deploy_get_client, tmp_path, + request_get_mock, + request_post_mock, ): path = os.path.join(tmp_path, "output") result = runner.invoke( diff --git a/tests/core/test_objects.py b/tests/core/test_objects.py index c0f01e7d..0e84e5ce 100644 --- a/tests/core/test_objects.py +++ b/tests/core/test_objects.py @@ -44,13 +44,15 @@ def get_status(self): def destroy(self): pass - def get_client(self): + +class MyMlemDeployment(MlemDeployment): + def _get_client(self, state): pass @pytest.fixture() def meta(): - return MlemDeployment( + return MyMlemDeployment( env="", model=MlemLink(path="", link_type="model"), ) diff --git a/tests/core/test_requirements.py b/tests/core/test_requirements.py index 15c5a094..e88c1b4a 100644 --- a/tests/core/test_requirements.py +++ b/tests/core/test_requirements.py @@ -143,6 +143,12 @@ def test_req_collection_main(tmpdir, postfix): } +def test_consistent_resolve_order(): + reqs = ["a", "b", "c"] + for _ in range(10): + assert resolve_requirements(reqs).modules == reqs + + # Copyright 2019 Zyfra # Copyright 2021 Iterative # diff --git a/tests/test_config.py b/tests/test_config.py index 4a706dca..4a6c4e83 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -32,4 +32,4 @@ def test_loading_remote(s3_tmp_path, s3_storage_fs): def test_default_server(): - assert project_config().server == FastAPIServer() + assert project_config("").server == FastAPIServer() diff --git a/tests/test_ext.py b/tests/test_ext.py index 79802781..38e41c8f 100644 --- a/tests/test_ext.py +++ b/tests/test_ext.py @@ -1,3 +1,6 @@ +import re +from pathlib import Path + from mlem import ExtensionLoader from mlem.utils.entrypoints import ( MLEM_ENTRY_POINT, @@ -21,6 +24,19 @@ def test_find_implementations(): assert not i.startswith("None") +def _write_entrypoints(impls_sorted): + setup_path = Path(__file__).parent.parent / "setup.py" + with open(setup_path, encoding="utf8") as f: + setup_py = f.read() + impls_string = ",\n".join(f' "{i}"' for i in impls_sorted) + new_entrypoints = f'"mlem.contrib": [\n{impls_string},\n ]' + setup_py = re.subn( + r'"mlem\.contrib": \[\n[^]]*]', new_entrypoints, setup_py + )[0] + with open(setup_path, "w", encoding="utf8") as f: + f.write(setup_py) + + def test_all_impls_in_entrypoints(): # if this test fails, add new entrypoints (take the result of find_implementations()) to setup.py and # reinstall your dev copy of mlem to re-populate them @@ -30,7 +46,12 @@ def test_all_impls_in_entrypoints(): impls_sorted = sorted( impls, key=lambda x: tuple(x.split(" = ")[1].split(":")) ) - assert exts == set(impls), str(impls_sorted) + impls_set = set(impls) + if exts != impls_set: + _write_entrypoints(impls_sorted) + assert ( + exts == impls_set + ), "New enrtypoints written to setup.py, please reinstall" def test_all_ext_has_pip_extra(): diff --git a/tests/utils/test_fslock.py b/tests/utils/test_fslock.py index af7738ec..3f93bec9 100644 --- a/tests/utils/test_fslock.py +++ b/tests/utils/test_fslock.py @@ -5,6 +5,7 @@ from fsspec.implementations.local import LocalFileSystem from mlem.utils.fslock import LOCK_EXT, FSLock +from mlem.utils.path import make_posix NAME = "testlock" @@ -17,8 +18,10 @@ def test_fslock(tmpdir): with lock: assert lock._timestamp is not None assert lock._salt is not None - lock_path = os.path.join( - tmpdir, f"{NAME}.{lock._timestamp}.{lock._salt}.{LOCK_EXT}" + lock_path = make_posix( + os.path.join( + tmpdir, f"{NAME}.{lock._timestamp}.{lock._salt}.{LOCK_EXT}" + ) ) assert lock.lock_path == lock_path assert fs.exists(lock_path) @@ -29,7 +32,7 @@ def test_fslock(tmpdir): def _work(dirname, num): - time.sleep(0.2 + num / 10) + time.sleep(0.3 + num / 5) with FSLock(LocalFileSystem(), dirname, NAME, salt=num): path = os.path.join(dirname, NAME) if os.path.exists(path):