Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixed allocation creation #1134

Merged
merged 9 commits into from
Jun 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 43 additions & 121 deletions tests/test_payment_platforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,167 +4,89 @@

from ya_payment import RequestorApi

from yapapi import NoPaymentAccountError
from yapapi.engine import DEFAULT_DRIVER, DEFAULT_NETWORK
from yapapi.golem import Golem
from yapapi.rest.payment import Account, Payment
from yapapi.engine import (
DEFAULT_DRIVER,
DEFAULT_NETWORK,
MAINNET_TOKEN_NAME,
TESTNET_TOKEN_NAME,
)
from yapapi.golem import Golem, _Engine


@pytest.fixture(autouse=True)
def _set_app_key(monkeypatch):
monkeypatch.setenv("YAGNA_APPKEY", "mock-appkey")


def _mock_accounts_iterator(*account_specs):
"""Create an iterator over mock `Account` objects.

`account_specs` should contain pairs `(driver, network)`, where `driver` and `network`
are strings, or triples `(driver, network, params)` with `driver` and `network` as before
and `params` a dictionary containing additional keyword arguments for `Account()`.
"""

async def _mock(*_args):
for spec in account_specs:
params = {
"platform": "mock-platform",
"address": "mock-address",
"driver": spec[0],
"network": spec[1],
"token": "mock-token",
"send": True,
"receive": True,
}
if len(spec) == 3:
params.update(**spec[2])
yield Account(**params)

return _mock


class _StopExecutor(Exception):
"""An exception raised to stop the test when reaching an expected checkpoint in executor."""


@pytest.fixture()
def _mock_decorate_demand(monkeypatch):
"""Make `Payment.decorate_demand()` stop the test."""
def _mock_engine_id(monkeypatch):
"""Mock Engine `id`."""

async def _id(_):
return

monkeypatch.setattr(
Payment,
"decorate_demand",
mock.Mock(side_effect=_StopExecutor("decorate_demand() called")),
_Engine,
"_id",
_id,
)


@pytest.fixture()
def _mock_create_allocation(monkeypatch):
"""Make `RequestorApi.create_allocation()` stop the test."""
monkeypatch.setattr(
RequestorApi,
"create_allocation",
mock.Mock(side_effect=_StopExecutor("create_allocation() called")),
)

create_allocation_mock = mock.Mock(side_effect=_StopExecutor("create_allocation() called"))

@pytest.mark.asyncio
async def test_no_accounts_raises(monkeypatch):
"""Test that exception is raised if `Payment.accounts()` returns empty list."""
monkeypatch.setattr(RequestorApi, "create_allocation", create_allocation_mock)

monkeypatch.setattr(Payment, "accounts", _mock_accounts_iterator())

with pytest.raises(NoPaymentAccountError):
async with Golem(budget=10.0):
pass
return create_allocation_mock


@pytest.mark.asyncio
async def test_no_matching_account_raises(monkeypatch):
"""Test that exception is raised if `Payment.accounts()` returns no matching accounts."""
async def test_default(_mock_engine_id, _mock_create_allocation):
"""Test the allocation defaults."""

monkeypatch.setattr(
Payment,
"accounts",
_mock_accounts_iterator(
("other-driver", "other-network"),
("matching-driver", "other-network"),
("other-driver", "matching-network"),
),
)

with pytest.raises(NoPaymentAccountError) as exc_info:
async with Golem(
budget=10.0, payment_driver="matching-driver", payment_network="matching-network"
):
with pytest.raises(_StopExecutor):
async with Golem(budget=10.0):
pass

exc = exc_info.value
assert exc.required_driver == "matching-driver"
assert exc.required_network == "matching-network"


@pytest.mark.asyncio
async def test_matching_account_creates_allocation(monkeypatch, _mock_decorate_demand):
"""Test that matching accounts are correctly selected and allocations are created for them."""

monkeypatch.setattr(
Payment,
"accounts",
_mock_accounts_iterator(
("other-driver", "other-network"),
("matching-driver", "matching-network", {"platform": "platform-1"}),
("matching-driver", "other-network"),
("other-driver", "matching-network"),
("matching-driver", "matching-network", {"platform": "platform-2"}),
),
assert _mock_create_allocation.called
assert (
_mock_create_allocation.mock_calls[0][1][0].payment_platform
== f"{DEFAULT_DRIVER}-{DEFAULT_NETWORK}-{TESTNET_TOKEN_NAME}"
)

create_allocation_args = []

async def mock_create_allocation(_self, model):
create_allocation_args.append(model)
return mock.Mock()

async def mock_release_allocation(*args, **kwargs):
pass

monkeypatch.setattr(RequestorApi, "create_allocation", mock_create_allocation)
monkeypatch.setattr(RequestorApi, "release_allocation", mock_release_allocation)
@pytest.mark.asyncio
async def test_mainnet(_mock_engine_id, _mock_create_allocation):
"""Test the allocation for a mainnet account."""

with pytest.raises(_StopExecutor):
async with Golem(
budget=10.0,
payment_driver="matching-driver",
payment_network="matching-network",
):
async with Golem(budget=10.0, payment_driver="somedriver", payment_network="mainnet"):
pass

assert len(create_allocation_args) == 2
assert create_allocation_args[0].payment_platform == "platform-1"
assert create_allocation_args[1].payment_platform == "platform-2"
assert _mock_create_allocation.called
assert (
_mock_create_allocation.mock_calls[0][1][0].payment_platform
== f"somedriver-mainnet-{MAINNET_TOKEN_NAME}"
)


@pytest.mark.asyncio
async def test_driver_network_case_insensitive(monkeypatch, _mock_create_allocation):
"""Test that matching driver and network names is not case sensitive."""

monkeypatch.setattr(Payment, "accounts", _mock_accounts_iterator(("dRIVER", "NetWORK")))
async def test_testnet(_mock_engine_id, _mock_create_allocation):
"""Test the allocation for a mainnet account."""

with pytest.raises(_StopExecutor):
async with Golem(
budget=10.0,
payment_driver="dRiVeR",
payment_network="NeTwOrK",
):
async with Golem(budget=10.0, payment_driver="somedriver", payment_network="othernet"):
pass


@pytest.mark.asyncio
async def test_default_driver_network(monkeypatch, _mock_create_allocation):
"""Test that defaults are used if driver and network are not specified."""

monkeypatch.setattr(
Payment, "accounts", _mock_accounts_iterator((DEFAULT_DRIVER, DEFAULT_NETWORK))
assert _mock_create_allocation.called
assert (
_mock_create_allocation.mock_calls[0][1][0].payment_platform
== f"somedriver-othernet-{TESTNET_TOKEN_NAME}"
)

with pytest.raises(_StopExecutor):
async with Golem(budget=10.0):
pass
61 changes: 30 additions & 31 deletions yapapi/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from datetime import datetime, timezone
from decimal import Decimal
import itertools
import json
import logging
import os
import sys
Expand Down Expand Up @@ -57,6 +58,10 @@

MAX_CONCURRENTLY_PROCESSED_DEBIT_NOTES: Final[int] = 10

MAINNET_NETWORKS: Set[str] = {"mainnet", "polygon"}
MAINNET_TOKEN_NAME: str = "glm"
TESTNET_TOKEN_NAME: str = "tglm"

logger = logging.getLogger("yapapi.executor")


Expand Down Expand Up @@ -103,6 +108,7 @@ def __init__(
subnet_tag: Optional[str] = None,
payment_driver: Optional[str] = None,
payment_network: Optional[str] = None,
payment_token: Optional[str] = None,
stream_output: bool = False,
app_key: Optional[str] = None,
):
Expand Down Expand Up @@ -135,6 +141,13 @@ def __init__(
self._subnet: Optional[str] = subnet_tag or DEFAULT_SUBNET
self._payment_driver: str = payment_driver.lower() if payment_driver else DEFAULT_DRIVER
self._payment_network: str = payment_network.lower() if payment_network else DEFAULT_NETWORK
self._payment_token: str = (
payment_token.lower()
if payment_token
else MAINNET_TOKEN_NAME
if self._payment_network in MAINNET_NETWORKS
else TESTNET_TOKEN_NAME
)
self._stream_output = stream_output

# a set of `Job` instances used to track jobs - computations or services - started
Expand Down Expand Up @@ -332,40 +345,26 @@ async def _shutdown(self, *exc_info):
except Exception:
logger.debug("Got error when waiting for services to finish", exc_info=True)

async def _create_allocations(self) -> rest.payment.MarketDecoration:
async def _id(self) -> str:
async with self._root_api_session.get(f"{self._api_config.root_url}/me") as resp:
return json.loads(await resp.text()).get("identity")

async def _create_allocations(self) -> rest.payment.MarketDecoration:
if not self._budget_allocations:
async for account in self._payment_api.accounts():
driver = account.driver.lower()
network = account.network.lower()
if (driver, network) != (self._payment_driver, self._payment_network):
logger.debug(
"Not using payment platform `%s`, platform's driver/network "
"`%s`/`%s` is different than requested driver/network `%s`/`%s`",
account.platform,
driver,
network,
self._payment_driver,
self._payment_network,
platform = f"{self._payment_driver}-{self._payment_network}-{self._payment_token}"
address = await self._id()
allocation = cast(
rest.payment.Allocation,
await self._stack.enter_async_context(
self._payment_api.new_allocation(
self._budget_amount,
payment_platform=platform,
payment_address=address,
)
continue
logger.debug("Creating allocation using payment platform `%s`", account.platform)
allocation = cast(
rest.payment.Allocation,
await self._stack.enter_async_context(
self._payment_api.new_allocation(
self._budget_amount,
payment_platform=account.platform,
payment_address=account.address,
# TODO what do to with this?
# expires=self._expires + CFG_INVOICE_TIMEOUT,
)
),
)
self._budget_allocations.append(allocation)

if not self._budget_allocations:
raise NoPaymentAccountError(self._payment_driver, self._payment_network)
),
)
logger.debug("Creating allocation using payment platform `%s`", platform)
self._budget_allocations.append(allocation)

allocation_ids = [allocation.id for allocation in self._budget_allocations]
return await self._payment_api.decorate_demand(allocation_ids)
Expand Down