-
Notifications
You must be signed in to change notification settings - Fork 22
/
test_resubscription.py
191 lines (145 loc) · 6.27 KB
/
test_resubscription.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
"""Test if subscription expiration is handled correctly by Executor"""
from datetime import timedelta
import logging
import os
from pathlib import Path
import time
from typing import Dict, Set, Type
from unittest.mock import Mock
import colors
import pytest
from goth.assertions import EventStream
from goth.assertions.monitor import EventMonitor
from goth.assertions.operators import eventually
from goth.configuration import load_yaml
from goth.runner import Runner
from goth.runner.log import configure_logging
from goth.runner.probe import RequestorProbe
from yapapi import Executor, Task
from yapapi.executor.events import (
Event,
ComputationStarted,
ComputationFinished,
SubscriptionCreated,
)
import yapapi.rest.market
from yapapi.log import enable_default_logger
from yapapi.package import vm
import ya_market.api.requestor_api
from ya_market import ApiException
logger = logging.getLogger("goth.test")
SUBSCRIPTION_EXPIRATION_TIME = 5
"""Number of seconds after which a subscription expires"""
class RequestorApi(ya_market.api.requestor_api.RequestorApi):
"""A replacement for market API that simulates early subscription expiration.
A call to `collect_offers(sub_id)` will raise `ApiException` indicating
subscription expiration when at least `SUBSCRIPTION_EXPIRATION_TIME`
elapsed after the given subscription has been created.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.subscriptions: Dict[str, float] = {}
def subscribe_demand(self, demand, **kwargs):
"""Override `RequestorApi.subscribe_demand()` to register subscription create time."""
id_coro = super().subscribe_demand(demand, **kwargs)
async def coro():
id = await id_coro
self.subscriptions[id] = time.time()
return id
return coro()
def collect_offers(self, subscription_id, **kwargs):
"""Override `RequestorApi.collect_offers()`.
Raise `ApiException(404)` if at least `SUBSCRIPTION_EXPIRATION_TIME` elapsed
since the subscription identified by `subscription_id` has been created.
"""
if time.time() > self.subscriptions[subscription_id] + SUBSCRIPTION_EXPIRATION_TIME:
logger.info("Subscription expired")
async def coro():
raise ApiException(
http_resp=Mock(
status=404,
reason="Not Found",
data=f"{{'message': 'Subscription [{subscription_id}] expired.'}}",
)
)
return coro()
else:
return super().collect_offers(subscription_id, **kwargs)
@pytest.fixture(autouse=True)
def patch_collect_offers(monkeypatch):
"""Install the patched `RequestorApi` class."""
monkeypatch.setattr(yapapi.rest.market, "RequestorApi", RequestorApi)
async def unsubscribe_demand(sub_id: str) -> None:
"""Auxiliary function that calls `unsubscribeDemand` operation for given `sub_id`."""
config = yapapi.rest.Configuration()
market_client = config.market()
requestor_api = yapapi.rest.market.RequestorApi(market_client)
await requestor_api.unsubscribe_demand(sub_id)
async def assert_demand_resubscribed(events: "EventStream[Event]"):
"""A temporal assertion that the requestor will have to satisfy."""
subscription_ids: Set[str] = set()
async def wait_for_event(event_type: Type[Event], timeout: float):
e = await eventually(events, lambda e: isinstance(e, event_type), timeout)
assert e, f"Timed out waiting for {event_type}"
logger.info(colors.cyan(str(e)))
return e
e = await wait_for_event(ComputationStarted, 10)
# Make sure new subscriptions are created at least three times
while len(subscription_ids) < 3:
e = await wait_for_event(SubscriptionCreated, SUBSCRIPTION_EXPIRATION_TIME + 10)
assert e.sub_id not in subscription_ids
subscription_ids.add(e.sub_id)
# Unsubscribe and make sure new subscription is created
await unsubscribe_demand(e.sub_id)
logger.info("Demand unsubscribed")
await wait_for_event(SubscriptionCreated, 5)
# Enough checking, wait until the computation finishes
await wait_for_event(ComputationFinished, 20)
@pytest.mark.asyncio
async def test_demand_resubscription(log_dir: Path, monkeypatch) -> None:
"""Test that checks that a demand is re-submitted after its previous submission expires."""
configure_logging(log_dir)
# Override the default test configuration to create only one provider node
nodes = [
{"name": "requestor", "type": "Requestor"},
{"name": "provider-1", "type": "VM-Wasm-Provider", "use-proxy": True},
]
goth_config = load_yaml(
Path(__file__).parent / "assets" / "goth-config.yml", [("nodes", nodes)]
)
vm_package = await vm.repo(
image_hash="9a3b5d67b0b27746283cb5f287c13eab1beaa12d92a9f536b747c7ae",
min_mem_gib=0.5,
min_storage_gib=2.0,
)
runner = Runner(base_log_dir=log_dir, compose_config=goth_config.compose_config)
async with runner(goth_config.containers):
requestor = runner.get_probes(probe_type=RequestorProbe)[0]
env = {**os.environ}
requestor.set_agent_env_vars(env)
# Setup the environment for the requestor
for key, val in env.items():
monkeypatch.setenv(key, val)
monitor = EventMonitor()
monitor.add_assertion(assert_demand_resubscribed)
monitor.start()
# The requestor
enable_default_logger()
async def worker(work_ctx, tasks):
async for task in tasks:
work_ctx.run("/bin/sleep", "5")
yield work_ctx.commit()
task.accept_result()
async with Executor(
budget=10.0,
package=vm_package,
max_workers=1,
timeout=timedelta(seconds=30),
event_consumer=monitor.add_event_sync,
) as executor:
task: Task # mypy needs this for some reason
async for task in executor.submit(worker, [Task(data=n) for n in range(20)]):
logger.info("Task %d computed", task.data)
await monitor.stop()
for a in monitor.failed:
raise a.result()