diff --git a/examples/erigon/erigon.py b/examples/erigon/erigon.py new file mode 100644 index 000000000..ede420648 --- /dev/null +++ b/examples/erigon/erigon.py @@ -0,0 +1,108 @@ +import asyncio + +from dataclasses import dataclass + +from yapapi.props.base import prop, constraint +from yapapi.props import inf + +from yapapi.payload import Payload +from yapapi.executor import Golem +from yapapi.executor.services import Service + +from yapapi.log import enable_default_logger, log_summary, log_event_repr # noqa + + +TURBOGETH_RUNTIME_NAME = "turbogeth-managed" +PROP_ERIGON_ETHEREUM_NETWORK = "golem.srv.app.eth.network" + + +@dataclass +class ErigonPayload(Payload): + network: str = prop(PROP_ERIGON_ETHEREUM_NETWORK) + + runtime: str = constraint(inf.INF_RUNTIME_NAME, default=TURBOGETH_RUNTIME_NAME) + min_mem_gib: float = constraint(inf.INF_MEM, operator=">=", default=16) + min_storage_gib: float = constraint(inf.INF_STORAGE, operator=">=", default=1024) + + +class ErigonService(Service): + credentials = None + + def post_init(self): + self.credentials = {} + + def __repr__(self): + srv_repr = super().__repr__() + return f"{srv_repr}, credentials: {self.credentials}" + + @staticmethod + async def get_payload(): + return ErigonPayload(network="rinkeby") + + async def start(self): + deploy_idx = self.ctx.deploy() + self.ctx.start() + future_results = yield self.ctx.commit() + results = await future_results + self.credentials = "RECEIVED" or results[deploy_idx] # (NORMALLY, WOULD BE PARSED) + + async def run(self): + + while True: + print(f"service {self.ctx.id} running on {self.ctx.provider_name} ... ") + signal = self._listen_nowait() + if signal and signal.message == "go": + self.ctx.run("go!") + yield self.ctx.commit() + else: + await asyncio.sleep(1) + yield + + async def shutdown(self): + self.ctx.download_file("some/service/state", "temp/path") + yield self.ctx.commit() + + +async def main(subnet_tag, driver=None, network=None): + + async with Golem( + budget=10.0, + subnet_tag=subnet_tag, + driver=driver, + network=network, + event_consumer=log_summary(log_event_repr), + ) as golem: + cluster = await golem.run_service( + ErigonService, + num_instances=1, + ) + + def instances(): + return [{s.ctx.id, s.state.value} for s in cluster.instances] + + def still_running(): + return any([s for s in cluster.instances if s.is_available]) + + cnt = 0 + while cnt < 10: + print(f"instances: {instances()}") + await asyncio.sleep(3) + cnt += 1 + if cnt == 3: + if len(cluster.instances) > 1: + cluster.instances[0].send_message_nowait("go") + + for s in cluster.instances: + cluster.stop_instance(s) + + print(f"instances: {instances()}") + + cnt = 0 + while cnt < 10 and still_running(): + print(f"instances: {instances()}") + await asyncio.sleep(1) + + print(f"instances: {instances()}") + + +asyncio.run(main(None)) diff --git a/examples/simple-service-poc/simple_service.py b/examples/simple-service-poc/simple_service.py new file mode 100644 index 000000000..2a5fa1fd9 --- /dev/null +++ b/examples/simple-service-poc/simple_service.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python3 +""" +the requestor agent controlling and interacting with the "simple service" +""" +import asyncio +from datetime import datetime, timedelta, timezone +import pathlib +import random +import string +import sys + + +from yapapi import ( + NoPaymentAccountError, + __version__ as yapapi_version, + windows_event_loop_fix, +) +from yapapi.executor import Golem +from yapapi.executor.services import Service, ServiceState + +from yapapi.log import enable_default_logger, log_summary, log_event_repr, pluralize # noqa +from yapapi.payload import vm + +examples_dir = pathlib.Path(__file__).resolve().parent.parent +sys.path.append(str(examples_dir)) + +from utils import ( + build_parser, + TEXT_COLOR_CYAN, + TEXT_COLOR_DEFAULT, + TEXT_COLOR_RED, + TEXT_COLOR_YELLOW, +) + +NUM_INSTANCES = 1 +STARTING_TIMEOUT = timedelta(minutes=4) + + +class SimpleService(Service): + STATS_PATH = "/golem/out/stats" + PLOT_INFO_PATH = "/golem/out/plot" + SIMPLE_SERVICE = "/golem/run/simple_service.py" + SIMPLE_SERVICE_CTL = "/golem/run/simulate_observations_ctl.py" + + @staticmethod + async def get_payload(): + return await vm.repo( + image_hash="8b11df59f84358d47fc6776d0bb7290b0054c15ded2d6f54cf634488", + min_mem_gib=0.5, + min_storage_gib=2.0, + ) + + async def start(self): + # handler responsible for starting the service + self._ctx.run(self.SIMPLE_SERVICE_CTL, "--start") + yield self._ctx.commit() + + async def run(self): + # handler responsible for providing the required interactions while the service is running + while True: + await asyncio.sleep(10) + self._ctx.run(self.SIMPLE_SERVICE, "--stats") # idx 0 + self._ctx.run(self.SIMPLE_SERVICE, "--plot", "dist") # idx 1 + + future_results = yield self._ctx.commit() + results = await future_results + stats = results[0].stdout.strip() + plot = results[1].stdout.strip().strip('"') + + print(f"{TEXT_COLOR_CYAN}stats: {stats}{TEXT_COLOR_DEFAULT}") + + plot_filename = "".join(random.choice(string.ascii_letters) for _ in range(10)) + ".png" + print( + f"{TEXT_COLOR_CYAN}downloading plot: {plot} to {plot_filename}{TEXT_COLOR_DEFAULT}" + ) + self._ctx.download_file( + plot, str(pathlib.Path(__file__).resolve().parent / plot_filename) + ) + + steps = self._ctx.commit() + yield steps + + async def shutdown(self): + # handler reponsible for executing operations on shutdown + self._ctx.run(self.SIMPLE_SERVICE_CTL, "--stop") + yield self._ctx.commit() + + +async def main(subnet_tag, driver=None, network=None): + async with Golem( + budget=1.0, + subnet_tag=subnet_tag, + driver=driver, + network=network, + event_consumer=log_summary(log_event_repr), + ) as golem: + + print( + f"yapapi version: {TEXT_COLOR_YELLOW}{yapapi_version}{TEXT_COLOR_DEFAULT}\n" + f"Using subnet: {TEXT_COLOR_YELLOW}{subnet_tag}{TEXT_COLOR_DEFAULT}, " + f"payment driver: {TEXT_COLOR_YELLOW}{golem.driver}{TEXT_COLOR_DEFAULT}, " + f"and network: {TEXT_COLOR_YELLOW}{golem.network}{TEXT_COLOR_DEFAULT}\n" + ) + + commissioning_time = datetime.now() + + print( + f"{TEXT_COLOR_YELLOW}starting {pluralize(NUM_INSTANCES, 'instance')}{TEXT_COLOR_DEFAULT}" + ) + + # start the service + + cluster = await golem.run_service( + SimpleService, + num_instances=NUM_INSTANCES, + expiration=datetime.now(timezone.utc) + timedelta(minutes=120), + ) + + # helper functions to display / filter instances + + def instances(): + return [(s.provider_name, s.state.value) for s in cluster.instances] + + def still_running(): + return any([s for s in cluster.instances if s.is_available]) + + def still_starting(): + return len(cluster.instances) < NUM_INSTANCES or any( + [s for s in cluster.instances if s.state == ServiceState.starting] + ) + + # wait until instances are started + + while still_starting() and datetime.now() < commissioning_time + STARTING_TIMEOUT: + print(f"instances: {instances()}") + await asyncio.sleep(5) + + if still_starting(): + raise Exception(f"Failed to start instances before {STARTING_TIMEOUT} elapsed :( ...") + + print("All instances started :)") + + # allow the service to run for a short while + # (and allowing its requestor-end handlers to interact with it) + + start_time = datetime.now() + + while datetime.now() < start_time + timedelta(minutes=2): + print(f"instances: {instances()}") + await asyncio.sleep(5) + + print(f"{TEXT_COLOR_YELLOW}stopping instances{TEXT_COLOR_DEFAULT}") + cluster.stop() + + # wait for instances to stop + + cnt = 0 + while cnt < 10 and still_running(): + print(f"instances: {instances()}") + await asyncio.sleep(5) + + print(f"instances: {instances()}") + + +if __name__ == "__main__": + parser = build_parser( + "A very simple / POC example of a service running on Golem, utilizing the VM runtime" + ) + now = datetime.now().strftime("%Y-%m-%d_%H.%M.%S") + parser.set_defaults(log_file=f"simple-service-yapapi-{now}.log") + args = parser.parse_args() + + # This is only required when running on Windows with Python prior to 3.8: + windows_event_loop_fix() + + enable_default_logger( + log_file=args.log_file, + debug_activity_api=True, + debug_market_api=True, + debug_payment_api=True, + ) + + loop = asyncio.get_event_loop() + task = loop.create_task( + main(subnet_tag=args.subnet_tag, driver=args.driver, network=args.network) + ) + + try: + loop.run_until_complete(task) + except NoPaymentAccountError as e: + handbook_url = ( + "https://handbook.golem.network/requestor-tutorials/" + "flash-tutorial-of-requestor-development" + ) + print( + f"{TEXT_COLOR_RED}" + f"No payment account initialized for driver `{e.required_driver}` " + f"and network `{e.required_network}`.\n\n" + f"See {handbook_url} on how to initialize payment accounts for a requestor node." + f"{TEXT_COLOR_DEFAULT}" + ) + except KeyboardInterrupt: + print( + f"{TEXT_COLOR_YELLOW}" + "Shutting down gracefully, please wait a short while " + "or press Ctrl+C to exit immediately..." + f"{TEXT_COLOR_DEFAULT}" + ) + task.cancel() + try: + loop.run_until_complete(task) + print( + f"{TEXT_COLOR_YELLOW}Shutdown completed, thank you for waiting!{TEXT_COLOR_DEFAULT}" + ) + except (asyncio.CancelledError, KeyboardInterrupt): + pass diff --git a/examples/simple-service-poc/simple_service/README.md b/examples/simple-service-poc/simple_service/README.md new file mode 100644 index 000000000..692f5156a --- /dev/null +++ b/examples/simple-service-poc/simple_service/README.md @@ -0,0 +1,38 @@ +This directory contains files used to construct the application Docker image +that's then converted to a GVMI file (a Golem Virtual Machine Image file) and uploaded +to the Yagna image repository. + +All Python scripts here are run within a VM on the Provider's end. + +The example (`../simple_service.py`) already contains the appropriate image hash +but if you'd like to experiment with it, feel free to re-build it. + +## Building the image + +You'll need: + +* Docker: https://www.docker.com/products/docker-desktop +* gvmkit-build: `pip install gvmkit-build` + +Once you have those installed, run the following from this directory: + +```bash +docker build -f simple_service.Dockerfile -t simple-service . +gvmkit-build simple-service:latest +gvmkit-build simple-service:latest --push +``` + +Note the hash link that's presented after the upload finishes. + +e.g. `b742b6cb04123d07bacb36a2462f8b2347b20c32223c1ac49664635f` + +and update the service's `get_payload` method to point to this image: + +```python + async def get_payload(): + return await vm.repo( + image_hash="b742b6cb04123d07bacb36a2462f8b2347b20c32223c1ac49664635f", + min_mem_gib=0.5, + min_storage_gib=2.0, + ) +``` diff --git a/examples/simple-service-poc/simple_service/simple_service.Dockerfile b/examples/simple-service-poc/simple_service/simple_service.Dockerfile new file mode 100644 index 000000000..8626fd87e --- /dev/null +++ b/examples/simple-service-poc/simple_service/simple_service.Dockerfile @@ -0,0 +1,9 @@ +FROM python:3.8-slim +VOLUME /golem/in /golem/out +COPY simple_service.py /golem/run/simple_service.py +COPY simulate_observations.py /golem/run/simulate_observations.py +COPY simulate_observations_ctl.py /golem/run/simulate_observations_ctl.py +RUN pip install numpy matplotlib +RUN chmod +x /golem/run/* +RUN /golem/run/simple_service.py --init +ENTRYPOINT ["sh"] diff --git a/examples/simple-service-poc/simple_service/simple_service.py b/examples/simple-service-poc/simple_service/simple_service.py new file mode 100644 index 000000000..cdde5cf98 --- /dev/null +++ b/examples/simple-service-poc/simple_service/simple_service.py @@ -0,0 +1,131 @@ +#!/usr/local/bin/python +""" +a very basic "stub" that exposes a few commands of an imagined, very simple CLI-based +service that is able to accumulate some linear, time-based values and present it stats +(characteristics of the statistical distribution of the data collected so far) or provide +distribution and time-series plots of the collected data. + +[ part of the VM image that's deployed by the runtime on the Provider's end. ] +""" +import argparse +from datetime import datetime +import enum +import contextlib +import json +import matplotlib.pyplot as plt +import numpy +import random +import sqlite3 +import string +from pathlib import Path + +DB_PATH = Path(__file__).absolute().parent / "service.db" +PLOT_PATH = Path("/golem/out").absolute() + + +class PlotType(enum.Enum): + time = "time" + dist = "dist" + + +@contextlib.contextmanager +def _connect_db(): + db = sqlite3.connect(DB_PATH) + db.row_factory = sqlite3.Row + try: + yield db.cursor() + db.commit() + finally: + db.close() + + +def init(): + with _connect_db() as db: + db.execute( + "create table observations(" + "id integer primary key autoincrement not null, " + "val float not null," + "time_added timestamp default current_timestamp not null" + ")" + ) + + +def add(val): + with _connect_db() as db: + db.execute("insert into observations (val) values (?)", [val]) + + +def plot(plot_type): + data = _get_data() + + if not data: + print(json.dumps("")) + return + + y = [r["val"] for r in data] + + if plot_type == PlotType.dist.value: + plt.hist(y) + elif plot_type == PlotType.time.value: + x = [datetime.strptime(r["time_added"], "%Y-%m-%d %H:%M:%S") for r in data] + plt.plot(x, y) + + plot_filename = PLOT_PATH / ( + "".join(random.choice(string.ascii_letters) for _ in range(10)) + ".png" + ) + plt.savefig(plot_filename) + print(json.dumps(str(plot_filename))) + + +def dump(): + print(json.dumps(_get_data())) + + +def _get_data(): + with _connect_db() as db: + db.execute("select val, time_added from observations order by time_added asc") + return list(map(dict, db.fetchall())) + + +def _get_stats(data=None): + data = data or [r["val"] for r in _get_data()] + return { + "min": min(data) if data else None, + "max": max(data) if data else None, + "median": numpy.median(data) if data else None, + "mean": numpy.mean(data) if data else None, + "variance": numpy.var(data) if data else None, + "std dev": numpy.std(data) if data else None, + "size": len(data), + } + + +def stats(): + print(json.dumps(_get_stats())) + + +def get_arg_parser(): + parser = argparse.ArgumentParser(description="simple service") + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument("--add", type=float) + group.add_argument("--init", action="store_true") + group.add_argument("--plot", choices=[pt.value for pt in list(PlotType)]) + group.add_argument("--dump", action="store_true") + group.add_argument("--stats", action="store_true") + return parser + + +if __name__ == "__main__": + arg_parser = get_arg_parser() + args = arg_parser.parse_args() + + if args.init: + init() + elif args.add: + add(args.add) + elif args.plot: + plot(args.plot) + elif args.dump: + dump() + elif args.stats: + stats() diff --git a/examples/simple-service-poc/simple_service/simulate_observations.py b/examples/simple-service-poc/simple_service/simulate_observations.py new file mode 100644 index 000000000..430807dc1 --- /dev/null +++ b/examples/simple-service-poc/simple_service/simulate_observations.py @@ -0,0 +1,26 @@ +#!/usr/local/bin/python +""" +the "hello world" service here just adds randomized numbers with normal distribution + +in a real-world example, this could be e.g. a thermometer connected to the provider's +machine providing its inputs into the database or some other piece of information +from some external source that changes over time and which can be expressed as a +singular value + +[ part of the VM image that's deployed by the runtime on the Provider's end. ] +""" +import os +from pathlib import Path +import random +import time + +MU = 14 +SIGMA = 3 + +SERVICE_PATH = Path(__file__).absolute().parent / "simple_service.py" + + +while True: + v = random.normalvariate(MU, SIGMA) + os.system(f"{SERVICE_PATH} --add {v}") + time.sleep(1) diff --git a/examples/simple-service-poc/simple_service/simulate_observations_ctl.py b/examples/simple-service-poc/simple_service/simulate_observations_ctl.py new file mode 100644 index 000000000..274063a54 --- /dev/null +++ b/examples/simple-service-poc/simple_service/simulate_observations_ctl.py @@ -0,0 +1,35 @@ +#!/usr/local/bin/python +""" +a helper, control script that starts and stops our example `simulate_observations` service + +[ part of the VM image that's deployed by the runtime on the Provider's end. ] +""" +import argparse +import os +import subprocess +import signal + +PIDFILE = "/var/run/simulate_observations.pid" +SCRIPT_FILE = "/golem/run/simulate_observations.py" + +parser = argparse.ArgumentParser("start/stop simulation") +group = parser.add_mutually_exclusive_group(required=True) +group.add_argument("--start", action="store_true") +group.add_argument("--stop", action="store_true") + +args = parser.parse_args() + +if args.start: + if os.path.exists(PIDFILE): + raise Exception(f"Cannot start process, {PIDFILE} exists.") + p = subprocess.Popen([SCRIPT_FILE]) + with open(PIDFILE, "w") as pidfile: + pidfile.write(str(p.pid)) +elif args.stop: + if not os.path.exists(PIDFILE): + raise Exception(f"Could not find pidfile: {PIDFILE}.") + with open(PIDFILE, "r") as pidfile: + pid = int(pidfile.read()) + + os.kill(pid, signal.SIGKILL) + os.remove(PIDFILE) diff --git a/examples/turbogeth/turbogeth.py b/examples/turbogeth/turbogeth.py deleted file mode 100644 index 29dfa8916..000000000 --- a/examples/turbogeth/turbogeth.py +++ /dev/null @@ -1,31 +0,0 @@ -import asyncio - -from dataclasses import dataclass - -from yapapi.props.builder import DemandBuilder -from yapapi.props.base import prop, constraint -from yapapi.props import inf - -from yapapi.payload import Payload - - -TURBOGETH_RUNTIME_NAME = "turbogeth-managed" -PROP_TURBOGETH_RPC_PORT = "golem.srv.app.eth.rpc-port" - - -@dataclass -class TurbogethPayload(Payload): - rpc_port: int = prop(PROP_TURBOGETH_RPC_PORT, None) - - runtime: str = constraint(inf.INF_RUNTIME_NAME, "=", TURBOGETH_RUNTIME_NAME) - min_mem_gib: float = constraint(inf.INF_MEM, ">=", 16) - min_storage_gib: float = constraint(inf.INF_STORAGE, ">=", 1024) - - -async def main(): - builder = DemandBuilder() - await builder.decorate(TurbogethPayload(rpc_port=1234)) - print(builder) - - -asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index 43e5ed3fb..854ad6080 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ colorama = "^0.4.4" # would not work: see https://github.com/python-poetry/poetry/issues/129. goth = { version = "^0.3", optional = true, python = "^3.8.0" } Deprecated = "^1.2.12" +python-statemachine = "^0.8.0" [tool.poetry.extras] cli = ['fire', 'rich'] diff --git a/tests/goth/test_run_blender.py b/tests/goth/test_run_blender.py index a5167ab93..24c005fca 100644 --- a/tests/goth/test_run_blender.py +++ b/tests/goth/test_run_blender.py @@ -51,7 +51,7 @@ async def assert_all_tasks_started(output_lines: EventStream[str]): async def assert_all_tasks_computed(output_lines: EventStream[str]): """Assert that for every task a line with `Task computed by provider` will appear.""" - await assert_all_tasks_processed("computed by provider", output_lines) + await assert_all_tasks_processed("finished by provider", output_lines) async def assert_all_invoices_accepted(output_lines: EventStream[str]): diff --git a/tests/goth/test_run_yacat.py b/tests/goth/test_run_yacat.py index d5f758758..d92c71265 100644 --- a/tests/goth/test_run_yacat.py +++ b/tests/goth/test_run_yacat.py @@ -51,7 +51,7 @@ async def assert_all_tasks_started(output_lines: EventStream[str]): async def assert_all_tasks_computed(output_lines: EventStream[str]): """Assert that for every task a line with `Task computed by provider` will appear.""" - await assert_all_tasks_processed("computed by provider", output_lines) + await assert_all_tasks_processed("finished by provider", output_lines) async def assert_all_invoices_accepted(output_lines: EventStream[str]): diff --git a/yapapi/executor/__init__.py b/yapapi/executor/__init__.py index 302d4d0c0..84c84249d 100644 --- a/yapapi/executor/__init__.py +++ b/yapapi/executor/__init__.py @@ -22,7 +22,9 @@ Optional, Set, Tuple, + Type, TypeVar, + TYPE_CHECKING, Union, cast, overload, @@ -48,6 +50,9 @@ from ..rest.market import OfferProposal, Subscription from ..storage import gftp from ._smartq import Consumer, Handle, SmartQueue + +if TYPE_CHECKING: + from .services import Cluster, Service from .strategy import ( DecreaseScoreForUnconfirmedAgreement, LeastExpensiveLinearPayuMS, @@ -259,57 +264,61 @@ def report_shutdown(*exc_info): self._storage_manager = await stack.enter_async_context(gftp.provider()) + stack.push_async_exit(self._shutdown) + return self except: await self.__aexit__(*sys.exc_info()) raise - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def _shutdown(self, *exc_info): + """Shutdown this Golem instance.""" + # Importing this at the beginning would cause circular dependencies from ..log import pluralize - try: - logger.debug("Golem is shutting down...") - # Wait until all computations are finished - await asyncio.gather(*[job.finished.wait() for job in self._jobs]) + logger.info("Golem is shutting down...") + # Wait until all computations are finished + await asyncio.gather(*[job.finished.wait() for job in self._jobs]) + logger.info("All jobs have finished") - self._payment_closing = True + self._payment_closing = True - for task in self._services: - if task is not self._process_invoices_job: - task.cancel() + for task in self._services: + if task is not self._process_invoices_job: + task.cancel() - if self._process_invoices_job and not any( - True for job in self._jobs if job.agreements_pool.confirmed > 0 - ): - logger.debug("No need to wait for invoices.") - self._process_invoices_job.cancel() + if self._process_invoices_job and not any( + True for job in self._jobs if job.agreements_pool.confirmed > 0 + ): + logger.debug("No need to wait for invoices.") + self._process_invoices_job.cancel() - try: - logger.info("Waiting for Golem services to finish...") - _, pending = await asyncio.wait( - self._services, timeout=10, return_when=asyncio.ALL_COMPLETED - ) - if pending: - logger.debug( - "%s still running: %s", pluralize(len(pending), "service"), pending - ) - except Exception: - logger.debug("Got error when waiting for services to finish", exc_info=True) + try: + logger.info("Waiting for Golem services to finish...") + _, pending = await asyncio.wait( + self._services, timeout=10, return_when=asyncio.ALL_COMPLETED + ) + if pending: + logger.debug("%s still running: %s", pluralize(len(pending), "service"), pending) + except Exception: + logger.debug("Got error when waiting for services to finish", exc_info=True) + + if self._agreements_to_pay and self._process_invoices_job: + logger.info( + "%s still unpaid, waiting for invoices...", + pluralize(len(self._agreements_to_pay), "agreement"), + ) + await asyncio.wait( + {self._process_invoices_job}, timeout=30, return_when=asyncio.ALL_COMPLETED + ) + if self._agreements_to_pay: + logger.warning("Unpaid agreements: %s", self._agreements_to_pay) - if self._agreements_to_pay and self._process_invoices_job: - logger.info( - "%s still unpaid, waiting for invoices...", - pluralize(len(self._agreements_to_pay), "agreement"), - ) - await asyncio.wait( - {self._process_invoices_job}, timeout=30, return_when=asyncio.ALL_COMPLETED - ) - if self._agreements_to_pay: - logger.warning("Unpaid agreements: %s", self._agreements_to_pay) + await asyncio.gather(*[job.finished.wait() for job in self._jobs]) - finally: - await self._stack.aclose() + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self._stack.aclose() async def _create_allocations(self) -> rest.payment.MarketDecoration: @@ -566,6 +575,33 @@ async def execute_tasks( async for t in executor.submit(worker, data): yield t + async def run_service( + self, + service_class: Type["Service"], + num_instances: int = 1, + payload: Optional[Payload] = None, + expiration: Optional[datetime] = None, + ) -> "Cluster": + from .services import Cluster # avoid circular dependency + + payload = payload or await service_class.get_payload() + + if not payload: + raise ValueError( + f"No payload returned from {service_class.__name__}.get_payload() nor given in the `payload` argument." + ) + + cluster = Cluster( + engine=self, + service_class=service_class, + payload=payload, + num_instances=num_instances, + expiration=expiration, + ) + await self._stack.enter_async_context(cluster) + cluster.spawn_instances() + return cluster + class Job: """Functionality related to a single job.""" diff --git a/yapapi/executor/ctx.py b/yapapi/executor/ctx.py index 0e225ff81..c4fbc15a1 100644 --- a/yapapi/executor/ctx.py +++ b/yapapi/executor/ctx.py @@ -51,9 +51,13 @@ def timeout(self) -> Optional[timedelta]: return None -class _InitStep(Work): +class _Deploy(Work): def register(self, commands: CommandContainer): commands.deploy() + + +class _Start(Work): + def register(self, commands: CommandContainer): commands.start() @@ -277,6 +281,7 @@ def __init__( node_info: NodeInfo, storage: StorageProvider, emitter: Optional[Callable[[StorageEvent], None]] = None, + implicit_init: bool = True, ): self.id = ctx_id self._node_info = node_info @@ -284,6 +289,7 @@ def __init__( self._pending_steps: List[Work] = [] self._started: bool = False self._emitter: Optional[Callable[[StorageEvent], None]] = emitter + self._implicit_init = implicit_init @property def provider_name(self) -> Optional[str]: @@ -291,13 +297,22 @@ def provider_name(self) -> Optional[str]: return self._node_info.name def __prepare(self): - if not self._started: - self._pending_steps.append(_InitStep()) + if not self._started and self._implicit_init: + self.deploy() + self.start() self._started = True def begin(self): pass + def deploy(self): + self._implicit_init = False + self._pending_steps.append(_Deploy()) + + def start(self): + self._implicit_init = False + self._pending_steps.append(_Start()) + def send_json(self, json_path: str, data: dict): """Schedule sending JSON data to the provider. diff --git a/yapapi/executor/services.py b/yapapi/executor/services.py new file mode 100644 index 000000000..f25461583 --- /dev/null +++ b/yapapi/executor/services.py @@ -0,0 +1,435 @@ +import asyncio +import itertools +from dataclasses import dataclass, field +from datetime import timedelta, datetime, timezone +import enum +import logging +from typing import Any, AsyncContextManager, List, Optional, Set, Type +import statemachine # type: ignore +import sys + +if sys.version_info >= (3, 7): + from contextlib import AsyncExitStack +else: + from async_exit_stack import AsyncExitStack # type: ignore + +if sys.version_info >= (3, 8): + from typing import Final +else: + from typing_extensions import Final + + +from .. import rest +from ..executor import Golem, Job, Task +from ..executor.ctx import WorkContext +from ..payload import Payload +from ..props import NodeInfo +from . import events + +logger = logging.getLogger(__name__) + +# current default for yagna providers as of yagna 0.6.x +DEFAULT_SERVICE_EXPIRATION: Final[timedelta] = timedelta(minutes=175) + +cluster_ids = itertools.count(1) + + +class ServiceState(statemachine.StateMachine): + """ + State machine describing the state and lifecycle of a Service instance. + """ + + # states + starting = statemachine.State("starting", initial=True) + running = statemachine.State("running") + stopping = statemachine.State("stopping") + terminated = statemachine.State("terminated") + unresponsive = statemachine.State("unresponsive") + + # transitions + ready = starting.to(running) + stop = running.to(stopping) + terminate = terminated.from_(starting, running, stopping, terminated) + mark_unresponsive = unresponsive.from_(starting, running, stopping, terminated) + + lifecycle = ready | stop | terminate + + # just a helper set of states in which the service can be interacted-with + AVAILABLE = (starting, running, stopping) + + +@dataclass +class ServiceSignal: + """ + Simple container to carry information between the client code and the Service instance. + """ + + message: Any + response_to: Optional["ServiceSignal"] = None + + +class Service: + """ + Base Service class to be extended by application developers to define their own, + specialized Service specifications. + """ + + def __init__(self, cluster: "Cluster", ctx: WorkContext): + self._cluster: "Cluster" = cluster + self._ctx: WorkContext = ctx + + self.__inqueue: asyncio.Queue[ServiceSignal] = asyncio.Queue() + self.__outqueue: asyncio.Queue[ServiceSignal] = asyncio.Queue() + self.post_init() + + def post_init(self): + pass + + @property + def id(self): + return self._ctx.id + + @property + def provider_name(self): + return self._ctx.provider_name + + def __repr__(self): + return f"<{self.__class__.__name__}: {self.id}>" + + async def send_message(self, message: Any = None): + await self.__inqueue.put(ServiceSignal(message=message)) + + def send_message_nowait(self, message: Optional[Any] = None): + self.__inqueue.put_nowait(ServiceSignal(message=message)) + + async def receive_message(self) -> ServiceSignal: + return await self.__outqueue.get() + + def receive_message_nowait(self) -> Optional[ServiceSignal]: + try: + return self.__outqueue.get_nowait() + except asyncio.QueueEmpty: + return None + + async def _listen(self) -> ServiceSignal: + return await self.__inqueue.get() + + def _listen_nowait(self) -> Optional[ServiceSignal]: + try: + return self.__inqueue.get_nowait() + except asyncio.QueueEmpty: + return None + + async def _respond(self, message: Optional[Any], response_to: Optional[ServiceSignal] = None): + await self.__outqueue.put(ServiceSignal(message=message, response_to=response_to)) + + def _respond_nowait(self, message: Optional[Any], response_to: Optional[ServiceSignal] = None): + self.__outqueue.put_nowait(ServiceSignal(message=message, response_to=response_to)) + + @staticmethod + async def get_payload() -> Optional[Payload]: + """Return the payload (runtime) definition for this service. + + If `get_payload` is not implemented, the payload will need to be provided in the + `Golem.run_service` call. + """ + pass + + async def start(self): + self._ctx.deploy() + self._ctx.start() + yield self._ctx.commit() + + async def run(self): + while True: + await asyncio.sleep(10) + yield + + async def shutdown(self): + yield + + @property + def is_available(self): + return self._cluster.get_state(self) in ServiceState.AVAILABLE + + @property + def state(self): + return self._cluster.get_state(self) + + +class ControlSignal(enum.Enum): + """ + Control signal, used to request an instance's state change from the controlling Cluster. + """ + + stop = "stop" + + +@dataclass +class ServiceInstance: + """Cluster's service instance. + + A binding between the instance of the Service, its control queue and its state, + used by the Cluster to hold the complete state of each instance of a service. + """ + + service: Service + control_queue: "asyncio.Queue[ControlSignal]" = field(default_factory=asyncio.Queue) + service_state: ServiceState = field(default_factory=ServiceState) + + @property + def state(self) -> ServiceState: + return self.service_state.current_state + + def get_control_signal(self) -> Optional[ControlSignal]: + try: + return self.control_queue.get_nowait() + except asyncio.QueueEmpty: + return None + + def send_control_signal(self, signal: ControlSignal): + self.control_queue.put_nowait(signal) + + +class Cluster(AsyncContextManager): + """ + Golem's sub-engine used to spawn and control instances of a single Service. + """ + + def __init__( + self, + engine: "Golem", + service_class: Type[Service], + payload: Payload, + num_instances: int = 1, + expiration: Optional[datetime] = None, + ): + self.id = str(next(cluster_ids)) + + self._engine = engine + self._service_class = service_class + self._payload = payload + self._num_instances = num_instances + self._expiration = expiration or datetime.now(timezone.utc) + DEFAULT_SERVICE_EXPIRATION + + self.__instances: List[ServiceInstance] = [] + """List of Service instances""" + + self._task_ids = itertools.count(1) + + self._stack = AsyncExitStack() + + def __repr__(self): + return f"Cluster {self.id}: {self._num_instances}x[Service: {self._service_class.__name__}, Payload: {self._payload}]" + + async def __aenter__(self): + self.__services: Set[asyncio.Task] = set() + """Asyncio tasks running within this cluster""" + + logger.debug("Starting new %s", self) + + self._job = Job(self._engine, expiration_time=self._expiration, payload=self._payload) + self._engine.add_job(self._job) + + loop = asyncio.get_event_loop() + self.__services.add(loop.create_task(self._job.find_offers())) + + async def agreements_pool_cycler(): + # shouldn't this be part of the Agreement pool itself? (or a task within Job?) + while True: + await asyncio.sleep(2) + await self._job.agreements_pool.cycle() + + self.__services.add(loop.create_task(agreements_pool_cycler())) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + logger.debug("%s is shutting down...", self) + + # TODO: should be different if we stop due to an error + termination_reason = { + "message": "Successfully finished all work", + "golem.requestor.code": "Success", + } + + try: + logger.debug("Terminating agreements...") + await self._job.agreements_pool.terminate_all(reason=termination_reason) + except Exception: + logger.debug("Couldn't terminate agreements", exc_info=True) + + for task in self.__services: + if not task.done(): + logger.debug("Cancelling task: %s", task) + task.cancel() + await asyncio.gather(*self.__services, return_exceptions=True) + + self._engine.finalize_job(self._job) + + def emit(self, event: events.Event) -> None: + self._engine.emit(event) + + @property + def instances(self) -> List[Service]: + return [i.service for i in self.__instances] + + def __get_service_instance(self, service: Service) -> ServiceInstance: + for i in self.__instances: + if i.service == service: + return i + assert False, f"No instance found for {service}" + + def get_state(self, service: Service) -> ServiceState: + instance = self.__get_service_instance(service) + return instance.state + + @staticmethod + def _get_handler(instance: ServiceInstance): + _handlers = { + ServiceState.starting: instance.service.start, + ServiceState.running: instance.service.run, + ServiceState.stopping: instance.service.shutdown, + } + handler = _handlers.get(instance.state, None) + if handler: + return handler() + + async def _run_instance(self, ctx: WorkContext): + loop = asyncio.get_event_loop() + instance = ServiceInstance(service=self._service_class(self, ctx)) + self.__instances.append(instance) + + logger.info(f"{instance.service} commissioned") + + handler = self._get_handler(instance) + batch = None + + while handler: + try: + if batch: + r = yield batch + fr = loop.create_future() + fr.set_result(await r) + batch = await handler.asend(fr) + else: + batch = await handler.__anext__() + except StopAsyncIteration: + instance.service_state.lifecycle() + handler = self._get_handler(instance) + batch = None + logger.debug(f"{instance.service} state changed to {instance.state.value}") + + # TODO + # + # two potential issues: + # * awaiting a batch makes us lose an ability to interpret a signal (await on generator won't return) + # * we may be losing a `batch` when we act on the control signal + # + # potential solution: + # * use `aiostream.stream.merge` + + ctl = instance.get_control_signal() + if ctl == ControlSignal.stop: + if instance.state == ServiceState.running: + instance.service_state.stop() + else: + instance.service_state.terminate() + + logger.debug(f"{instance.service} state changed to {instance.state.value}") + + handler = self._get_handler(instance) + batch = None + + logger.info(f"{instance.service} decomissioned") + + async def spawn_instance(self): + logger.debug("spawning instance within %s", self) + spawned = False + + async def start_worker(agreement: rest.market.Agreement, node_info: NodeInfo) -> None: + nonlocal spawned + self.emit(events.WorkerStarted(agr_id=agreement.id)) + try: + act = await self._engine.create_activity(agreement.id) + except Exception: + self.emit( + events.ActivityCreateFailed( + agr_id=agreement.id, exc_info=sys.exc_info() # type: ignore + ) + ) + self.emit(events.WorkerFinished(agr_id=agreement.id)) + raise + + async with act: + spawned = True + self.emit(events.ActivityCreated(act_id=act.id, agr_id=agreement.id)) + self._engine.approve_agreement_payments(agreement.id) + work_context = WorkContext( + act.id, node_info, self._engine.storage_manager, emitter=self.emit + ) + task_id = f"{self.id}:{next(self._task_ids)}" + self.emit( + events.TaskStarted( + agr_id=agreement.id, + task_id=task_id, + task_data=f"Service: {self._service_class.__name__}", + ) + ) + + try: + instance_batches = self._run_instance(work_context) + try: + await self._engine.process_batches(agreement.id, act, instance_batches) + except StopAsyncIteration: + pass + + self.emit( + events.TaskFinished( + agr_id=agreement.id, + task_id=task_id, + ) + ) + self.emit(events.WorkerFinished(agr_id=agreement.id)) + except Exception: + self.emit( + events.WorkerFinished( + agr_id=agreement.id, exc_info=sys.exc_info() # type: ignore + ) + ) + raise + finally: + await self._engine.accept_payment_for_agreement(agreement.id) + + loop = asyncio.get_event_loop() + + while not spawned: + await asyncio.sleep(1.0) + task = await self._job.agreements_pool.use_agreement( + lambda agreement, node: loop.create_task(start_worker(agreement, node)) + ) + if task: + await task + + def stop_instance(self, service: Service): + instance = self.__get_service_instance(service) + instance.send_control_signal(ControlSignal.stop) + + def spawn_instances(self, num_instances: Optional[int] = None) -> None: + """ + Spawn new instances within this Cluster. + + :param num_instances: number of instances to commission. + if not given, spawns the number that the Cluster has been initialized with. + """ + if num_instances: + self._num_instances += num_instances + else: + num_instances = self._num_instances + + loop = asyncio.get_event_loop() + for i in range(num_instances): + loop.create_task(self.spawn_instance()) + + def stop(self): + """Signal the whole cluster to stop.""" + for s in self.instances: + self.stop_instance(s) diff --git a/yapapi/log.py b/yapapi/log.py index adc8172f4..a12e41679 100644 --- a/yapapi/log.py +++ b/yapapi/log.py @@ -418,7 +418,7 @@ def _handle(self, event: events.Event): provider_info = self.agreement_provider_info[event.agr_id] data = self.task_data[event.task_id] self.logger.info( - "Task computed by provider '%s', task data: %s", + "Task finished by provider '%s', task data: %s", provider_info.name, str_capped(data, 200), ) diff --git a/yapapi/rest/activity.py b/yapapi/rest/activity.py index d4abdbb97..65c55ec7f 100644 --- a/yapapi/rest/activity.py +++ b/yapapi/rest/activity.py @@ -225,8 +225,10 @@ async def __aiter__(self) -> AsyncIterator[events.CommandEventContext]: batch_id = self._batch_id last_idx = self._size - 1 + evt_src_endpoint = f"{host}/activity/{activity_id}/exec/{batch_id}" + async with sse_client.EventSource( - f"{host}/activity/{activity_id}/exec/{batch_id}", + evt_src_endpoint, headers=headers, timeout=self.seconds_left(), ) as event_source: