Skip to content

Commit

Permalink
fixed allocation creation (#1134)
Browse files Browse the repository at this point in the history
* fixed allocation creation

* identity

* unit test fixes

---------

Co-authored-by: shadeofblue <[email protected]>
Co-authored-by: shadeofblue <[email protected]>
  • Loading branch information
3 people authored Jun 7, 2023
1 parent 37a9967 commit a63cce3
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 152 deletions.
164 changes: 43 additions & 121 deletions tests/test_payment_platforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,167 +4,89 @@

from ya_payment import RequestorApi

from yapapi import NoPaymentAccountError
from yapapi.engine import DEFAULT_DRIVER, DEFAULT_NETWORK
from yapapi.golem import Golem
from yapapi.rest.payment import Account, Payment
from yapapi.engine import (
DEFAULT_DRIVER,
DEFAULT_NETWORK,
MAINNET_TOKEN_NAME,
TESTNET_TOKEN_NAME,
)
from yapapi.golem import Golem, _Engine


@pytest.fixture(autouse=True)
def _set_app_key(monkeypatch):
monkeypatch.setenv("YAGNA_APPKEY", "mock-appkey")


def _mock_accounts_iterator(*account_specs):
"""Create an iterator over mock `Account` objects.
`account_specs` should contain pairs `(driver, network)`, where `driver` and `network`
are strings, or triples `(driver, network, params)` with `driver` and `network` as before
and `params` a dictionary containing additional keyword arguments for `Account()`.
"""

async def _mock(*_args):
for spec in account_specs:
params = {
"platform": "mock-platform",
"address": "mock-address",
"driver": spec[0],
"network": spec[1],
"token": "mock-token",
"send": True,
"receive": True,
}
if len(spec) == 3:
params.update(**spec[2])
yield Account(**params)

return _mock


class _StopExecutor(Exception):
"""An exception raised to stop the test when reaching an expected checkpoint in executor."""


@pytest.fixture()
def _mock_decorate_demand(monkeypatch):
"""Make `Payment.decorate_demand()` stop the test."""
def _mock_engine_id(monkeypatch):
"""Mock Engine `id`."""

async def _id(_):
return

monkeypatch.setattr(
Payment,
"decorate_demand",
mock.Mock(side_effect=_StopExecutor("decorate_demand() called")),
_Engine,
"_id",
_id,
)


@pytest.fixture()
def _mock_create_allocation(monkeypatch):
"""Make `RequestorApi.create_allocation()` stop the test."""
monkeypatch.setattr(
RequestorApi,
"create_allocation",
mock.Mock(side_effect=_StopExecutor("create_allocation() called")),
)

create_allocation_mock = mock.Mock(side_effect=_StopExecutor("create_allocation() called"))

@pytest.mark.asyncio
async def test_no_accounts_raises(monkeypatch):
"""Test that exception is raised if `Payment.accounts()` returns empty list."""
monkeypatch.setattr(RequestorApi, "create_allocation", create_allocation_mock)

monkeypatch.setattr(Payment, "accounts", _mock_accounts_iterator())

with pytest.raises(NoPaymentAccountError):
async with Golem(budget=10.0):
pass
return create_allocation_mock


@pytest.mark.asyncio
async def test_no_matching_account_raises(monkeypatch):
"""Test that exception is raised if `Payment.accounts()` returns no matching accounts."""
async def test_default(_mock_engine_id, _mock_create_allocation):
"""Test the allocation defaults."""

monkeypatch.setattr(
Payment,
"accounts",
_mock_accounts_iterator(
("other-driver", "other-network"),
("matching-driver", "other-network"),
("other-driver", "matching-network"),
),
)

with pytest.raises(NoPaymentAccountError) as exc_info:
async with Golem(
budget=10.0, payment_driver="matching-driver", payment_network="matching-network"
):
with pytest.raises(_StopExecutor):
async with Golem(budget=10.0):
pass

exc = exc_info.value
assert exc.required_driver == "matching-driver"
assert exc.required_network == "matching-network"


@pytest.mark.asyncio
async def test_matching_account_creates_allocation(monkeypatch, _mock_decorate_demand):
"""Test that matching accounts are correctly selected and allocations are created for them."""

monkeypatch.setattr(
Payment,
"accounts",
_mock_accounts_iterator(
("other-driver", "other-network"),
("matching-driver", "matching-network", {"platform": "platform-1"}),
("matching-driver", "other-network"),
("other-driver", "matching-network"),
("matching-driver", "matching-network", {"platform": "platform-2"}),
),
assert _mock_create_allocation.called
assert (
_mock_create_allocation.mock_calls[0][1][0].payment_platform
== f"{DEFAULT_DRIVER}-{DEFAULT_NETWORK}-{TESTNET_TOKEN_NAME}"
)

create_allocation_args = []

async def mock_create_allocation(_self, model):
create_allocation_args.append(model)
return mock.Mock()

async def mock_release_allocation(*args, **kwargs):
pass

monkeypatch.setattr(RequestorApi, "create_allocation", mock_create_allocation)
monkeypatch.setattr(RequestorApi, "release_allocation", mock_release_allocation)
@pytest.mark.asyncio
async def test_mainnet(_mock_engine_id, _mock_create_allocation):
"""Test the allocation for a mainnet account."""

with pytest.raises(_StopExecutor):
async with Golem(
budget=10.0,
payment_driver="matching-driver",
payment_network="matching-network",
):
async with Golem(budget=10.0, payment_driver="somedriver", payment_network="mainnet"):
pass

assert len(create_allocation_args) == 2
assert create_allocation_args[0].payment_platform == "platform-1"
assert create_allocation_args[1].payment_platform == "platform-2"
assert _mock_create_allocation.called
assert (
_mock_create_allocation.mock_calls[0][1][0].payment_platform
== f"somedriver-mainnet-{MAINNET_TOKEN_NAME}"
)


@pytest.mark.asyncio
async def test_driver_network_case_insensitive(monkeypatch, _mock_create_allocation):
"""Test that matching driver and network names is not case sensitive."""

monkeypatch.setattr(Payment, "accounts", _mock_accounts_iterator(("dRIVER", "NetWORK")))
async def test_testnet(_mock_engine_id, _mock_create_allocation):
"""Test the allocation for a mainnet account."""

with pytest.raises(_StopExecutor):
async with Golem(
budget=10.0,
payment_driver="dRiVeR",
payment_network="NeTwOrK",
):
async with Golem(budget=10.0, payment_driver="somedriver", payment_network="othernet"):
pass


@pytest.mark.asyncio
async def test_default_driver_network(monkeypatch, _mock_create_allocation):
"""Test that defaults are used if driver and network are not specified."""

monkeypatch.setattr(
Payment, "accounts", _mock_accounts_iterator((DEFAULT_DRIVER, DEFAULT_NETWORK))
assert _mock_create_allocation.called
assert (
_mock_create_allocation.mock_calls[0][1][0].payment_platform
== f"somedriver-othernet-{TESTNET_TOKEN_NAME}"
)

with pytest.raises(_StopExecutor):
async with Golem(budget=10.0):
pass
61 changes: 30 additions & 31 deletions yapapi/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from datetime import datetime, timezone
from decimal import Decimal
import itertools
import json
import logging
import os
import sys
Expand Down Expand Up @@ -57,6 +58,10 @@

MAX_CONCURRENTLY_PROCESSED_DEBIT_NOTES: Final[int] = 10

MAINNET_NETWORKS: Set[str] = {"mainnet", "polygon"}
MAINNET_TOKEN_NAME: str = "glm"
TESTNET_TOKEN_NAME: str = "tglm"

logger = logging.getLogger("yapapi.executor")


Expand Down Expand Up @@ -103,6 +108,7 @@ def __init__(
subnet_tag: Optional[str] = None,
payment_driver: Optional[str] = None,
payment_network: Optional[str] = None,
payment_token: Optional[str] = None,
stream_output: bool = False,
app_key: Optional[str] = None,
):
Expand Down Expand Up @@ -135,6 +141,13 @@ def __init__(
self._subnet: Optional[str] = subnet_tag or DEFAULT_SUBNET
self._payment_driver: str = payment_driver.lower() if payment_driver else DEFAULT_DRIVER
self._payment_network: str = payment_network.lower() if payment_network else DEFAULT_NETWORK
self._payment_token: str = (
payment_token.lower()
if payment_token
else MAINNET_TOKEN_NAME
if self._payment_network in MAINNET_NETWORKS
else TESTNET_TOKEN_NAME
)
self._stream_output = stream_output

# a set of `Job` instances used to track jobs - computations or services - started
Expand Down Expand Up @@ -332,40 +345,26 @@ async def _shutdown(self, *exc_info):
except Exception:
logger.debug("Got error when waiting for services to finish", exc_info=True)

async def _create_allocations(self) -> rest.payment.MarketDecoration:
async def _id(self) -> str:
async with self._root_api_session.get(f"{self._api_config.root_url}/me") as resp:
return json.loads(await resp.text()).get("identity")

async def _create_allocations(self) -> rest.payment.MarketDecoration:
if not self._budget_allocations:
async for account in self._payment_api.accounts():
driver = account.driver.lower()
network = account.network.lower()
if (driver, network) != (self._payment_driver, self._payment_network):
logger.debug(
"Not using payment platform `%s`, platform's driver/network "
"`%s`/`%s` is different than requested driver/network `%s`/`%s`",
account.platform,
driver,
network,
self._payment_driver,
self._payment_network,
platform = f"{self._payment_driver}-{self._payment_network}-{self._payment_token}"
address = await self._id()
allocation = cast(
rest.payment.Allocation,
await self._stack.enter_async_context(
self._payment_api.new_allocation(
self._budget_amount,
payment_platform=platform,
payment_address=address,
)
continue
logger.debug("Creating allocation using payment platform `%s`", account.platform)
allocation = cast(
rest.payment.Allocation,
await self._stack.enter_async_context(
self._payment_api.new_allocation(
self._budget_amount,
payment_platform=account.platform,
payment_address=account.address,
# TODO what do to with this?
# expires=self._expires + CFG_INVOICE_TIMEOUT,
)
),
)
self._budget_allocations.append(allocation)

if not self._budget_allocations:
raise NoPaymentAccountError(self._payment_driver, self._payment_network)
),
)
logger.debug("Creating allocation using payment platform `%s`", platform)
self._budget_allocations.append(allocation)

allocation_ids = [allocation.id for allocation in self._budget_allocations]
return await self._payment_api.decorate_demand(allocation_ids)
Expand Down

0 comments on commit a63cce3

Please sign in to comment.