-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add ray-on-golem stats command (#106)
- Loading branch information
1 parent
9a34b45
commit 81d8445
Showing
16 changed files
with
650 additions
and
275 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
provider: | ||
enable_registry_stats: false | ||
params: | ||
enable_registry_stats: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from ray_on_golem.main import main | ||
|
||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import click | ||
|
||
from ray_on_golem.network_stats import main as network_stats | ||
from ray_on_golem.server import main as webserver | ||
from ray_on_golem.utils import prepare_tmp_dir | ||
|
||
|
||
@click.group() | ||
def cli(): | ||
pass | ||
|
||
|
||
cli.add_command(network_stats) | ||
cli.add_command(webserver) | ||
|
||
|
||
def main(): | ||
prepare_tmp_dir() | ||
|
||
cli() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from ray_on_golem.network_stats.main import main | ||
|
||
__all__ = ("main",) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import asyncio | ||
import logging | ||
import logging.config | ||
from contextlib import asynccontextmanager | ||
from typing import Dict | ||
|
||
import click | ||
import yaml | ||
|
||
from ray_on_golem.network_stats.services import NetworkStatsService | ||
from ray_on_golem.provider.node_provider import GolemNodeProvider | ||
from ray_on_golem.server.services import YagnaService | ||
from ray_on_golem.server.settings import LOGGING_CONFIG, YAGNA_PATH | ||
from ray_on_golem.utils import prepare_tmp_dir | ||
|
||
|
||
@click.command( | ||
name="network-stats", | ||
short_help="Run Golem Network statistics.", | ||
help="Run Golem Network statistics based on given cluster config file.", | ||
context_settings={"show_default": True}, | ||
) | ||
@click.argument("cluster-config-file", type=click.Path(exists=True)) | ||
@click.option( | ||
"-d", | ||
"--duration", | ||
type=int, | ||
default=5, | ||
help="Set the duration of the statistics gathering process, in minutes.", | ||
) | ||
@click.option( | ||
"--enable-logging", | ||
is_flag=True, | ||
default=False, | ||
help="Enable verbose logging.", | ||
) | ||
def main(cluster_config_file: str, duration: int, enable_logging: bool): | ||
if enable_logging: | ||
logging.config.dictConfig(LOGGING_CONFIG) | ||
|
||
with open(cluster_config_file) as file: | ||
config = yaml.safe_load(file.read()) | ||
|
||
GolemNodeProvider._apply_config_defaults(config) | ||
|
||
asyncio.run(_network_stats(config, duration)) | ||
|
||
|
||
async def _network_stats(config: Dict, duration: int): | ||
provider_params = config["provider"]["parameters"] | ||
|
||
async with network_stats_service(provider_params["enable_registry_stats"]) as stats_service: | ||
await stats_service.run(provider_params, duration) | ||
|
||
|
||
@asynccontextmanager | ||
async def network_stats_service(registry_stats: bool) -> NetworkStatsService: | ||
network_stats_service: NetworkStatsService = NetworkStatsService(registry_stats) | ||
yagna_service = YagnaService( | ||
yagna_path=YAGNA_PATH, | ||
) | ||
|
||
await yagna_service.init() | ||
await network_stats_service.init(yagna_appkey=yagna_service.yagna_appkey) | ||
|
||
yield network_stats_service | ||
|
||
await network_stats_service.shutdown() | ||
await yagna_service.shutdown() | ||
|
||
|
||
if __name__ == "__main__": | ||
prepare_tmp_dir() | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from ray_on_golem.network_stats.services.network_stats import NetworkStatsService |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,231 @@ | ||
import asyncio | ||
import logging | ||
import re | ||
from collections import defaultdict | ||
from datetime import timedelta | ||
from typing import Dict, Optional, Sequence | ||
|
||
from golem.managers import ( | ||
AddChosenPaymentPlatform, | ||
BlacklistProviderIdPlugin, | ||
Buffer, | ||
DefaultProposalManager, | ||
NegotiatingPlugin, | ||
PayAllPaymentManager, | ||
ProposalManagerPlugin, | ||
RefreshingDemandManager, | ||
ScoringBuffer, | ||
) | ||
from golem.managers.base import ProposalNegotiator | ||
from golem.node import GolemNode | ||
from golem.payload import PayloadSyntaxParser | ||
from golem.resources import DemandData, Proposal | ||
from golem.resources.proposal.exceptions import ProposalRejected | ||
from ya_market import ApiException | ||
|
||
from ray_on_golem.server.models import NodeConfigData | ||
from ray_on_golem.server.services.golem.helpers.demand_config import DemandConfigHelper | ||
from ray_on_golem.server.services.golem.helpers.manager_stack import ManagerStackNodeConfigHelper | ||
from ray_on_golem.server.services.golem.manager_stack import ManagerStack | ||
from ray_on_golem.server.services.golem.provider_data import PROVIDERS_BLACKLIST | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class ProposalCounterPlugin(ProposalManagerPlugin): | ||
def __init__(self) -> None: | ||
self._count = 0 | ||
|
||
async def get_proposal(self) -> Proposal: | ||
while True: | ||
proposal: Proposal = await self._get_proposal() | ||
self._count += 1 | ||
return proposal | ||
|
||
|
||
class StatsNegotiatingPlugin(NegotiatingPlugin): | ||
def __init__( | ||
self, | ||
demand_offer_parser: Optional[PayloadSyntaxParser] = None, | ||
proposal_negotiators: Optional[Sequence[ProposalNegotiator]] = None, | ||
*args, | ||
**kwargs, | ||
) -> None: | ||
super().__init__(demand_offer_parser, proposal_negotiators, *args, **kwargs) | ||
self.fails = defaultdict(int) | ||
|
||
async def get_proposal(self) -> Proposal: | ||
while True: | ||
proposal = await self._get_proposal() | ||
|
||
demand_data = await self._get_demand_data_from_proposal(proposal) | ||
|
||
try: | ||
negotiated_proposal = await self._negotiate_proposal(demand_data, proposal) | ||
return negotiated_proposal | ||
except ProposalRejected as err: | ||
self.fails[err.reason] += 1 | ||
except Exception as err: | ||
self.fails[str(err)] += 1 | ||
|
||
async def _send_demand_proposal( | ||
self, offer_proposal: Proposal, demand_data: DemandData | ||
) -> Proposal: | ||
try: | ||
return await offer_proposal.respond( | ||
demand_data.properties, | ||
demand_data.constraints, | ||
) | ||
except ApiException as e: | ||
error_msg = re.sub(r"\[.*?\]", "[***]", str(e.body)) | ||
raise RuntimeError(f"Failed to send proposal response! {e.status}: {error_msg}") from e | ||
except asyncio.TimeoutError as e: | ||
raise RuntimeError(f"Failed to send proposal response! Request timed out") from e | ||
|
||
|
||
class StatsPluginFactory: | ||
_stats_negotiating_plugin: StatsNegotiatingPlugin | ||
|
||
def __init__(self) -> None: | ||
self._counters = {} | ||
|
||
def print_gathered_stats(self) -> None: | ||
print("\nProposals count:") | ||
[print(f"{tag}: {counter._count}") for tag, counter in self._counters.items()] | ||
print("\nNegotiation errors:") | ||
[ | ||
print(f"{err}: {count}") | ||
for err, count in sorted( | ||
self._stats_negotiating_plugin.fails.items(), key=lambda item: item[1], reverse=True | ||
) | ||
] | ||
|
||
def create_negotiating_plugin(self) -> StatsNegotiatingPlugin: | ||
self._stats_negotiating_plugin = StatsNegotiatingPlugin( | ||
proposal_negotiators=(AddChosenPaymentPlatform(),), | ||
) | ||
return self._stats_negotiating_plugin | ||
|
||
def create_counter_plugin(self, tag: str) -> ProposalCounterPlugin: | ||
self._counters[tag] = ProposalCounterPlugin() | ||
return self._counters[tag] | ||
|
||
|
||
class NetworkStatsService: | ||
def __init__(self, registry_stats: bool) -> None: | ||
self._registry_stats = registry_stats | ||
|
||
self._demand_config_helper: DemandConfigHelper = DemandConfigHelper(registry_stats) | ||
|
||
self._golem: Optional[GolemNode] = None | ||
self._yagna_appkey: Optional[str] = None | ||
|
||
self._stats_plugin_factory = StatsPluginFactory() | ||
|
||
async def init(self, yagna_appkey: str) -> None: | ||
logger.info("Starting NetworkStatsService...") | ||
|
||
self._golem = GolemNode(app_key=yagna_appkey) | ||
self._yagna_appkey = yagna_appkey | ||
await self._golem.start() | ||
|
||
logger.info("Starting NetworkStatsService done") | ||
|
||
async def shutdown(self) -> None: | ||
logger.info("Stopping NetworkStatsService...") | ||
|
||
await self._golem.aclose() | ||
self._golem = None | ||
|
||
logger.info("Stopping NetworkStatsService done") | ||
|
||
async def run(self, provider_parameters: Dict, duration_minutes: int) -> None: | ||
network: str = provider_parameters["network"] | ||
budget: int = provider_parameters["budget"] | ||
node_config: NodeConfigData = NodeConfigData(**provider_parameters["node_config"]) | ||
|
||
stack = await self._create_stack(node_config, budget, network) | ||
await stack.start() | ||
|
||
print(f"Gathering stats data for {duration_minutes} minutes...") | ||
consume_proposals_task = asyncio.create_task(self._consume_draft_proposals(stack)) | ||
try: | ||
await asyncio.wait( | ||
[consume_proposals_task], | ||
timeout=timedelta(minutes=duration_minutes).total_seconds(), | ||
) | ||
finally: | ||
consume_proposals_task.cancel() | ||
await consume_proposals_task | ||
|
||
await stack.stop() | ||
print("Gathering stats data done") | ||
self._stats_plugin_factory.print_gathered_stats() | ||
|
||
async def _consume_draft_proposals(self, stack: ManagerStack) -> None: | ||
drafts = [] | ||
try: | ||
while True: | ||
draft = await stack.proposal_manager.get_draft_proposal() | ||
drafts.append(draft) | ||
except asyncio.CancelledError: | ||
return | ||
finally: | ||
await asyncio.gather( | ||
# FIXME better reason message | ||
*[draft.reject(reason="No more needed") for draft in drafts], | ||
return_exceptions=True, | ||
) | ||
|
||
async def _create_stack( | ||
self, node_config: NodeConfigData, budget: float, network: str | ||
) -> ManagerStack: | ||
stack = ManagerStack() | ||
|
||
payload = await self._demand_config_helper.get_payload_from_demand_config( | ||
node_config.demand | ||
) | ||
|
||
ManagerStackNodeConfigHelper.apply_cost_management_avg_usage(stack, node_config) | ||
ManagerStackNodeConfigHelper.apply_cost_management_hard_limits(stack, node_config) | ||
|
||
stack.payment_manager = PayAllPaymentManager(self._golem, budget=budget, network=network) | ||
stack.demand_manager = RefreshingDemandManager( | ||
self._golem, | ||
stack.payment_manager.get_allocation, | ||
payload, | ||
demand_expiration_timeout=timedelta(hours=8), | ||
) | ||
|
||
plugins = [ | ||
self._stats_plugin_factory.create_counter_plugin("Initial"), | ||
BlacklistProviderIdPlugin(PROVIDERS_BLACKLIST.get(network, set())), | ||
self._stats_plugin_factory.create_counter_plugin("Not blacklisted"), | ||
] | ||
|
||
for plugin_tag, plugin in stack.extra_proposal_plugins.items(): | ||
plugins.append(plugin) | ||
plugins.append(self._stats_plugin_factory.create_counter_plugin(f"Passed {plugin_tag}")) | ||
|
||
plugins.extend( | ||
[ | ||
ScoringBuffer( | ||
min_size=50, | ||
max_size=1000, | ||
fill_at_start=True, | ||
proposal_scorers=(*stack.extra_proposal_scorers.values(),), | ||
update_interval=timedelta(seconds=10), | ||
), | ||
self._stats_plugin_factory.create_counter_plugin("Scored"), | ||
self._stats_plugin_factory.create_negotiating_plugin(), | ||
self._stats_plugin_factory.create_counter_plugin("Negotiated"), | ||
Buffer(min_size=1, max_size=50, fill_concurrency_size=10), | ||
] | ||
) | ||
stack.proposal_manager = DefaultProposalManager( | ||
self._golem, | ||
stack.demand_manager.get_initial_proposal, | ||
plugins=plugins, | ||
) | ||
|
||
return stack |
Oops, something went wrong.