Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ray-on-golem stats command #106

Merged
merged 11 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion golem-cluster.override.3-disable-stats.yaml
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
provider:
enable_registry_stats: false
params:
enable_registry_stats: false
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ click = "^8"
pydantic = "<2"

[tool.poetry.scripts]
ray-on-golem = "ray_on_golem.server.run:main"
ray-on-golem = "ray_on_golem.main:main"

[tool.poetry.group.dev.dependencies]
poethepoet = "^0.22.0"
Expand Down
3 changes: 3 additions & 0 deletions ray_on_golem/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ray_on_golem.main import main

main()
24 changes: 24 additions & 0 deletions ray_on_golem/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import click

from ray_on_golem.network_stats import main as network_stats
from ray_on_golem.server import main as webserver
from ray_on_golem.utils import prepare_tmp_dir


@click.group()
def cli():
pass


cli.add_command(network_stats)
cli.add_command(webserver)


def main():
prepare_tmp_dir()

cli()


if __name__ == "__main__":
main()
3 changes: 3 additions & 0 deletions ray_on_golem/network_stats/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ray_on_golem.network_stats.main import main

__all__ = ("main",)
74 changes: 74 additions & 0 deletions ray_on_golem/network_stats/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import asyncio
import logging
import logging.config
from contextlib import asynccontextmanager
from typing import Dict

import click
import yaml

from ray_on_golem.network_stats.services import NetworkStatsService
from ray_on_golem.provider.node_provider import GolemNodeProvider
from ray_on_golem.server.services import YagnaService
from ray_on_golem.server.settings import LOGGING_CONFIG, YAGNA_PATH
from ray_on_golem.utils import prepare_tmp_dir


@click.command(
name="network-stats",
short_help="Run Golem Network statistics.",
help="Run Golem Network statistics based on given cluster config file.",
approxit marked this conversation as resolved.
Show resolved Hide resolved
approxit marked this conversation as resolved.
Show resolved Hide resolved
context_settings={"show_default": True},
)
@click.argument("cluster-config-file", type=click.Path(exists=True))
@click.option(
"-d",
"--duration",
type=int,
default=5,
help="Set the duration of the statistics gathering process, in minutes.",
)
@click.option(
"--enable-logging",
is_flag=True,
default=False,
help="Enable verbose logging.",
)
def main(cluster_config_file: str, duration: int, enable_logging: bool):
if enable_logging:
logging.config.dictConfig(LOGGING_CONFIG)

with open(cluster_config_file) as file:
config = yaml.safe_load(file.read())

GolemNodeProvider._apply_config_defaults(config)

asyncio.run(_network_stats(config, duration))


async def _network_stats(config: Dict, duration: int):
provider_params = config["provider"]["parameters"]

async with network_stats_service(provider_params["enable_registry_stats"]) as stats_service:
await stats_service.run(provider_params, duration)


@asynccontextmanager
async def network_stats_service(registry_stats: bool) -> NetworkStatsService:
network_stats_service: NetworkStatsService = NetworkStatsService(registry_stats)
yagna_service = YagnaService(
yagna_path=YAGNA_PATH,
)

await yagna_service.init()
await network_stats_service.init(yagna_appkey=yagna_service.yagna_appkey)

yield network_stats_service

await network_stats_service.shutdown()
await yagna_service.shutdown()


if __name__ == "__main__":
prepare_tmp_dir()
main()
1 change: 1 addition & 0 deletions ray_on_golem/network_stats/services/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ray_on_golem.network_stats.services.network_stats import NetworkStatsService
231 changes: 231 additions & 0 deletions ray_on_golem/network_stats/services/network_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
import asyncio
import logging
import re
from collections import defaultdict
from datetime import timedelta
from typing import Dict, Optional, Sequence

from golem.managers import (
AddChosenPaymentPlatform,
BlacklistProviderIdPlugin,
Buffer,
DefaultProposalManager,
NegotiatingPlugin,
PayAllPaymentManager,
ProposalManagerPlugin,
RefreshingDemandManager,
ScoringBuffer,
)
from golem.managers.base import ProposalNegotiator
from golem.node import GolemNode
from golem.payload import PayloadSyntaxParser
from golem.resources import DemandData, Proposal
from golem.resources.proposal.exceptions import ProposalRejected
from ya_market import ApiException

from ray_on_golem.server.models import NodeConfigData
from ray_on_golem.server.services.golem.helpers.demand_config import DemandConfigHelper
from ray_on_golem.server.services.golem.helpers.manager_stack import ManagerStackNodeConfigHelper
from ray_on_golem.server.services.golem.manager_stack import ManagerStack
from ray_on_golem.server.services.golem.provider_data import PROVIDERS_BLACKLIST

logger = logging.getLogger(__name__)


class ProposalCounterPlugin(ProposalManagerPlugin):
def __init__(self) -> None:
self._count = 0

async def get_proposal(self) -> Proposal:
while True:
proposal: Proposal = await self._get_proposal()
self._count += 1
return proposal


class StatsNegotiatingPlugin(NegotiatingPlugin):
def __init__(
self,
demand_offer_parser: Optional[PayloadSyntaxParser] = None,
proposal_negotiators: Optional[Sequence[ProposalNegotiator]] = None,
*args,
**kwargs,
) -> None:
super().__init__(demand_offer_parser, proposal_negotiators, *args, **kwargs)
self.fails = defaultdict(int)

async def get_proposal(self) -> Proposal:
while True:
proposal = await self._get_proposal()

demand_data = await self._get_demand_data_from_proposal(proposal)

try:
negotiated_proposal = await self._negotiate_proposal(demand_data, proposal)
return negotiated_proposal
except ProposalRejected as err:
self.fails[err.reason] += 1
except Exception as err:
self.fails[str(err)] += 1

async def _send_demand_proposal(
self, offer_proposal: Proposal, demand_data: DemandData
) -> Proposal:
try:
return await offer_proposal.respond(
demand_data.properties,
demand_data.constraints,
)
except ApiException as e:
error_msg = re.sub(r"\[.*?\]", "[***]", str(e.body))
raise RuntimeError(f"Failed to send proposal response! {e.status}: {error_msg}") from e
except asyncio.TimeoutError as e:
raise RuntimeError(f"Failed to send proposal response! Request timed out") from e


class StatsPluginFactory:
_stats_negotiating_plugin: StatsNegotiatingPlugin

def __init__(self) -> None:
self._counters = {}

def print_gathered_stats(self) -> None:
print("\nProposals count:")
[print(f"{tag}: {counter._count}") for tag, counter in self._counters.items()]
print("\nNegotiation errors:")
[
print(f"{err}: {count}")
for err, count in sorted(
self._stats_negotiating_plugin.fails.items(), key=lambda item: item[1], reverse=True
)
]

def create_negotiating_plugin(self) -> StatsNegotiatingPlugin:
self._stats_negotiating_plugin = StatsNegotiatingPlugin(
proposal_negotiators=(AddChosenPaymentPlatform(),),
)
return self._stats_negotiating_plugin

def create_counter_plugin(self, tag: str) -> ProposalCounterPlugin:
self._counters[tag] = ProposalCounterPlugin()
return self._counters[tag]


class NetworkStatsService:
def __init__(self, registry_stats: bool) -> None:
self._registry_stats = registry_stats

self._demand_config_helper: DemandConfigHelper = DemandConfigHelper(registry_stats)

self._golem: Optional[GolemNode] = None
self._yagna_appkey: Optional[str] = None

self._stats_plugin_factory = StatsPluginFactory()

async def init(self, yagna_appkey: str) -> None:
logger.info("Starting NetworkStatsService...")

self._golem = GolemNode(app_key=yagna_appkey)
self._yagna_appkey = yagna_appkey
await self._golem.start()

logger.info("Starting NetworkStatsService done")

async def shutdown(self) -> None:
logger.info("Stopping NetworkStatsService...")

await self._golem.aclose()
self._golem = None

logger.info("Stopping NetworkStatsService done")

async def run(self, provider_parameters: Dict, duration_minutes: int) -> None:
network: str = provider_parameters["network"]
budget: int = provider_parameters["budget"]
node_config: NodeConfigData = NodeConfigData(**provider_parameters["node_config"])

stack = await self._create_stack(node_config, budget, network)
await stack.start()

print(f"Gathering stats data for {duration_minutes} minutes...")
consume_proposals_task = asyncio.create_task(self._consume_draft_proposals(stack))
try:
await asyncio.wait(
[consume_proposals_task],
timeout=timedelta(minutes=duration_minutes).total_seconds(),
)
finally:
consume_proposals_task.cancel()
await consume_proposals_task

await stack.stop()
print("Gathering stats data done")
self._stats_plugin_factory.print_gathered_stats()

async def _consume_draft_proposals(self, stack: ManagerStack) -> None:
drafts = []
try:
while True:
draft = await stack.proposal_manager.get_draft_proposal()
drafts.append(draft)
except asyncio.CancelledError:
return
finally:
await asyncio.gather(
# FIXME better reason message
*[draft.reject(reason="No more needed") for draft in drafts],
return_exceptions=True,
)

async def _create_stack(
self, node_config: NodeConfigData, budget: float, network: str
) -> ManagerStack:
stack = ManagerStack()

payload = await self._demand_config_helper.get_payload_from_demand_config(
node_config.demand
)

ManagerStackNodeConfigHelper.apply_cost_management_avg_usage(stack, node_config)
ManagerStackNodeConfigHelper.apply_cost_management_hard_limits(stack, node_config)

stack.payment_manager = PayAllPaymentManager(self._golem, budget=budget, network=network)
stack.demand_manager = RefreshingDemandManager(
self._golem,
stack.payment_manager.get_allocation,
payload,
demand_expiration_timeout=timedelta(hours=8),
)

plugins = [
self._stats_plugin_factory.create_counter_plugin("Initial"),
BlacklistProviderIdPlugin(PROVIDERS_BLACKLIST.get(network, set())),
self._stats_plugin_factory.create_counter_plugin("Not blacklisted"),
]

for plugin_tag, plugin in stack.extra_proposal_plugins.items():
plugins.append(plugin)
plugins.append(self._stats_plugin_factory.create_counter_plugin(f"Passed {plugin_tag}"))

plugins.extend(
[
ScoringBuffer(
min_size=50,
max_size=1000,
fill_at_start=True,
proposal_scorers=(*stack.extra_proposal_scorers.values(),),
update_interval=timedelta(seconds=10),
),
self._stats_plugin_factory.create_counter_plugin("Scored"),
self._stats_plugin_factory.create_negotiating_plugin(),
self._stats_plugin_factory.create_counter_plugin("Negotiated"),
Buffer(min_size=1, max_size=50, fill_concurrency_size=10),
]
)
stack.proposal_manager = DefaultProposalManager(
self._golem,
stack.demand_manager.get_initial_proposal,
plugins=plugins,
)

return stack
Loading
Loading