Skip to content

Commit

Permalink
Add ray-on-golem stats command (#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
lucekdudek authored Nov 9, 2023
1 parent 9a34b45 commit 81d8445
Show file tree
Hide file tree
Showing 16 changed files with 650 additions and 275 deletions.
3 changes: 2 additions & 1 deletion golem-cluster.override.3-disable-stats.yaml
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
provider:
enable_registry_stats: false
params:
enable_registry_stats: false
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ click = "^8"
pydantic = "<2"

[tool.poetry.scripts]
ray-on-golem = "ray_on_golem.server.run:main"
ray-on-golem = "ray_on_golem.main:main"

[tool.poetry.group.dev.dependencies]
poethepoet = "^0.22.0"
Expand Down
3 changes: 3 additions & 0 deletions ray_on_golem/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ray_on_golem.main import main

main()
24 changes: 24 additions & 0 deletions ray_on_golem/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import click

from ray_on_golem.network_stats import main as network_stats
from ray_on_golem.server import main as webserver
from ray_on_golem.utils import prepare_tmp_dir


@click.group()
def cli():
pass


cli.add_command(network_stats)
cli.add_command(webserver)


def main():
prepare_tmp_dir()

cli()


if __name__ == "__main__":
main()
3 changes: 3 additions & 0 deletions ray_on_golem/network_stats/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ray_on_golem.network_stats.main import main

__all__ = ("main",)
74 changes: 74 additions & 0 deletions ray_on_golem/network_stats/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import asyncio
import logging
import logging.config
from contextlib import asynccontextmanager
from typing import Dict

import click
import yaml

from ray_on_golem.network_stats.services import NetworkStatsService
from ray_on_golem.provider.node_provider import GolemNodeProvider
from ray_on_golem.server.services import YagnaService
from ray_on_golem.server.settings import LOGGING_CONFIG, YAGNA_PATH
from ray_on_golem.utils import prepare_tmp_dir


@click.command(
name="network-stats",
short_help="Run Golem Network statistics.",
help="Run Golem Network statistics based on given cluster config file.",
context_settings={"show_default": True},
)
@click.argument("cluster-config-file", type=click.Path(exists=True))
@click.option(
"-d",
"--duration",
type=int,
default=5,
help="Set the duration of the statistics gathering process, in minutes.",
)
@click.option(
"--enable-logging",
is_flag=True,
default=False,
help="Enable verbose logging.",
)
def main(cluster_config_file: str, duration: int, enable_logging: bool):
if enable_logging:
logging.config.dictConfig(LOGGING_CONFIG)

with open(cluster_config_file) as file:
config = yaml.safe_load(file.read())

GolemNodeProvider._apply_config_defaults(config)

asyncio.run(_network_stats(config, duration))


async def _network_stats(config: Dict, duration: int):
provider_params = config["provider"]["parameters"]

async with network_stats_service(provider_params["enable_registry_stats"]) as stats_service:
await stats_service.run(provider_params, duration)


@asynccontextmanager
async def network_stats_service(registry_stats: bool) -> NetworkStatsService:
network_stats_service: NetworkStatsService = NetworkStatsService(registry_stats)
yagna_service = YagnaService(
yagna_path=YAGNA_PATH,
)

await yagna_service.init()
await network_stats_service.init(yagna_appkey=yagna_service.yagna_appkey)

yield network_stats_service

await network_stats_service.shutdown()
await yagna_service.shutdown()


if __name__ == "__main__":
prepare_tmp_dir()
main()
1 change: 1 addition & 0 deletions ray_on_golem/network_stats/services/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ray_on_golem.network_stats.services.network_stats import NetworkStatsService
231 changes: 231 additions & 0 deletions ray_on_golem/network_stats/services/network_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
import asyncio
import logging
import re
from collections import defaultdict
from datetime import timedelta
from typing import Dict, Optional, Sequence

from golem.managers import (
AddChosenPaymentPlatform,
BlacklistProviderIdPlugin,
Buffer,
DefaultProposalManager,
NegotiatingPlugin,
PayAllPaymentManager,
ProposalManagerPlugin,
RefreshingDemandManager,
ScoringBuffer,
)
from golem.managers.base import ProposalNegotiator
from golem.node import GolemNode
from golem.payload import PayloadSyntaxParser
from golem.resources import DemandData, Proposal
from golem.resources.proposal.exceptions import ProposalRejected
from ya_market import ApiException

from ray_on_golem.server.models import NodeConfigData
from ray_on_golem.server.services.golem.helpers.demand_config import DemandConfigHelper
from ray_on_golem.server.services.golem.helpers.manager_stack import ManagerStackNodeConfigHelper
from ray_on_golem.server.services.golem.manager_stack import ManagerStack
from ray_on_golem.server.services.golem.provider_data import PROVIDERS_BLACKLIST

logger = logging.getLogger(__name__)


class ProposalCounterPlugin(ProposalManagerPlugin):
def __init__(self) -> None:
self._count = 0

async def get_proposal(self) -> Proposal:
while True:
proposal: Proposal = await self._get_proposal()
self._count += 1
return proposal


class StatsNegotiatingPlugin(NegotiatingPlugin):
def __init__(
self,
demand_offer_parser: Optional[PayloadSyntaxParser] = None,
proposal_negotiators: Optional[Sequence[ProposalNegotiator]] = None,
*args,
**kwargs,
) -> None:
super().__init__(demand_offer_parser, proposal_negotiators, *args, **kwargs)
self.fails = defaultdict(int)

async def get_proposal(self) -> Proposal:
while True:
proposal = await self._get_proposal()

demand_data = await self._get_demand_data_from_proposal(proposal)

try:
negotiated_proposal = await self._negotiate_proposal(demand_data, proposal)
return negotiated_proposal
except ProposalRejected as err:
self.fails[err.reason] += 1
except Exception as err:
self.fails[str(err)] += 1

async def _send_demand_proposal(
self, offer_proposal: Proposal, demand_data: DemandData
) -> Proposal:
try:
return await offer_proposal.respond(
demand_data.properties,
demand_data.constraints,
)
except ApiException as e:
error_msg = re.sub(r"\[.*?\]", "[***]", str(e.body))
raise RuntimeError(f"Failed to send proposal response! {e.status}: {error_msg}") from e
except asyncio.TimeoutError as e:
raise RuntimeError(f"Failed to send proposal response! Request timed out") from e


class StatsPluginFactory:
_stats_negotiating_plugin: StatsNegotiatingPlugin

def __init__(self) -> None:
self._counters = {}

def print_gathered_stats(self) -> None:
print("\nProposals count:")
[print(f"{tag}: {counter._count}") for tag, counter in self._counters.items()]
print("\nNegotiation errors:")
[
print(f"{err}: {count}")
for err, count in sorted(
self._stats_negotiating_plugin.fails.items(), key=lambda item: item[1], reverse=True
)
]

def create_negotiating_plugin(self) -> StatsNegotiatingPlugin:
self._stats_negotiating_plugin = StatsNegotiatingPlugin(
proposal_negotiators=(AddChosenPaymentPlatform(),),
)
return self._stats_negotiating_plugin

def create_counter_plugin(self, tag: str) -> ProposalCounterPlugin:
self._counters[tag] = ProposalCounterPlugin()
return self._counters[tag]


class NetworkStatsService:
def __init__(self, registry_stats: bool) -> None:
self._registry_stats = registry_stats

self._demand_config_helper: DemandConfigHelper = DemandConfigHelper(registry_stats)

self._golem: Optional[GolemNode] = None
self._yagna_appkey: Optional[str] = None

self._stats_plugin_factory = StatsPluginFactory()

async def init(self, yagna_appkey: str) -> None:
logger.info("Starting NetworkStatsService...")

self._golem = GolemNode(app_key=yagna_appkey)
self._yagna_appkey = yagna_appkey
await self._golem.start()

logger.info("Starting NetworkStatsService done")

async def shutdown(self) -> None:
logger.info("Stopping NetworkStatsService...")

await self._golem.aclose()
self._golem = None

logger.info("Stopping NetworkStatsService done")

async def run(self, provider_parameters: Dict, duration_minutes: int) -> None:
network: str = provider_parameters["network"]
budget: int = provider_parameters["budget"]
node_config: NodeConfigData = NodeConfigData(**provider_parameters["node_config"])

stack = await self._create_stack(node_config, budget, network)
await stack.start()

print(f"Gathering stats data for {duration_minutes} minutes...")
consume_proposals_task = asyncio.create_task(self._consume_draft_proposals(stack))
try:
await asyncio.wait(
[consume_proposals_task],
timeout=timedelta(minutes=duration_minutes).total_seconds(),
)
finally:
consume_proposals_task.cancel()
await consume_proposals_task

await stack.stop()
print("Gathering stats data done")
self._stats_plugin_factory.print_gathered_stats()

async def _consume_draft_proposals(self, stack: ManagerStack) -> None:
drafts = []
try:
while True:
draft = await stack.proposal_manager.get_draft_proposal()
drafts.append(draft)
except asyncio.CancelledError:
return
finally:
await asyncio.gather(
# FIXME better reason message
*[draft.reject(reason="No more needed") for draft in drafts],
return_exceptions=True,
)

async def _create_stack(
self, node_config: NodeConfigData, budget: float, network: str
) -> ManagerStack:
stack = ManagerStack()

payload = await self._demand_config_helper.get_payload_from_demand_config(
node_config.demand
)

ManagerStackNodeConfigHelper.apply_cost_management_avg_usage(stack, node_config)
ManagerStackNodeConfigHelper.apply_cost_management_hard_limits(stack, node_config)

stack.payment_manager = PayAllPaymentManager(self._golem, budget=budget, network=network)
stack.demand_manager = RefreshingDemandManager(
self._golem,
stack.payment_manager.get_allocation,
payload,
demand_expiration_timeout=timedelta(hours=8),
)

plugins = [
self._stats_plugin_factory.create_counter_plugin("Initial"),
BlacklistProviderIdPlugin(PROVIDERS_BLACKLIST.get(network, set())),
self._stats_plugin_factory.create_counter_plugin("Not blacklisted"),
]

for plugin_tag, plugin in stack.extra_proposal_plugins.items():
plugins.append(plugin)
plugins.append(self._stats_plugin_factory.create_counter_plugin(f"Passed {plugin_tag}"))

plugins.extend(
[
ScoringBuffer(
min_size=50,
max_size=1000,
fill_at_start=True,
proposal_scorers=(*stack.extra_proposal_scorers.values(),),
update_interval=timedelta(seconds=10),
),
self._stats_plugin_factory.create_counter_plugin("Scored"),
self._stats_plugin_factory.create_negotiating_plugin(),
self._stats_plugin_factory.create_counter_plugin("Negotiated"),
Buffer(min_size=1, max_size=50, fill_concurrency_size=10),
]
)
stack.proposal_manager = DefaultProposalManager(
self._golem,
stack.demand_manager.get_initial_proposal,
plugins=plugins,
)

return stack
Loading

0 comments on commit 81d8445

Please sign in to comment.