Skip to content

Commit

Permalink
Fixes in executor/__init__.py to pass unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
azawlocki committed May 18, 2021
1 parent 142d3a3 commit 8c13ca0
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 38 deletions.
18 changes: 9 additions & 9 deletions tests/executor/test_payment_platforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ async def test_no_accounts_raises(monkeypatch):

monkeypatch.setattr(Payment, "accounts", _mock_accounts_iterator())

async with Executor(package=mock.Mock(), budget=10.0) as executor:
with pytest.raises(NoPaymentAccountError):
with pytest.raises(NoPaymentAccountError):
async with Executor(package=mock.Mock(), budget=10.0) as executor:
async for _ in executor.submit(worker=mock.Mock(), data=mock.Mock()):
pass

Expand All @@ -90,15 +90,15 @@ async def test_no_matching_account_raises(monkeypatch):
),
)

async with Executor(
package=mock.Mock(), budget=10.0, driver="matching-driver", network="matching-network"
) as executor:
with pytest.raises(NoPaymentAccountError) as exc_info:
with pytest.raises(NoPaymentAccountError) as exc_info:
async with Executor(
package=mock.Mock(), budget=10.0, driver="matching-driver", network="matching-network"
) as executor:
async for _ in executor.submit(worker=mock.Mock(), data=mock.Mock()):
pass
exc = exc_info.value
assert exc.required_driver == "matching-driver"
assert exc.required_network == "matching-network"
exc = exc_info.value
assert exc.required_driver == "matching-driver"
assert exc.required_network == "matching-network"


@pytest.mark.asyncio
Expand Down
10 changes: 5 additions & 5 deletions tests/executor/test_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from unittest.mock import Mock

from yapapi.executor import Executor
from yapapi.executor import Golem
from yapapi.executor.strategy import (
DecreaseScoreForUnconfirmedAgreement,
LeastExpensiveLinearPayuMS,
Expand Down Expand Up @@ -130,8 +130,8 @@ async def test_default_strategy_type(monkeypatch):

monkeypatch.setattr(yapapi.rest, "Configuration", Mock)

executor = Executor(package=Mock(), budget=1.0)
default_strategy = executor.strategy
golem = Golem(budget=1.0)
default_strategy = golem.strategy
assert isinstance(default_strategy, DecreaseScoreForUnconfirmedAgreement)
assert isinstance(default_strategy.base_strategy, LeastExpensiveLinearPayuMS)

Expand All @@ -143,8 +143,8 @@ async def test_user_strategy_not_modified(monkeypatch):
monkeypatch.setattr(yapapi.rest, "Configuration", Mock)

user_strategy = Mock()
executor = Executor(package=Mock(), budget=1.0, strategy=user_strategy)
assert executor.strategy == user_strategy
golem = Golem(budget=1.0, strategy=user_strategy)
assert golem.strategy == user_strategy


class TestLeastExpensiveLinearPayuMS:
Expand Down
63 changes: 39 additions & 24 deletions yapapi/executor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,11 @@ def __init__(

# Add buffering to the provided event emitter to make sure
# that emitting events will not block
self._wrapped_consumer = AsyncWrapper(event_consumer)
# TODO: make AsyncWrapper an AsyncContextManager and start it in
# in __aenter__(); if it's started here then there's no guarantee that
# it will be cancelled properly
self._wrapped_consumer: Optional[AsyncWrapper] = None
self._event_consumer = event_consumer

self._stream_output = stream_output

Expand All @@ -185,8 +189,12 @@ def __init__(
self._invoices: Dict[str, rest.payment.Invoice] = dict()
self._payment_closing: bool = False

self._services: Set[asyncio.Task] = set()
# a set of `Job` instances used to track jobs - computations or services - started
# it can be used to wait until all jobs are finished
self._jobs: Set[Job] = set()
self._process_invoices_job: Optional[asyncio.Task] = None

self._services: Set[asyncio.Task] = set()
self._stack = AsyncExitStack()

async def create_demand_builder(
Expand Down Expand Up @@ -225,34 +233,39 @@ def strategy(self) -> MarketStrategy:
return self._strategy

def emit(self, *args, **kwargs) -> None:
self._wrapped_consumer.async_call(*args, **kwargs)
if self._wrapped_consumer:
self._wrapped_consumer.async_call(*args, **kwargs)

async def __aenter__(self) -> "Golem":
stack = self._stack
try:
stack = self._stack

market_client = await stack.enter_async_context(self._api_config.market())
self._market_api = rest.Market(market_client)
self._wrapped_consumer = AsyncWrapper(self._event_consumer)

activity_client = await stack.enter_async_context(self._api_config.activity())
self._activity_api = rest.Activity(activity_client)
market_client = await stack.enter_async_context(self._api_config.market())
self._market_api = rest.Market(market_client)

payment_client = await stack.enter_async_context(self._api_config.payment())
self._payment_api = rest.Payment(payment_client)
activity_client = await stack.enter_async_context(self._api_config.activity())
self._activity_api = rest.Activity(activity_client)

# a set of `Job` instances used to track jobs - computations or services - started
# it can be used to wait until all jobs are finished
self._jobs: Set[Job] = set()
payment_client = await stack.enter_async_context(self._api_config.payment())
self._payment_api = rest.Payment(payment_client)

self.payment_decoration = Golem.PaymentDecoration(await self._create_allocations())
self.payment_decoration = Golem.PaymentDecoration(await self._create_allocations())

loop = asyncio.get_event_loop()
self._process_invoices_job = loop.create_task(self.process_invoices())
self._services.add(self._process_invoices_job)
self._services.add(loop.create_task(self.process_debit_notes()))
# TODO: make the method starting the process_invoices() task an async context manager
# to simplify code in __aexit__()
loop = asyncio.get_event_loop()
self._process_invoices_job = loop.create_task(self.process_invoices())
self._services.add(self._process_invoices_job)
self._services.add(loop.create_task(self.process_debit_notes()))

self._storage_manager = await self._stack.enter_async_context(gftp.provider())
self._storage_manager = await self._stack.enter_async_context(gftp.provider())

return self
return self
except:
await self.__aexit__(*sys.exc_info())
raise

async def __aexit__(self, exc_type, exc_val, exc_tb):
# Importing this at the beginning would cause circular dependencies
Expand All @@ -268,7 +281,9 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
if task is not self._process_invoices_job:
task.cancel()

if not any(True for job in self._jobs if job.agreements_pool.confirmed > 0):
if self._process_invoices_job and not any(
True for job in self._jobs if job.agreements_pool.confirmed > 0
):
logger.debug("No need to wait for invoices.")
self._process_invoices_job.cancel()

Expand All @@ -283,7 +298,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
# TODO: add message
logger.debug("TODO", exc_info=True)

if self._agreements_to_pay:
if self._agreements_to_pay and self._process_invoices_job:
logger.info(
"%s still unpaid, waiting for invoices...",
pluralize(len(self._agreements_to_pay), "agreement"),
Expand All @@ -301,7 +316,8 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
except Exception:
self.emit(events.ShutdownFinished(exc_info=sys.exc_info()))
finally:
await self._wrapped_consumer.stop()
if self._wrapped_consumer:
await self._wrapped_consumer.stop()

async def _create_allocations(self) -> rest.payment.MarketDecoration:

Expand Down Expand Up @@ -868,7 +884,6 @@ def network(self) -> str:
async def __aenter__(self) -> "Executor":
if self.__standalone:
await self._stack.enter_async_context(self._engine)

self._expires = datetime.now(timezone.utc) + self._timeout
return self

Expand Down

0 comments on commit 8c13ca0

Please sign in to comment.