Skip to content

Commit

Permalink
Fix deploy arguments passing; Fix style checks
Browse files Browse the repository at this point in the history
  • Loading branch information
nieznanysprawiciel committed Jan 16, 2024
1 parent db32400 commit 5bf31ce
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 84 deletions.
101 changes: 54 additions & 47 deletions examples/transfer-progress/progress.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,26 @@
#!/usr/bin/env python3

import asyncio
import os
import pathlib
import sys
import os
from dataclasses import dataclass
from datetime import datetime

from alive_progress import alive_bar
from dataclasses import dataclass

import yapapi.script.command
from yapapi import Golem
from yapapi.payload import vm
from yapapi.payload.vm import _VmPackage
from yapapi.props.base import constraint
from yapapi.script import ProgressArgs
from yapapi.services import Service

import asyncio
from datetime import datetime

import colorama # type: ignore

from yapapi import Golem

examples_dir = pathlib.Path(__file__).resolve().parent.parent
sys.path.append(str(examples_dir))

from utils import (
build_parser,
format_usage,
print_env_info,
run_golem_example,
)
from utils import build_parser, run_golem_example


def command_key(event: "yapapi.events.CommandProgress") -> str:
Expand All @@ -48,28 +40,38 @@ def progress_bar(self, event: "yapapi.events.CommandProgress"):
if event.message is not None:
print(f"{event.message}")

if event.progress is not None:
if event.progress is not None and event.progress[1] is not None:
progress = event.progress
key = command_key(event)

if progress[1] is not None:
key = command_key(event)
if self._transfers_ctx.get(key) is None:
bar = alive_bar(total=progress[1], manual=True, title="Progress", unit=event.unit, scale=True,
dual_line=True)
bar_ctx = bar.__enter__()
if self._transfers_ctx.get(key) is None:
self.create_progress_bar(event)

if isinstance(event.command, yapapi.script.command.Deploy):
bar_ctx.text = f"Deploying image"
elif isinstance(event.command, yapapi.script.command._SendContent):
bar_ctx.text = f"Uploading file: {event.command._src.download_url} -> {event.command._dst_path}"
elif isinstance(event.command, yapapi.script.command._ReceiveContent):
bar_ctx.text = f"Downloading file: {event.command._src_path} -> {event.command._dst_path}"
bar = self._transfers_ctx.get(key)
bar(progress[0] / progress[1])

self._transfers_bars[key] = bar
self._transfers_ctx[key] = bar_ctx
def create_progress_bar(self, event: "yapapi.events.CommandProgress"):
key = command_key(event)
bar = alive_bar(
total=event.progress[1],
manual=True,
title="Progress",
unit=event.unit,
scale=True,
dual_line=True,
)
bar_ctx = bar.__enter__()

bar = self._transfers_ctx.get(key)
bar(progress[0] / progress[1])
command = event.command
if isinstance(command, yapapi.script.command.Deploy):
bar_ctx.text = "Deploying image"
elif isinstance(command, yapapi.script.command._SendContent):
bar_ctx.text = f"Uploading file: {command._src.download_url} -> {command._dst_path}"
elif isinstance(command, yapapi.script.command._ReceiveContent):
bar_ctx.text = f"Downloading file: {command._src_path} -> {command._dst_path}"

self._transfers_bars[key] = bar
self._transfers_ctx[key] = bar_ctx

def executed(self, event: "yapapi.events.CommandExecuted"):
key = command_key(event)
Expand All @@ -86,7 +88,9 @@ def executed(self, event: "yapapi.events.CommandExecuted"):

@dataclass
class ExamplePayload(_VmPackage):
progress_capability: bool = constraint("golem.activity.caps.transfer.report-progress", operator="=", default=True)
progress_capability: bool = constraint(
"golem.activity.caps.transfer.report-progress", operator="=", default=True
)


class ExampleService(Service):
Expand All @@ -97,11 +101,13 @@ async def get_payload():
min_mem_gib=0.5,
min_storage_gib=10.0,
)
return ExamplePayload(image_url=package.image_url, constraints=package.constraints, progress_capability=True)
return ExamplePayload(
image_url=package.image_url, constraints=package.constraints, progress_capability=True
)

async def start(self):
script = self._ctx.new_script(timeout=None)
script.deploy(progress={})
script.deploy(progress_args=ProgressArgs(updateInterval="300ms"))
script.start()

yield script
Expand All @@ -111,13 +117,14 @@ async def run(self):

script = self._ctx.new_script(timeout=None)
script.download_from_url(
"https://huggingface.co/cointegrated/rubert-tiny2/resolve/main/model.safetensors?download=true",
"/golem/resource/model-small", progress_args=progress)
"https://huggingface.co/cointegrated/rubert-tiny2/resolve/main/model.safetensors",
"/golem/resource/model-small",
progress_args=progress,
)
script.upload_bytes(
os.urandom(40 * 1024 * 1024),
"/golem/resource/bytes.bin", progress_args=progress)
script.download_file(
"/golem/resource/bytes.bin", "download.bin", progress_args=progress)
os.urandom(40 * 1024 * 1024), "/golem/resource/bytes.bin", progress_args=progress
)
script.download_file("/golem/resource/bytes.bin", "download.bin", progress_args=progress)
yield script

os.remove("download.bin")
Expand All @@ -129,16 +136,16 @@ async def run(self):

async def main(subnet_tag, driver=None, network=None):
async with Golem(
budget=50.0,
subnet_tag=subnet_tag,
payment_driver=driver,
payment_network=network,
stream_output=True,
budget=50.0,
subnet_tag=subnet_tag,
payment_driver=driver,
payment_network=network,
stream_output=True,
) as golem:
global shutdown

bar = ProgressDisplayer()
cluster = await golem.run_service(
await golem.run_service(
ExampleService,
num_instances=1,
)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ autoflake = "^1"
flake8 = "^5"
flake8-docstrings = "^1.6"
Flake8-pyproject = "^1.2.2"
pyproject-autoflake = "^1.0.2"

[tool.poe.tasks]
checks = {sequence = ["checks_codestyle", "checks_typing", "checks_license"], help = "Run all available code checks"}
Expand Down
18 changes: 12 additions & 6 deletions yapapi/script/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
Deploy,
DownloadBytes,
DownloadFile,
DownloadFileFromInternet,
DownloadJson,
ProgressArgs,
Run,
SendBytes,
SendFile,
SendJson,
Start,
Terminate, DownloadFileFromInternet, ProgressArgs,
Terminate,
)
from yapapi.storage import DOWNLOAD_BYTES_LIMIT_DEFAULT

Expand Down Expand Up @@ -128,9 +130,11 @@ def add(self, cmd: Command) -> Awaitable[CommandExecuted]:
cmd._set_script(self, len(self._commands) - 1)
return cmd._result

def deploy(self, **kwargs: dict) -> Awaitable[CommandExecuted]:
def deploy(
self, progress_args: Optional[ProgressArgs] = None, **kwargs: dict
) -> Awaitable[CommandExecuted]:
"""Schedule a :class:`Deploy` command on the provider."""
return self.add(Deploy(**kwargs))
return self.add(Deploy(progress_args=progress_args, **kwargs))

def start(self, *args: str) -> Awaitable[CommandExecuted]:
"""Schedule a :class:`Start` command on the provider."""
Expand Down Expand Up @@ -169,7 +173,7 @@ def download_bytes(
src_path: str,
on_download: Callable[[bytes], Awaitable],
limit: int = DOWNLOAD_BYTES_LIMIT_DEFAULT,
**kwargs
**kwargs,
) -> Awaitable[CommandExecuted]:
"""Schedule downloading a remote file from the provider as bytes.
Expand All @@ -192,7 +196,7 @@ def download_json(
src_path: str,
on_download: Callable[[Any], Awaitable],
limit: int = DOWNLOAD_BYTES_LIMIT_DEFAULT,
**kwargs
**kwargs,
) -> Awaitable[CommandExecuted]:
"""Schedule downloading a remote file from the provider as JSON.
Expand Down Expand Up @@ -226,7 +230,9 @@ def upload_json(self, data: dict, dst_path: str, **kwargs) -> Awaitable[CommandE
"""
return self.add(SendJson(data, dst_path, **kwargs))

def download_from_url(self, src_url: str, dst_path: str, progress_args: Optional[ProgressArgs] = None) -> Awaitable[CommandExecuted]:
def download_from_url(
self, src_url: str, dst_path: str, progress_args: Optional[ProgressArgs] = None
) -> Awaitable[CommandExecuted]:
"""Schedule sending a file to the provider.
:param src_url: remote (internet) source url
Expand Down
69 changes: 39 additions & 30 deletions yapapi/script/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from os import PathLike
from pathlib import Path
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Type, Union

import attr

from yapapi.events import CommandEventType, DownloadFinished, DownloadStarted
Expand Down Expand Up @@ -37,7 +38,7 @@ def _make_batch_command(cmd_name: str, **kwargs) -> BatchCommand:
def __init__(self):
self._result: asyncio.Future = asyncio.get_event_loop().create_future()
self._script: Optional["Script"] = None
self._index: int = None
self._index: int = 0

def _set_script(self, script: "Script", index: int) -> None:
assert self._script is None, f"Command {self} already belongs to a script {self._script}"
Expand All @@ -60,7 +61,8 @@ def __repr__(self):

@attr.s(auto_attribs=True, repr=False)
class ProgressArgs:
"""Interval represented as human-readable duration string (examples: '5s' '10min')"""
"""Interval represented as human-readable duration string (examples: '5s' '10min')."""

updateInterval: Optional[str] = attr.field(default=None)
updateStep: Optional[int] = attr.field(default=None)

Expand All @@ -77,7 +79,11 @@ def __repr__(self):
return f"{super().__repr__()} {self.kwargs}"

def evaluate(self):
kwargs = dict(self.kwargs, progress=attr.asdict(self._progress)) if self._progress else self.kwargs
kwargs = (
dict(self.kwargs, progress=attr.asdict(self._progress))
if self._progress
else self.kwargs
)
return self._make_batch_command("deploy", **kwargs)


Expand Down Expand Up @@ -160,7 +166,9 @@ def __init__(self, data: dict, dst_path: str, progress_args: Optional[ProgressAr
:param data: dictionary representing JSON data to send
:param dst_path: remote (provider) destination path
"""
super().__init__(json.dumps(data).encode(encoding="utf-8"), dst_path, progress_args=progress_args)
super().__init__(
json.dumps(data).encode(encoding="utf-8"), dst_path, progress_args=progress_args
)


class SendFile(_SendContent):
Expand All @@ -186,12 +194,12 @@ class Run(Command):
"""Command which schedules running a shell command on a provider."""

def __init__(
self,
cmd: str,
*args: str,
env: Optional[Dict[str, str]] = None,
stderr: CaptureContext = CaptureContext.build(mode="stream"),
stdout: CaptureContext = CaptureContext.build(mode="stream"),
self,
cmd: str,
*args: str,
env: Optional[Dict[str, str]] = None,
stderr: CaptureContext = CaptureContext.build(mode="stream"),
stdout: CaptureContext = CaptureContext.build(mode="stream"),
):
"""Create a new Run command.
Expand Down Expand Up @@ -222,11 +230,7 @@ def __repr__(self):


class _ReceiveContent(Command, abc.ABC):
def __init__(
self,
src_path: str,
progress_args: Optional[ProgressArgs] = None
):
def __init__(self, src_path: str, progress_args: Optional[ProgressArgs] = None):
super().__init__()
self._src_path: str = src_path
self._dst_slot: Optional[Destination] = None
Expand Down Expand Up @@ -259,10 +263,10 @@ class DownloadFile(_ReceiveContent):
"""Command which schedules downloading a file from a provider."""

def __init__(
self,
src_path: str,
dst_path: str,
progress_args: Optional[ProgressArgs] = None,
self,
src_path: str,
dst_path: str,
progress_args: Optional[ProgressArgs] = None,
):
"""Create a new DownloadFile command.
Expand All @@ -288,11 +292,11 @@ class DownloadBytes(_ReceiveContent):
"""Command which schedules downloading a file from a provider as bytes."""

def __init__(
self,
src_path: str,
on_download: Callable[[bytes], Awaitable],
limit: int = DOWNLOAD_BYTES_LIMIT_DEFAULT,
progress_args: Optional[ProgressArgs] = None,
self,
src_path: str,
on_download: Callable[[bytes], Awaitable],
limit: int = DOWNLOAD_BYTES_LIMIT_DEFAULT,
progress_args: Optional[ProgressArgs] = None,
):
"""Create a new DownloadBytes command.
Expand All @@ -317,19 +321,24 @@ class DownloadJson(DownloadBytes):
"""Command which schedules downloading a file from a provider as JSON data."""

def __init__(
self,
src_path: str,
on_download: Callable[[Any], Awaitable],
limit: int = DOWNLOAD_BYTES_LIMIT_DEFAULT,
progress_args: Optional[ProgressArgs] = None,
self,
src_path: str,
on_download: Callable[[Any], Awaitable],
limit: int = DOWNLOAD_BYTES_LIMIT_DEFAULT,
progress_args: Optional[ProgressArgs] = None,
):
"""Create a new DownloadJson command.
:param src_path: remote (provider) source path
:param on_download: the callable to run on the received data
:param limit: limit of bytes to be downloaded (expected size)
"""
super().__init__(src_path, partial(self.__on_json_download, on_download), limit, progress_args=progress_args)
super().__init__(
src_path,
partial(self.__on_json_download, on_download),
limit,
progress_args=progress_args,
)

@staticmethod
async def __on_json_download(on_download: Callable[[bytes], Awaitable], content: bytes):
Expand Down
1 change: 0 additions & 1 deletion yapapi/services/service_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
self._state.stop()
except statemachine.exceptions.TransitionNotAllowed:
"""The ServiceRunner is not running,"""
pass

logger.debug("%s is shutting down... state: %s", self, self.state)

Expand Down

0 comments on commit 5bf31ce

Please sign in to comment.