-
Notifications
You must be signed in to change notification settings - Fork 2
/
conftest.py
117 lines (83 loc) · 2.58 KB
/
conftest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import queue
import typing
from collections import defaultdict
from contextlib import contextmanager
from functools import wraps
from threading import Thread
from unittest.mock import patch
import pytest
from auth import auth_backend
from celeryapp import app
from daras_ai_v2.base import BasePage
from daras_ai_v2.send_email import pytest_outbox
def flaky(fn):
max_tries = 5
@wraps(fn)
def wrapper(*args, **kwargs):
for i in range(max_tries):
try:
return fn(*args, **kwargs)
except Exception:
if i == max_tries - 1:
raise
return wrapper
@pytest.fixture
def db_fixtures(transactional_db):
from django.core.management import call_command
print("Loading fixtures from fixture.json")
call_command("loaddata", "fixture.json")
@pytest.fixture
def force_authentication():
with auth_backend.force_authentication() as user:
yield user
app.conf.task_always_eager = True
redis_qs = defaultdict(queue.Queue)
@pytest.fixture
def mock_celery_tasks():
with (
patch("celeryapp.tasks.runner_task", _mock_runner_task),
patch("celeryapp.tasks.post_runner_tasks", _mock_post_runner_tasks),
patch("gooey_gui.realtime_subscribe", _mock_realtime_subscribe),
):
yield
@app.task
def _mock_runner_task(
*, page_cls: typing.Type[BasePage], run_id: str, uid: str, **kwargs
):
sr = page_cls.get_sr_from_ids(run_id, uid)
sr.set(sr.parent.to_dict())
sr.save()
channel = page_cls.realtime_channel_name(run_id, uid)
_mock_realtime_push(channel, sr.to_dict())
@app.task
def _mock_post_runner_tasks(*args, **kwargs):
pass
def _mock_realtime_push(channel, value):
redis_qs[channel].put(value)
@contextmanager
def _mock_realtime_subscribe(channel: str):
def iterq():
while True:
yield redis_qs[channel].get()
yield iterq()
@pytest.fixture
def threadpool_subtest(subtests, max_workers: int = 128):
ts = []
def submit(fn, *args, msg=None, **kwargs):
if not msg:
msg = "--".join(map(str, [*args, *kwargs.values()]))
@wraps(fn)
def runner(*args, **kwargs):
with subtests.test(msg=msg):
return fn(*args, **kwargs)
ts.append(Thread(target=runner, args=args, kwargs=kwargs))
yield submit
for i in range(0, len(ts), max_workers):
s = slice(i, i + max_workers)
for t in ts[s]:
t.start()
for t in ts[s]:
t.join()
@pytest.fixture(autouse=True)
def clear_pytest_outbox():
pytest_outbox.clear()