-
Notifications
You must be signed in to change notification settings - Fork 0
/
batch.py
149 lines (114 loc) · 3.98 KB
/
batch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from __future__ import annotations
import asyncio
from abc import ABC
from abc import abstractmethod
from asyncio import Future
from asyncio import Task
from asyncio import TimerHandle
from functools import cache
from typing import Awaitable
from typing import ClassVar
from typing import Generic
from typing import Iterable
from typing import Mapping
from typing import Optional
from typing import TypeVar
Tk = TypeVar('Tk')
Tv = TypeVar('Tv')
class Batch(Generic[Tk, Tv], ABC):
tasks: ClassVar[set[Task]] = set()
timer_handle: ClassVar[Optional[TimerHandle]] = None
@classmethod
@cache
def get_futures(cls) -> dict[Tk, Future[Tv]]:
return {}
@classmethod
async def resolve(cls, futures: dict[Tk, Future[Tv]]) -> None:
batch_keys = list(futures.keys())
print(f'>>> flush {cls.__name__}({len(batch_keys)}) {batch_keys}')
future_results = await cls.resolve_futures(batch_keys)
if future_results.keys() != futures.keys():
raise ValueError('Batch resolved an incomplete set of future keys')
for key, result in future_results.items():
futures[key].set_result(result)
@staticmethod
def schedule_batches() -> None:
loop = asyncio.get_event_loop()
for batch in Batch.__subclasses__():
if futures := batch.get_futures():
task = loop.create_task(batch.resolve(futures.copy()))
batch.tasks.add(task)
task.add_done_callback(batch.tasks.discard)
futures.clear()
Batch.timer_handle = None
# Internal interface
@staticmethod
@abstractmethod
async def resolve_futures(batch: Iterable[Tk]) -> Mapping[Tk, Tv]:
raise NotImplementedError
@classmethod
def schedule(cls, key: Tk) -> Awaitable[Tv]:
loop = asyncio.get_event_loop()
if not Batch.timer_handle:
Batch.timer_handle = loop.call_later(0, Batch.schedule_batches)
return cls.get_futures().setdefault(key, loop.create_future())
# External interface
@classmethod
async def gen(cls, key: Tk) -> Tv:
return await cls.schedule(key)
@classmethod
async def genv(cls, keys: Iterable[Tk]) -> Iterable[Tv]:
return await asyncio.gather(*[cls.gen(key) for key in keys])
class DoubleBatch(Batch[int, int]):
@staticmethod
async def resolve_futures(batch: Iterable[int]) -> Mapping[int, int]:
await asyncio.sleep(1)
return {x: x + x for x in batch}
class SquareBatch(Batch):
@staticmethod
async def resolve_futures(batch: Iterable[int]) -> Mapping[int, int]:
await asyncio.sleep(1)
return {x: x * x for x in batch}
async def double_square(x: int) -> int:
double = await DoubleBatch.gen(x)
square = await SquareBatch.gen(double)
return square
async def square_double(x: int) -> int:
square = await SquareBatch.gen(x)
double = await DoubleBatch.gen(square)
return double
async def triple_double(x: int) -> int:
d1 = await DoubleBatch.gen(x)
d2 = await DoubleBatch.gen(d1)
d3 = await DoubleBatch.gen(d2)
return d3
async def double_square_square_double(x: int) -> int:
ds = await double_square(x)
sd = await square_double(ds)
return sd
async def root():
x = await asyncio.gather(
square_double(10),
square_double(20),
square_double(30),
double_square(8),
double_square(9),
double_square(10),
DoubleBatch.genv([1, 2, 3, 4, 5, 6]),
SquareBatch.genv([-1, -2, -3, -4, -5, -6]),
triple_double(100),
triple_double(200),
triple_double(300),
double_square_square_double(123),
double_square_square_double(456),
double_square_square_double(789),
)
print(x)
assert x == [
200, 800, 1800, 256, 324, 400,
[2, 4, 6, 8, 10, 12],
[1, 4, 9, 16, 25, 36],
800, 1600, 2400, 7324372512, 1383596163072, 12401036654112,
]
if __name__ == '__main__':
asyncio.run(root())