Skip to content

Commit

Permalink
Make WorkContext backward compatible with Script
Browse files Browse the repository at this point in the history
  • Loading branch information
kmazurek committed Aug 27, 2021
1 parent efaf2c3 commit 6f5fcfd
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 36 deletions.
68 changes: 39 additions & 29 deletions yapapi/ctx.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import abc
from copy import copy
from dataclasses import dataclass, field
from datetime import timedelta, datetime
from deprecated import deprecated # type: ignore
import enum
import json
import logging
Expand Down Expand Up @@ -327,6 +329,7 @@ def __init__(
self._started: bool = False

self.__payment_model: Optional[ComLinear] = None
self.__script: Optional[Script] = None

@property
def id(self) -> str:
Expand All @@ -353,29 +356,32 @@ def _payment_model(self) -> ComLinear:
return self.__payment_model

def __prepare(self):
if not self._started and self._implicit_init:
self.deploy()
self.start()
self._started = True
if not self.__script:
self.__script = Script(self)

def new_script(self):
"""Stuff."""
return Script(self)

@deprecated(version="0.7.0", reason="please use a Script object via WorkContext.new_script")
def deploy(self):
"""Schedule a Deploy command."""
self._implicit_init = False
self._pending_steps.append(_Deploy())
self.__prepare()
self.__script.deploy()

@deprecated(version="0.7.0", reason="please use a Script object via WorkContext.new_script")
def start(self, *args: str):
"""Schedule a Start command."""
self._implicit_init = False
self._pending_steps.append(_Start(*args))
self.__prepare()
self.__script.start(*args)

@deprecated(version="0.7.0", reason="please use a Script object via WorkContext.new_script")
def terminate(self):
"""Schedule a Terminate command."""
self._pending_steps.append(_Terminate())
self.__prepare()
self.__script.terminate()

@deprecated(version="0.7.0", reason="please use a Script object via WorkContext.new_script")
def send_json(self, json_path: str, data: dict):
"""Schedule sending JSON data to the provider.
Expand All @@ -384,8 +390,9 @@ def send_json(self, json_path: str, data: dict):
:return: None
"""
self.__prepare()
self._pending_steps.append(_SendJson(self._storage, data, json_path))
self.__script.send_json(data, json_path)

@deprecated(version="0.7.0", reason="please use a Script object via WorkContext.new_script")
def send_bytes(self, dst_path: str, data: bytes):
"""Schedule sending bytes data to the provider.
Expand All @@ -394,8 +401,9 @@ def send_bytes(self, dst_path: str, data: bytes):
:return: None
"""
self.__prepare()
self._pending_steps.append(_SendBytes(self._storage, data, dst_path))
self.__script.send_bytes(data, dst_path)

@deprecated(version="0.7.0", reason="please use a Script object via WorkContext.new_script")
def send_file(self, src_path: str, dst_path: str):
"""Schedule sending file to the provider.
Expand All @@ -404,8 +412,9 @@ def send_file(self, src_path: str, dst_path: str):
:return: None
"""
self.__prepare()
self._pending_steps.append(_SendFile(self._storage, src_path, dst_path))
self.__script.send_file(src_path, dst_path)

@deprecated(version="0.7.0", reason="please use a Script object via WorkContext.new_script")
def run(
self,
cmd: str,
Expand All @@ -419,12 +428,10 @@ def run(
:param env: optional dictionary with environmental variables
:return: None
"""
stdout = CaptureContext.build(mode="stream")
stderr = CaptureContext.build(mode="stream")

self.__prepare()
self._pending_steps.append(_Run(cmd, *args, env=env, stdout=stdout, stderr=stderr))
self.__script.run(cmd, *args, env=env)

@deprecated(version="0.7.0", reason="please use a Script object via WorkContext.new_script")
def download_file(self, src_path: str, dst_path: str):
"""Schedule downloading remote file from the provider.
Expand All @@ -433,51 +440,54 @@ def download_file(self, src_path: str, dst_path: str):
:return: None
"""
self.__prepare()
self._pending_steps.append(_ReceiveFile(self._storage, src_path, dst_path, self._emitter))
self.__script.download_file(src_path, dst_path)

@deprecated(version="0.7.0", reason="please use a Script object via WorkContext.new_script")
def download_bytes(
self,
src_path: str,
on_download: Callable[[bytes], Awaitable],
limit: int = DOWNLOAD_BYTES_LIMIT_DEFAULT,
):
"""Schedule downloading a remote file as bytes
:param src_path: remote (provider) path
:param on_download: the callable to run on the received data
:param limit: the maximum length of the expected byte string
:return None
"""
self.__prepare()
self._pending_steps.append(
_ReceiveBytes(self._storage, src_path, on_download, limit, self._emitter)
)
self.__script.download_bytes(src_path, on_download, limit)

@deprecated(version="0.7.0", reason="please use a Script object via WorkContext.new_script")
def download_json(
self,
src_path: str,
on_download: Callable[[Any], Awaitable],
limit: int = DOWNLOAD_BYTES_LIMIT_DEFAULT,
):
"""Schedule downloading a remote file as JSON
:param src_path: remote (provider) path
:param on_download: the callable to run on the received JSON data
:param limit: the maximum length of the expected remote file
:return None
"""
self.__prepare()
self._pending_steps.append(
_ReceiveJson(self._storage, src_path, on_download, limit, self._emitter)
)
self.__script.download_json(src_path, on_download, limit)

def commit(self, timeout: Optional[timedelta] = None) -> Work:
@deprecated(version="0.7.0", reason="please use a Script object via WorkContext.new_script")
def commit(self, timeout: Optional[timedelta] = None) -> Script:
"""Creates a sequence of commands to be sent to provider.
:return: Work object containing the sequence of commands
scheduled within this work context before calling this method)
:return: Script object containing the sequence of commands
scheduled within this work context before calling this method
"""
steps = self._pending_steps
self._pending_steps = []
return Steps(*steps, timeout=timeout)
if timeout:
self.__script.timeout = timeout
script_to_commit = copy(self.__script)
self.__script = None
return script_to_commit

async def get_raw_usage(self) -> yaa_ActivityUsage:
"""Get the raw usage vector for the activity bound to this work context.
Expand Down
4 changes: 3 additions & 1 deletion yapapi/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from yapapi.rest.activity import CommandExecutionError, Activity
from yapapi.rest.market import Agreement, AgreementDetails, OfferProposal, Subscription
from yapapi.script import Script
from yapapi.script.command import BatchCommand
from yapapi.storage import gftp
from yapapi.strategy import (
DecreaseScoreForUnconfirmedAgreement,
Expand Down Expand Up @@ -592,7 +593,8 @@ async def process_batches(

try:
await script._before()
remote = await activity.send(script._evaluate(), deadline=batch_deadline)
batch: List[BatchCommand] = script._evaluate()
remote = await activity.send(batch, deadline=batch_deadline)
except Exception:
item = await command_generator.athrow(*sys.exc_info())
continue
Expand Down
14 changes: 8 additions & 6 deletions yapapi/script/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ async def _after(self):
async def _before(self):
"""Hook which is executed before the script is evaluated and sent to the provider."""
if not self._ctx._started and self._ctx._implicit_init:
# TODO: maybe check if first two steps already cover this?
loop = asyncio.get_event_loop()
self._commands.insert(0, (Deploy(), loop.create_future()))
self._commands.insert(1, (Start(), loop.create_future()))
Expand All @@ -48,7 +47,10 @@ async def _before(self):
await cmd.before(self._ctx)

def _set_cmd_result(self, result: CommandExecuted) -> None:
self._commands[result.cmd_idx][1].set_result(result)
cmd = self._commands[result.cmd_idx]
cmd[1].set_result(result)
if isinstance(cmd, Start):
self._ctx._started = True

def add(self, cmd: Command) -> Awaitable[CommandExecuted]:
loop = asyncio.get_event_loop()
Expand Down Expand Up @@ -82,15 +84,15 @@ def send_bytes(self, data: bytes, dst_path: str) -> Awaitable[CommandExecuted]:
:param data: bytes to send
:param dst_path: remote (provider) destination path
"""
return self.add(SendBytes(data, dst_path)
return self.add(SendBytes(data, dst_path))

def send_file(self, src_path: str, dst_path: str) -> Awaitable[CommandExecuted]:
"""Schedule sending a file to the provider.
:param src_path: local (requestor) source path
:param dst_path: remote (provider) destination path
"""
return self.add(SendFile(src_path, dst_path)
return self.add(SendFile(src_path, dst_path))

def run(
self,
Expand Down Expand Up @@ -123,7 +125,7 @@ def download_bytes(
src_path: str,
on_download: Callable[[bytes], Awaitable],
limit: int = DOWNLOAD_BYTES_LIMIT_DEFAULT,
) -> Awaitable[CommandExecuted]:
) -> Awaitable[CommandExecuted]:
"""Schedule downloading a remote file from the provider as bytes.
:param src_path: remote (provider) source path
Expand All @@ -137,7 +139,7 @@ def download_json(
src_path: str,
on_download: Callable[[Any], Awaitable],
limit: int = DOWNLOAD_BYTES_LIMIT_DEFAULT,
) -> Awaitable[CommandExecuted]:
) -> Awaitable[CommandExecuted]:
"""Schedule downloading a remote file from the provider as JSON.
:param src_path: remote (provider) source path
Expand Down

0 comments on commit 6f5fcfd

Please sign in to comment.