diff --git a/golem-cluster.override.3-disable-stats.yaml b/golem-cluster.override.3-disable-stats.yaml index 127ea2a2..87dbbd58 100644 --- a/golem-cluster.override.3-disable-stats.yaml +++ b/golem-cluster.override.3-disable-stats.yaml @@ -1,2 +1,3 @@ provider: - enable_registry_stats: false + params: + enable_registry_stats: false diff --git a/pyproject.toml b/pyproject.toml index ce123313..ba5b6e10 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ click = "^8" pydantic = "<2" [tool.poetry.scripts] -ray-on-golem = "ray_on_golem.server.run:main" +ray-on-golem = "ray_on_golem.main:main" [tool.poetry.group.dev.dependencies] poethepoet = "^0.22.0" diff --git a/ray_on_golem/__main__.py b/ray_on_golem/__main__.py new file mode 100644 index 00000000..a9fc2409 --- /dev/null +++ b/ray_on_golem/__main__.py @@ -0,0 +1,3 @@ +from ray_on_golem.main import main + +main() diff --git a/ray_on_golem/main.py b/ray_on_golem/main.py new file mode 100644 index 00000000..824e1ea6 --- /dev/null +++ b/ray_on_golem/main.py @@ -0,0 +1,24 @@ +import click + +from ray_on_golem.network_stats import main as network_stats +from ray_on_golem.server import main as webserver +from ray_on_golem.utils import prepare_tmp_dir + + +@click.group() +def cli(): + pass + + +cli.add_command(network_stats) +cli.add_command(webserver) + + +def main(): + prepare_tmp_dir() + + cli() + + +if __name__ == "__main__": + main() diff --git a/ray_on_golem/network_stats/__init__.py b/ray_on_golem/network_stats/__init__.py new file mode 100644 index 00000000..6d1f632c --- /dev/null +++ b/ray_on_golem/network_stats/__init__.py @@ -0,0 +1,3 @@ +from ray_on_golem.network_stats.main import main + +__all__ = ("main",) diff --git a/ray_on_golem/network_stats/main.py b/ray_on_golem/network_stats/main.py new file mode 100644 index 00000000..dd26fc62 --- /dev/null +++ b/ray_on_golem/network_stats/main.py @@ -0,0 +1,74 @@ +import asyncio +import logging +import logging.config +from contextlib import asynccontextmanager +from typing import Dict + +import click +import yaml + +from ray_on_golem.network_stats.services import NetworkStatsService +from ray_on_golem.provider.node_provider import GolemNodeProvider +from ray_on_golem.server.services import YagnaService +from ray_on_golem.server.settings import LOGGING_CONFIG, YAGNA_PATH +from ray_on_golem.utils import prepare_tmp_dir + + +@click.command( + name="network-stats", + short_help="Run Golem Network statistics.", + help="Run Golem Network statistics based on given cluster config file.", + context_settings={"show_default": True}, +) +@click.argument("cluster-config-file", type=click.Path(exists=True)) +@click.option( + "-d", + "--duration", + type=int, + default=5, + help="Set the duration of the statistics gathering process, in minutes.", +) +@click.option( + "--enable-logging", + is_flag=True, + default=False, + help="Enable verbose logging.", +) +def main(cluster_config_file: str, duration: int, enable_logging: bool): + if enable_logging: + logging.config.dictConfig(LOGGING_CONFIG) + + with open(cluster_config_file) as file: + config = yaml.safe_load(file.read()) + + GolemNodeProvider._apply_config_defaults(config) + + asyncio.run(_network_stats(config, duration)) + + +async def _network_stats(config: Dict, duration: int): + provider_params = config["provider"]["parameters"] + + async with network_stats_service(provider_params["enable_registry_stats"]) as stats_service: + await stats_service.run(provider_params, duration) + + +@asynccontextmanager +async def network_stats_service(registry_stats: bool) -> NetworkStatsService: + network_stats_service: NetworkStatsService = NetworkStatsService(registry_stats) + yagna_service = YagnaService( + yagna_path=YAGNA_PATH, + ) + + await yagna_service.init() + await network_stats_service.init(yagna_appkey=yagna_service.yagna_appkey) + + yield network_stats_service + + await network_stats_service.shutdown() + await yagna_service.shutdown() + + +if __name__ == "__main__": + prepare_tmp_dir() + main() diff --git a/ray_on_golem/network_stats/services/__init__.py b/ray_on_golem/network_stats/services/__init__.py new file mode 100644 index 00000000..a15706d3 --- /dev/null +++ b/ray_on_golem/network_stats/services/__init__.py @@ -0,0 +1 @@ +from ray_on_golem.network_stats.services.network_stats import NetworkStatsService diff --git a/ray_on_golem/network_stats/services/network_stats.py b/ray_on_golem/network_stats/services/network_stats.py new file mode 100644 index 00000000..f2697f4f --- /dev/null +++ b/ray_on_golem/network_stats/services/network_stats.py @@ -0,0 +1,231 @@ +import asyncio +import logging +import re +from collections import defaultdict +from datetime import timedelta +from typing import Dict, Optional, Sequence + +from golem.managers import ( + AddChosenPaymentPlatform, + BlacklistProviderIdPlugin, + Buffer, + DefaultProposalManager, + NegotiatingPlugin, + PayAllPaymentManager, + ProposalManagerPlugin, + RefreshingDemandManager, + ScoringBuffer, +) +from golem.managers.base import ProposalNegotiator +from golem.node import GolemNode +from golem.payload import PayloadSyntaxParser +from golem.resources import DemandData, Proposal +from golem.resources.proposal.exceptions import ProposalRejected +from ya_market import ApiException + +from ray_on_golem.server.models import NodeConfigData +from ray_on_golem.server.services.golem.helpers.demand_config import DemandConfigHelper +from ray_on_golem.server.services.golem.helpers.manager_stack import ManagerStackNodeConfigHelper +from ray_on_golem.server.services.golem.manager_stack import ManagerStack +from ray_on_golem.server.services.golem.provider_data import PROVIDERS_BLACKLIST + +logger = logging.getLogger(__name__) + + +class ProposalCounterPlugin(ProposalManagerPlugin): + def __init__(self) -> None: + self._count = 0 + + async def get_proposal(self) -> Proposal: + while True: + proposal: Proposal = await self._get_proposal() + self._count += 1 + return proposal + + +class StatsNegotiatingPlugin(NegotiatingPlugin): + def __init__( + self, + demand_offer_parser: Optional[PayloadSyntaxParser] = None, + proposal_negotiators: Optional[Sequence[ProposalNegotiator]] = None, + *args, + **kwargs, + ) -> None: + super().__init__(demand_offer_parser, proposal_negotiators, *args, **kwargs) + self.fails = defaultdict(int) + + async def get_proposal(self) -> Proposal: + while True: + proposal = await self._get_proposal() + + demand_data = await self._get_demand_data_from_proposal(proposal) + + try: + negotiated_proposal = await self._negotiate_proposal(demand_data, proposal) + return negotiated_proposal + except ProposalRejected as err: + self.fails[err.reason] += 1 + except Exception as err: + self.fails[str(err)] += 1 + + async def _send_demand_proposal( + self, offer_proposal: Proposal, demand_data: DemandData + ) -> Proposal: + try: + return await offer_proposal.respond( + demand_data.properties, + demand_data.constraints, + ) + except ApiException as e: + error_msg = re.sub(r"\[.*?\]", "[***]", str(e.body)) + raise RuntimeError(f"Failed to send proposal response! {e.status}: {error_msg}") from e + except asyncio.TimeoutError as e: + raise RuntimeError(f"Failed to send proposal response! Request timed out") from e + + +class StatsPluginFactory: + _stats_negotiating_plugin: StatsNegotiatingPlugin + + def __init__(self) -> None: + self._counters = {} + + def print_gathered_stats(self) -> None: + print("\nProposals count:") + [print(f"{tag}: {counter._count}") for tag, counter in self._counters.items()] + print("\nNegotiation errors:") + [ + print(f"{err}: {count}") + for err, count in sorted( + self._stats_negotiating_plugin.fails.items(), key=lambda item: item[1], reverse=True + ) + ] + + def create_negotiating_plugin(self) -> StatsNegotiatingPlugin: + self._stats_negotiating_plugin = StatsNegotiatingPlugin( + proposal_negotiators=(AddChosenPaymentPlatform(),), + ) + return self._stats_negotiating_plugin + + def create_counter_plugin(self, tag: str) -> ProposalCounterPlugin: + self._counters[tag] = ProposalCounterPlugin() + return self._counters[tag] + + +class NetworkStatsService: + def __init__(self, registry_stats: bool) -> None: + self._registry_stats = registry_stats + + self._demand_config_helper: DemandConfigHelper = DemandConfigHelper(registry_stats) + + self._golem: Optional[GolemNode] = None + self._yagna_appkey: Optional[str] = None + + self._stats_plugin_factory = StatsPluginFactory() + + async def init(self, yagna_appkey: str) -> None: + logger.info("Starting NetworkStatsService...") + + self._golem = GolemNode(app_key=yagna_appkey) + self._yagna_appkey = yagna_appkey + await self._golem.start() + + logger.info("Starting NetworkStatsService done") + + async def shutdown(self) -> None: + logger.info("Stopping NetworkStatsService...") + + await self._golem.aclose() + self._golem = None + + logger.info("Stopping NetworkStatsService done") + + async def run(self, provider_parameters: Dict, duration_minutes: int) -> None: + network: str = provider_parameters["network"] + budget: int = provider_parameters["budget"] + node_config: NodeConfigData = NodeConfigData(**provider_parameters["node_config"]) + + stack = await self._create_stack(node_config, budget, network) + await stack.start() + + print(f"Gathering stats data for {duration_minutes} minutes...") + consume_proposals_task = asyncio.create_task(self._consume_draft_proposals(stack)) + try: + await asyncio.wait( + [consume_proposals_task], + timeout=timedelta(minutes=duration_minutes).total_seconds(), + ) + finally: + consume_proposals_task.cancel() + await consume_proposals_task + + await stack.stop() + print("Gathering stats data done") + self._stats_plugin_factory.print_gathered_stats() + + async def _consume_draft_proposals(self, stack: ManagerStack) -> None: + drafts = [] + try: + while True: + draft = await stack.proposal_manager.get_draft_proposal() + drafts.append(draft) + except asyncio.CancelledError: + return + finally: + await asyncio.gather( + # FIXME better reason message + *[draft.reject(reason="No more needed") for draft in drafts], + return_exceptions=True, + ) + + async def _create_stack( + self, node_config: NodeConfigData, budget: float, network: str + ) -> ManagerStack: + stack = ManagerStack() + + payload = await self._demand_config_helper.get_payload_from_demand_config( + node_config.demand + ) + + ManagerStackNodeConfigHelper.apply_cost_management_avg_usage(stack, node_config) + ManagerStackNodeConfigHelper.apply_cost_management_hard_limits(stack, node_config) + + stack.payment_manager = PayAllPaymentManager(self._golem, budget=budget, network=network) + stack.demand_manager = RefreshingDemandManager( + self._golem, + stack.payment_manager.get_allocation, + payload, + demand_expiration_timeout=timedelta(hours=8), + ) + + plugins = [ + self._stats_plugin_factory.create_counter_plugin("Initial"), + BlacklistProviderIdPlugin(PROVIDERS_BLACKLIST.get(network, set())), + self._stats_plugin_factory.create_counter_plugin("Not blacklisted"), + ] + + for plugin_tag, plugin in stack.extra_proposal_plugins.items(): + plugins.append(plugin) + plugins.append(self._stats_plugin_factory.create_counter_plugin(f"Passed {plugin_tag}")) + + plugins.extend( + [ + ScoringBuffer( + min_size=50, + max_size=1000, + fill_at_start=True, + proposal_scorers=(*stack.extra_proposal_scorers.values(),), + update_interval=timedelta(seconds=10), + ), + self._stats_plugin_factory.create_counter_plugin("Scored"), + self._stats_plugin_factory.create_negotiating_plugin(), + self._stats_plugin_factory.create_counter_plugin("Negotiated"), + Buffer(min_size=1, max_size=50, fill_concurrency_size=10), + ] + ) + stack.proposal_manager = DefaultProposalManager( + self._golem, + stack.demand_manager.get_initial_proposal, + plugins=plugins, + ) + + return stack diff --git a/ray_on_golem/provider/node_provider.py b/ray_on_golem/provider/node_provider.py index 89f6a194..c9d0f6a8 100644 --- a/ray_on_golem/provider/node_provider.py +++ b/ray_on_golem/provider/node_provider.py @@ -14,7 +14,6 @@ from ray_on_golem.client.client import RayOnGolemClient from ray_on_golem.provider.ssh_command_runner import SSHCommandRunner from ray_on_golem.server.models import NodeConfigData, NodeId, ShutdownState -from ray_on_golem.server.run import prepare_tmp_dir from ray_on_golem.server.settings import ( LOGGING_DEBUG_PATH, RAY_ON_GOLEM_CHECK_DEADLINE, @@ -26,6 +25,7 @@ get_default_ssh_key_name, get_last_lines_from_file, is_running_on_golem_network, + prepare_tmp_dir, ) logger = logging.getLogger(__name__) @@ -57,38 +57,27 @@ def __init__(self, provider_config: dict, cluster_name: str): def bootstrap_config(cls, cluster_config: Dict[str, Any]) -> Dict[str, Any]: config = deepcopy(cluster_config) - provider_parameters: Dict = config["provider"]["parameters"] - provider_parameters.setdefault("webserver_port", 4578) - provider_parameters.setdefault("enable_registry_stats", True) - provider_parameters.setdefault("network", "goerli") - provider_parameters.setdefault("budget", 1) + cls._apply_config_defaults(config) + provider_parameters = config["provider"]["parameters"] ray_on_golem_client = cls._get_ray_on_golem_client_instance( provider_parameters["webserver_port"], provider_parameters["enable_registry_stats"], ) - auth: Dict = config["auth"] - auth.setdefault("ssh_user", "root") - - if "ssh_private_key" not in auth: - ssh_key_path = TMP_PATH / get_default_ssh_key_name(config["cluster_name"]) - auth["ssh_private_key"] = str(ssh_key_path) - - if not ssh_key_path.exists(): + auth = config["auth"] + default_ssh_private_key = TMP_PATH / get_default_ssh_key_name(config["cluster_name"]) + if auth["ssh_private_key"] == str(default_ssh_private_key): + if not default_ssh_private_key.exists(): ssh_key_base64 = ray_on_golem_client.get_or_create_default_ssh_key( config["cluster_name"] ) # FIXME: mitigate double file writing on local machine as get_or_create_default_ssh_key creates the file - ssh_key_path.parent.mkdir(parents=True, exist_ok=True) - with ssh_key_path.open("w") as f: + default_ssh_private_key.parent.mkdir(parents=True, exist_ok=True) + with default_ssh_private_key.open("w") as f: f.write(ssh_key_base64) - # copy ssh details to provider namespace for cluster creation in __init__ - provider_parameters["_ssh_private_key"] = auth["ssh_private_key"] - provider_parameters["_ssh_user"] = auth["ssh_user"] - global_event_system.execute_callback( CreateClusterEvent.ssh_keypair_downloaded, {"ssh_key_path": auth["ssh_private_key"]}, @@ -173,6 +162,26 @@ def external_ip(self, node_id: NodeId) -> str: def set_node_tags(self, node_id: NodeId, tags: Dict) -> None: self._ray_on_golem_client.set_node_tags(node_id, tags) + @staticmethod + def _apply_config_defaults(config: Dict[str, Any]) -> None: + provider_parameters: Dict = config["provider"]["parameters"] + provider_parameters.setdefault("webserver_port", 4578) + provider_parameters.setdefault("enable_registry_stats", True) + provider_parameters.setdefault("network", "goerli") + provider_parameters.setdefault("budget", 1) + + auth: Dict = config.setdefault("auth", {}) + auth.setdefault("ssh_user", "root") + + if "ssh_private_key" not in auth: + auth["ssh_private_key"] = str( + TMP_PATH / get_default_ssh_key_name(config["cluster_name"]) + ) + + # copy ssh details to provider namespace for cluster creation in __init__ + provider_parameters["_ssh_private_key"] = auth["ssh_private_key"] + provider_parameters["_ssh_user"] = auth["ssh_user"] + @staticmethod def _start_webserver( ray_on_golem_client: RayOnGolemClient, @@ -189,6 +198,7 @@ def _start_webserver( ) args = [ RAY_ON_GOLEM_PATH, + "webserver", "-p", str(port), "--registry-stats" if registry_stats else "--no-registry-stats", diff --git a/ray_on_golem/server/__init__.py b/ray_on_golem/server/__init__.py index e69de29b..ebdf1f69 100644 --- a/ray_on_golem/server/__init__.py +++ b/ray_on_golem/server/__init__.py @@ -0,0 +1,3 @@ +from ray_on_golem.server.main import main + +__all__ = ("main",) diff --git a/ray_on_golem/server/run.py b/ray_on_golem/server/main.py similarity index 66% rename from ray_on_golem/server/run.py rename to ray_on_golem/server/main.py index d1d34e19..4fca5d8c 100644 --- a/ray_on_golem/server/run.py +++ b/ray_on_golem/server/main.py @@ -1,7 +1,7 @@ -import argparse import logging import logging.config +import click from aiohttp import web from ray_on_golem.server.middlewares import error_middleware, trace_id_middleware @@ -14,44 +14,51 @@ YAGNA_PATH, ) from ray_on_golem.server.views import routes +from ray_on_golem.utils import prepare_tmp_dir logger = logging.getLogger(__name__) -def parse_sys_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Ray on Golem's webserver.") - parser.add_argument( - "-p", - "--port", - type=int, - default=4578, - help="port for Ray on Golem's webserver to listen on, default: %(default)s", - ) - parser.add_argument( - "--self-shutdown", - action="store_true", - help="flag to enable self-shutdown after last node termination, default: %(default)s", - ) - parser.add_argument("--no-self-shutdown", action="store_false", dest="self_shutdown") - parser.add_argument( - "--registry-stats", - action="store_true", - help="flag to enable collection of Golem Registry stats about resolved images, default: %(default)s", - ) - parser.add_argument( - "--no-registry-stats", - action="store_false", - dest="registry_stats", - ) - parser.set_defaults(self_shutdown=False, registry_stats=True) - return parser.parse_args() +@click.command( + name="webserver", + help="Run Ray on Golem's webserver.", + context_settings={"show_default": True}, +) +@click.option( + "-p", + "--port", + type=int, + default=4578, + help="Port for Ray on Golem's webserver to listen on.", +) +@click.option( + "--self-shutdown", + is_flag=True, + help="Enable self-shutdown after last node termination.", +) +@click.option( + "--registry-stats/--no-registry-stats", + default=True, + help="Enable collection of Golem Registry stats about resolved images.", +) +def main(port: int, self_shutdown: bool, registry_stats: bool): + logging.config.dictConfig(LOGGING_CONFIG) + app = create_application(port, self_shutdown, registry_stats) + + logger.info(f"Starting server... {port=}, {self_shutdown=}, {registry_stats=}") -def prepare_tmp_dir(): try: - TMP_PATH.mkdir(parents=True, exist_ok=True) - except OSError: - pass + web.run_app( + app, + port=app["port"], + print=None, + shutdown_timeout=RAY_ON_GOLEM_SHUTDOWN_DEADLINE.total_seconds(), + ) + except Exception: + logger.info("Server unexpectedly died, bye!") + else: + logger.info("Stopping server done, bye!") def create_application(port: int, self_shutdown: bool, registry_stats: bool) -> web.Application: @@ -96,6 +103,7 @@ async def startup_print(app: web.Application) -> None: async def shutdown_print(app: web.Application) -> None: + print("") # explicit new line to console to visually better handle ^C logger.info("Stopping server gracefully, forcing after `%s`...", RAY_ON_GOLEM_SHUTDOWN_DEADLINE) @@ -128,30 +136,6 @@ async def ray_service_ctx(app: web.Application) -> None: await ray_service.shutdown() -def main(): - prepare_tmp_dir() - args = parse_sys_args() - - logging.config.dictConfig(LOGGING_CONFIG) - - app = create_application(args.port, args.self_shutdown, args.registry_stats) - - logger.info( - "Starting server... {}".format(", ".join(f"{k}={v}" for k, v in args.__dict__.items())) - ) - - try: - web.run_app( - app, - port=app["port"], - print=None, - shutdown_timeout=RAY_ON_GOLEM_SHUTDOWN_DEADLINE.total_seconds(), - ) - except Exception: - logger.info("Server unexpectedly died, bye!") - else: - logger.info("Stopping server done, bye!") - - if __name__ == "__main__": + prepare_tmp_dir() main() diff --git a/ray_on_golem/server/services/golem/golem.py b/ray_on_golem/server/services/golem/golem.py index 59993c75..fbbda120 100644 --- a/ray_on_golem/server/services/golem/golem.py +++ b/ray_on_golem/server/services/golem/golem.py @@ -1,105 +1,44 @@ import asyncio -import base64 import hashlib -import json import logging -import platform from collections import defaultdict -from dataclasses import dataclass from datetime import timedelta from functools import partial from pathlib import Path -from typing import AsyncIterator, Dict, List, Optional, Tuple +from typing import AsyncIterator, Dict, Optional, Tuple -import aiohttp -import ray from golem.managers import ( - ActivityManager, AddChosenPaymentPlatform, - AgreementManager, BlacklistProviderIdPlugin, Buffer, DefaultAgreementManager, DefaultProposalManager, - DemandManager, - LinearAverageCostPricing, MapScore, NegotiatingPlugin, PayAllPaymentManager, - PaymentManager, - ProposalManager, - ProposalManagerPlugin, - ProposalScorer, RefreshingDemandManager, - RejectIfCostsExceeds, ScoringBuffer, WorkContext, ) -from golem.managers.proposal.plugins.linear_coeffs import LinearCoeffsCost -from golem.node import SUBNET, GolemNode -from golem.payload import ManifestVmPayload, Payload, constraint, prop +from golem.node import GolemNode from golem.resources import Activity, Network, ProposalData -from pydantic import BaseModel, Field from yarl import URL -from ray_on_golem.server.exceptions import RayOnGolemServerError, RegistryRequestError -from ray_on_golem.server.models import DemandConfigData, NodeConfigData -from ray_on_golem.server.services.golem.manifest import get_manifest +from ray_on_golem.server.models import NodeConfigData +from ray_on_golem.server.services.golem.helpers.demand_config import DemandConfigHelper +from ray_on_golem.server.services.golem.helpers.manager_stack import ManagerStackNodeConfigHelper +from ray_on_golem.server.services.golem.manager_stack import ManagerStack from ray_on_golem.server.services.golem.provider_data import PROVIDERS_BLACKLIST, PROVIDERS_SCORED from ray_on_golem.server.services.utils import get_ssh_command logger = logging.getLogger(__name__) -class ManagerStack(BaseModel): - payment_manager: Optional[PaymentManager] - demand_manager: Optional[DemandManager] - proposal_manager: Optional[ProposalManager] - agreement_manager: Optional[AgreementManager] - activity_manager: Optional[ActivityManager] - extra_proposal_plugins: List[ProposalManagerPlugin] = Field(default_factory=list) - extra_proposal_scorers: List[ProposalScorer] = Field(default_factory=list) - - class Config: - arbitrary_types_allowed = True - - @property - def _managers(self): - return [ - self.payment_manager, - self.demand_manager, - self.proposal_manager, - self.agreement_manager, - self.activity_manager, - ] - - async def start(self) -> None: - logger.info("Starting stack managers...") - - for manager in self._managers: - if manager is not None: - await manager.start() - - logger.info("Starting stack managers done") - - async def stop(self) -> None: - logger.info("Stopping stack managers...") - - for manager in reversed(self._managers): - if manager is not None: - try: - await manager.stop() - except Exception: - logger.exception(f"{manager} stop failed!") - - logger.info("Stopping stack managers done") - - class GolemService: def __init__(self, websocat_path: Path, registry_stats: bool): self._websocat_path = websocat_path - self._registry_stats = registry_stats + self._demand_config_helper: DemandConfigHelper = DemandConfigHelper(registry_stats) self._golem: Optional[GolemNode] = None self._network: Optional[Network] = None self._yagna_appkey: Optional[str] = None @@ -169,10 +108,12 @@ async def _create_stack( ) -> ManagerStack: stack = ManagerStack() - payload = await self._get_payload_from_demand_config(node_config.demand) + payload = await self._demand_config_helper.get_payload_from_demand_config( + node_config.demand + ) - self._apply_cost_management_avg_usage(stack, node_config) - self._apply_cost_management_hard_limits(stack, node_config) + ManagerStackNodeConfigHelper.apply_cost_management_avg_usage(stack, node_config) + ManagerStackNodeConfigHelper.apply_cost_management_hard_limits(stack, node_config) stack.payment_manager = PayAllPaymentManager(self._golem, budget=budget, network=network) stack.demand_manager = RefreshingDemandManager( @@ -186,13 +127,13 @@ async def _create_stack( stack.demand_manager.get_initial_proposal, plugins=( BlacklistProviderIdPlugin(PROVIDERS_BLACKLIST.get(network, set())), - *stack.extra_proposal_plugins, + *stack.extra_proposal_plugins.values(), ScoringBuffer( min_size=50, max_size=1000, fill_at_start=True, proposal_scorers=( - *stack.extra_proposal_scorers, + *stack.extra_proposal_scorers.values(), MapScore(partial(self._score_with_provider_data, network=network)), ), update_interval=timedelta(seconds=10), @@ -209,129 +150,6 @@ async def _create_stack( return stack - async def _get_payload_from_demand_config(self, demand_config: DemandConfigData) -> Payload: - @dataclass - class CustomManifestVmPayload(ManifestVmPayload): - subnet_constraint: str = constraint("golem.node.debug.subnet", "=", default=SUBNET) - debit_notes_accept_timeout: int = prop( - "golem.com.payment.debit-notes.accept-timeout?", default=240 - ) - - image_url, image_hash = await self._get_image_url_and_hash(demand_config) - - manifest = get_manifest(image_url, image_hash) - manifest = base64.b64encode(json.dumps(manifest).encode("utf-8")).decode("utf-8") - - params = demand_config.dict(exclude={"image_hash", "image_tag"}) - params["manifest"] = manifest - - payload = CustomManifestVmPayload(**params) - - return payload - - async def _get_image_url_and_hash(self, demand_config: DemandConfigData) -> Tuple[URL, str]: - image_tag = demand_config.image_tag - image_hash = demand_config.image_hash - - if image_tag is not None and image_hash is not None: - raise RayOnGolemServerError( - "Only one of `image_tag` and `image_hash` parameter should be defined!" - ) - - if image_hash is not None: - image_url = await self._get_image_url_from_hash(image_hash) - return image_url, image_hash - - if image_tag is None: - python_version = platform.python_version() - ray_version = ray.__version__ - image_tag = f"golem/ray-on-golem:py{python_version}-ray{ray_version}" - - return await self._get_image_url_and_hash_from_tag(image_tag) - - async def _get_image_url_from_hash(self, image_hash: str) -> URL: - async with aiohttp.ClientSession() as session: - async with session.get( - f"https://registry.golem.network/v1/image/info", - params={"hash": image_hash, "count": str(self._registry_stats).lower()}, - ) as response: - response_data = await response.json() - - if response.status == 200: - return URL(response_data["http"]) - elif response.status == 404: - raise RegistryRequestError(f"Image hash `{image_hash}` does not exist") - else: - raise RegistryRequestError("Can't access Golem Registry for image lookup!") - - async def _get_image_url_and_hash_from_tag(self, image_tag: str) -> Tuple[URL, str]: - async with aiohttp.ClientSession() as session: - async with session.get( - f"https://registry.golem.network/v1/image/info", - params={"tag": image_tag, "count": str(self._registry_stats).lower()}, - ) as response: - response_data = await response.json() - - if response.status == 200: - return response_data["http"], response_data["sha3"] - elif response.status == 404: - raise RegistryRequestError(f"Image tag `{image_tag}` does not exist") - else: - raise RegistryRequestError("Can't access Golem Registry for image lookup!") - - def _apply_cost_management_avg_usage( - self, stack: ManagerStack, node_config: NodeConfigData - ) -> None: - cost_management = node_config.cost_management - - if cost_management is None or not cost_management.is_average_usage_cost_enabled(): - logger.debug("Cost management based on average usage is not enabled") - return - - linear_average_cost = LinearAverageCostPricing( - average_cpu_load=node_config.cost_management.average_cpu_load, - average_duration=timedelta( - minutes=node_config.cost_management.average_duration_minutes - ), - ) - - stack.extra_proposal_scorers.append( - MapScore(linear_average_cost, normalize=True, normalize_flip=True), - ) - - max_average_usage_cost = node_config.cost_management.max_average_usage_cost - if max_average_usage_cost is not None: - stack.extra_proposal_plugins.append( - RejectIfCostsExceeds(max_average_usage_cost, linear_average_cost), - ) - logger.debug("Cost management based on average usage applied with max limits") - else: - logger.debug("Cost management based on average usage applied without max limits") - - def _apply_cost_management_hard_limits( - self, stack: ManagerStack, node_config: NodeConfigData - ) -> None: - # TODO: Consider creating RejectIfCostsExceeds variant for multiple values - proposal_plugins = [] - field_names = { - "max_initial_price": "price_initial", - "max_cpu_sec_price": "price_cpu_sec", - "max_duration_sec_price": "price_duration_sec", - } - - for cost_field_name, coef_field_name in field_names.items(): - cost_max_value = getattr(node_config.cost_management, cost_field_name, None) - if cost_max_value is not None: - proposal_plugins.append( - RejectIfCostsExceeds(cost_max_value, LinearCoeffsCost(coef_field_name)), - ) - - if proposal_plugins: - stack.extra_proposal_plugins.extend(proposal_plugins) - logger.debug("Cost management based on max limits applied") - else: - logger.debug("Cost management based on max limits is not enabled") - def _score_with_provider_data( self, proposal_data: ProposalData, network: str ) -> Optional[float]: diff --git a/ray_on_golem/server/services/golem/helpers/demand_config.py b/ray_on_golem/server/services/golem/helpers/demand_config.py new file mode 100644 index 00000000..b4e1e048 --- /dev/null +++ b/ray_on_golem/server/services/golem/helpers/demand_config.py @@ -0,0 +1,93 @@ +import base64 +import json +import logging +import platform +from dataclasses import dataclass +from typing import Tuple + +import aiohttp +import ray +from golem.node import SUBNET +from golem.payload import ManifestVmPayload, Payload, constraint, prop +from yarl import URL + +from ray_on_golem.server.exceptions import RayOnGolemServerError, RegistryRequestError +from ray_on_golem.server.models import DemandConfigData +from ray_on_golem.server.services.golem.manifest import get_manifest + +logger = logging.getLogger(__name__) + + +class DemandConfigHelper: + def __init__(self, registry_stats: bool): + self._registry_stats = registry_stats + + async def get_payload_from_demand_config(self, demand_config: DemandConfigData) -> Payload: + @dataclass + class CustomManifestVmPayload(ManifestVmPayload): + subnet_constraint: str = constraint("golem.node.debug.subnet", "=", default=SUBNET) + debit_notes_accept_timeout: int = prop( + "golem.com.payment.debit-notes.accept-timeout?", default=240 + ) + + image_url, image_hash = await self._get_image_url_and_hash(demand_config) + + manifest = get_manifest(image_url, image_hash) + manifest = base64.b64encode(json.dumps(manifest).encode("utf-8")).decode("utf-8") + + params = demand_config.dict(exclude={"image_hash", "image_tag"}) + params["manifest"] = manifest + + payload = CustomManifestVmPayload(**params) + + return payload + + async def _get_image_url_and_hash(self, demand_config: DemandConfigData) -> Tuple[URL, str]: + image_tag = demand_config.image_tag + image_hash = demand_config.image_hash + + if image_tag is not None and image_hash is not None: + raise RayOnGolemServerError( + "Only one of `image_tag` and `image_hash` parameter should be defined!" + ) + + if image_hash is not None: + image_url = await self._get_image_url_from_hash(image_hash) + return image_url, image_hash + + if image_tag is None: + python_version = platform.python_version() + ray_version = ray.__version__ + image_tag = f"golem/ray-on-golem:py{python_version}-ray{ray_version}" + + return await self._get_image_url_and_hash_from_tag(image_tag) + + async def _get_image_url_from_hash(self, image_hash: str) -> URL: + async with aiohttp.ClientSession() as session: + async with session.get( + f"https://registry.golem.network/v1/image/info", + params={"hash": image_hash, "count": str(self._registry_stats).lower()}, + ) as response: + response_data = await response.json() + + if response.status == 200: + return URL(response_data["http"]) + elif response.status == 404: + raise RegistryRequestError(f"Image hash `{image_hash}` does not exist") + else: + raise RegistryRequestError("Can't access Golem Registry for image lookup!") + + async def _get_image_url_and_hash_from_tag(self, image_tag: str) -> Tuple[URL, str]: + async with aiohttp.ClientSession() as session: + async with session.get( + f"https://registry.golem.network/v1/image/info", + params={"tag": image_tag, "count": str(self._registry_stats).lower()}, + ) as response: + response_data = await response.json() + + if response.status == 200: + return response_data["http"], response_data["sha3"] + elif response.status == 404: + raise RegistryRequestError(f"Image tag `{image_tag}` does not exist") + else: + raise RegistryRequestError("Can't access Golem Registry for image lookup!") diff --git a/ray_on_golem/server/services/golem/helpers/manager_stack.py b/ray_on_golem/server/services/golem/helpers/manager_stack.py new file mode 100644 index 00000000..48c52c13 --- /dev/null +++ b/ray_on_golem/server/services/golem/helpers/manager_stack.py @@ -0,0 +1,63 @@ +import logging +from datetime import timedelta + +from golem.managers import LinearAverageCostPricing, MapScore, RejectIfCostsExceeds +from golem.managers.proposal.plugins.linear_coeffs import LinearCoeffsCost + +from ray_on_golem.server.models import NodeConfigData +from ray_on_golem.server.services.golem.manager_stack import ManagerStack + +logger = logging.getLogger(__name__) + + +class ManagerStackNodeConfigHelper: + @staticmethod + def apply_cost_management_avg_usage(stack: ManagerStack, node_config: NodeConfigData) -> None: + cost_management = node_config.cost_management + + if cost_management is None or not cost_management.is_average_usage_cost_enabled(): + logger.debug("Cost management based on average usage is not enabled") + return + + linear_average_cost = LinearAverageCostPricing( + average_cpu_load=node_config.cost_management.average_cpu_load, + average_duration=timedelta( + minutes=node_config.cost_management.average_duration_minutes + ), + ) + + stack.extra_proposal_scorers["Sort by linear average cost"] = MapScore( + linear_average_cost, normalize=True, normalize_flip=True + ) + + max_average_usage_cost = node_config.cost_management.max_average_usage_cost + if max_average_usage_cost is not None: + stack.extra_proposal_plugins[ + f"Reject if max_average_usage_cost exceeds {node_config.cost_management.max_average_usage_cost}" + ] = RejectIfCostsExceeds(max_average_usage_cost, linear_average_cost) + logger.debug("Cost management based on average usage applied with max limits") + else: + logger.debug("Cost management based on average usage applied without max limits") + + @staticmethod + def apply_cost_management_hard_limits(stack: ManagerStack, node_config: NodeConfigData) -> None: + # TODO: Consider creating RejectIfCostsExceeds variant for multiple values + proposal_plugins = {} + field_names = { + "max_initial_price": "price_initial", + "max_cpu_sec_price": "price_cpu_sec", + "max_duration_sec_price": "price_duration_sec", + } + + for cost_field_name, coef_field_name in field_names.items(): + cost_max_value = getattr(node_config.cost_management, cost_field_name, None) + if cost_max_value is not None: + proposal_plugins[ + f"Reject if {coef_field_name} exceeds {cost_max_value}" + ] = RejectIfCostsExceeds(cost_max_value, LinearCoeffsCost(coef_field_name)) + + if proposal_plugins: + stack.extra_proposal_plugins.update(proposal_plugins) + logger.debug("Cost management based on max limits applied") + else: + logger.debug("Cost management based on max limits is not enabled") diff --git a/ray_on_golem/server/services/golem/manager_stack.py b/ray_on_golem/server/services/golem/manager_stack.py new file mode 100644 index 00000000..e22bea16 --- /dev/null +++ b/ray_on_golem/server/services/golem/manager_stack.py @@ -0,0 +1,59 @@ +import logging +from typing import Dict, Optional + +from golem.managers import ( + ActivityManager, + AgreementManager, + DemandManager, + PaymentManager, + ProposalManager, + ProposalManagerPlugin, + ProposalScorer, +) +from pydantic import BaseModel, Field + +logger = logging.getLogger(__name__) + + +class ManagerStack(BaseModel): + payment_manager: Optional[PaymentManager] + demand_manager: Optional[DemandManager] + proposal_manager: Optional[ProposalManager] + agreement_manager: Optional[AgreementManager] + activity_manager: Optional[ActivityManager] + extra_proposal_plugins: Dict[str, ProposalManagerPlugin] = Field(default_factory=dict) + extra_proposal_scorers: Dict[str, ProposalScorer] = Field(default_factory=dict) + + class Config: + arbitrary_types_allowed = True + + @property + def _managers(self): + return [ + self.payment_manager, + self.demand_manager, + self.proposal_manager, + self.agreement_manager, + self.activity_manager, + ] + + async def start(self) -> None: + logger.info("Starting stack managers...") + + for manager in self._managers: + if manager is not None: + await manager.start() + + logger.info("Starting stack managers done") + + async def stop(self) -> None: + logger.info("Stopping stack managers...") + + for manager in reversed(self._managers): + if manager is not None: + try: + await manager.stop() + except Exception: + logger.exception(f"{manager} stop failed!") + + logger.info("Stopping stack managers done") diff --git a/ray_on_golem/utils.py b/ray_on_golem/utils.py index 8193e18a..48156cf0 100644 --- a/ray_on_golem/utils.py +++ b/ray_on_golem/utils.py @@ -11,6 +11,7 @@ from aiohttp.web_runner import GracefulExit from ray_on_golem.exceptions import RayOnGolemError +from ray_on_golem.server.settings import TMP_PATH async def run_subprocess( @@ -77,3 +78,10 @@ def rolloverLogFiles(): def get_last_lines_from_file(file_path: Path, max_lines: int) -> str: with file_path.open() as file: return "".join(deque(file, max_lines)) + + +def prepare_tmp_dir(): + try: + TMP_PATH.mkdir(parents=True, exist_ok=True) + except OSError: + pass