Skip to content

Commit

Permalink
Clean up calls to captured_logger (#7521)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored Feb 2, 2023
1 parent d74f500 commit 70abff0
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 40 deletions.
4 changes: 1 addition & 3 deletions distributed/diagnostics/tests/test_scheduler_plugin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

import logging

import pytest

from distributed import Scheduler, SchedulerPlugin, Worker, get_worker
Expand Down Expand Up @@ -234,7 +232,7 @@ async def close(self):
await s.register_scheduler_plugin(OK())
await s.register_scheduler_plugin(Bad())

with captured_logger(logging.getLogger("distributed.scheduler")) as logger:
with captured_logger("distributed.scheduler") as logger:
await s.close()

out, err = capsys.readouterr()
Expand Down
26 changes: 13 additions & 13 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,8 +644,8 @@ async def test_gather_skip(c, s, a):
x = c.submit(div, 1, 0, priority=10)
y = c.submit(slowinc, 1, delay=0.5)

with captured_logger(logging.getLogger("distributed.scheduler")) as sched:
with captured_logger(logging.getLogger("distributed.client")) as client:
with captured_logger("distributed.scheduler") as sched:
with captured_logger("distributed.client") as client:
L = await c.gather([x, y], errors="skip")
assert L == [2]

Expand Down Expand Up @@ -3479,7 +3479,7 @@ async def test_get_foo_lost_keys(c, s, u, v, w):
)
async def test_bad_tasks_fail(c, s, a, b):
f = c.submit(sys.exit, 0)
with captured_logger(logging.getLogger("distributed.scheduler")) as logger:
with captured_logger("distributed.scheduler") as logger:
with pytest.raises(KilledWorker) as info:
await f

Expand Down Expand Up @@ -5015,7 +5015,7 @@ async def test_fire_and_forget_err(c, s, a, b):


def test_quiet_client_close(loop):
with captured_logger(logging.getLogger("distributed")) as logger:
with captured_logger("distributed") as logger:
with Client(
loop=loop,
processes=False,
Expand All @@ -5041,7 +5041,7 @@ def test_quiet_client_close(loop):

@pytest.mark.slow
def test_quiet_client_close_when_cluster_is_closed_before_client(loop):
with captured_logger(logging.getLogger("tornado.application")) as logger:
with captured_logger("tornado.application") as logger:
cluster = LocalCluster(loop=loop, n_workers=1, dashboard_address=":0")
client = Client(cluster, loop=loop)
cluster.close()
Expand Down Expand Up @@ -5584,7 +5584,7 @@ async def test_profile_keys(c, s, a, b):

assert p["count"] == xp["count"] + yp["count"]

with captured_logger(logging.getLogger("distributed")) as logger:
with captured_logger("distributed") as logger:
prof = await c.profile("does-not-exist")
assert prof == profile.create()
out = logger.getvalue()
Expand Down Expand Up @@ -5839,7 +5839,7 @@ def test_client_doesnt_close_given_loop(loop_in_thread, s, a, b):
@gen_cluster(client=True, nthreads=[])
async def test_quiet_scheduler_loss(c, s):
c._periodic_callbacks["scheduler-info"].interval = 10
with captured_logger(logging.getLogger("distributed.client")) as logger:
with captured_logger("distributed.client") as logger:
await s.close()
text = logger.getvalue()
assert "BrokenPipeError" not in text
Expand Down Expand Up @@ -6335,7 +6335,7 @@ async def test_shutdown_is_quiet_with_cluster():
async with LocalCluster(
n_workers=1, asynchronous=True, processes=False, dashboard_address=":0"
) as cluster:
with captured_logger(logging.getLogger("distributed.client")) as logger:
with captured_logger("distributed.client") as logger:
timeout = 0.1
async with Client(cluster, asynchronous=True, timeout=timeout) as c:
await c.shutdown()
Expand All @@ -6349,7 +6349,7 @@ async def test_client_is_quiet_cluster_close():
async with LocalCluster(
n_workers=1, asynchronous=True, processes=False, dashboard_address=":0"
) as cluster:
with captured_logger(logging.getLogger("distributed.client")) as logger:
with captured_logger("distributed.client") as logger:
timeout = 0.1
async with Client(cluster, asynchronous=True, timeout=timeout) as c:
await cluster.close()
Expand Down Expand Up @@ -6830,7 +6830,7 @@ def handler(event):
while len(s.event_subscriber["test-topic"]) != 2:
await asyncio.sleep(0.01)

with captured_logger(logging.getLogger("distributed.client")) as logger:
with captured_logger("distributed.client") as logger:
await c.log_event("test-topic", {})

while len(received_events) < 2:
Expand Down Expand Up @@ -7472,7 +7472,7 @@ def no_message():
# missing "message" key should log TypeError
get_worker().log_event("warn", {})

with captured_logger(logging.getLogger("distributed.client")) as log:
with captured_logger("distributed.client") as log:
await c.submit(no_message)
assert "TypeError" in log.getvalue()

Expand Down Expand Up @@ -7600,7 +7600,7 @@ def print_otherfile():
# this should log a TypeError in the client
get_worker().log_event("print", {"args": ("hello",), "file": "bad value"})

with captured_logger(logging.getLogger("distributed.client")) as log:
with captured_logger("distributed.client") as log:
await c.submit(print_otherfile)
assert "TypeError" in log.getvalue()

Expand All @@ -7610,7 +7610,7 @@ async def test_print_manual_bad_args(c, s, a, b, capsys):
def foo():
get_worker().log_event("print", {"args": "not a tuple"})

with captured_logger(logging.getLogger("distributed.client")) as log:
with captured_logger("distributed.client") as log:
await c.submit(foo)
assert "TypeError" in log.getvalue()

Expand Down
5 changes: 2 additions & 3 deletions distributed/tests/test_preload.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import logging
import os
import re
import shutil
Expand Down Expand Up @@ -298,8 +297,8 @@ def dask_teardown(worker):
raise Exception(456)
"""

with captured_logger(logging.getLogger("distributed.scheduler")) as s_logger:
with captured_logger(logging.getLogger("distributed.worker")) as w_logger:
with captured_logger("distributed.scheduler") as s_logger:
with captured_logger("distributed.worker") as w_logger:
async with Scheduler(dashboard_address=":0", preload=text) as s:
async with Worker(s.address, preload=[text]) as w:
pass
Expand Down
4 changes: 2 additions & 2 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2834,7 +2834,7 @@ async def test_too_many_groups(c, s, a, b):

@gen_test()
async def test_multiple_listeners():
with captured_logger(logging.getLogger("distributed.scheduler")) as log:
with captured_logger("distributed.scheduler") as log:
async with Scheduler(dashboard_address=":0", protocol=["inproc", "tcp"]) as s:
async with Worker(s.listeners[0].contact_address) as a:
async with Worker(s.listeners[1].contact_address) as b:
Expand All @@ -2861,7 +2861,7 @@ async def test_worker_name_collision(s, a):
# test that a name collision for workers produces the expected response
# and leaves the data structures of Scheduler in a good state
# is not updated by the second worker
with captured_logger(logging.getLogger("distributed.scheduler")) as log:
with captured_logger("distributed.scheduler") as log:
with raises_with_cause(
RuntimeError, None, ValueError, f"name taken, {a.name!r}"
):
Expand Down
6 changes: 2 additions & 4 deletions distributed/tests/test_sizeof.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

import logging

import pytest

from dask.sizeof import sizeof
Expand All @@ -23,13 +21,13 @@ def __sizeof__(self):
foo = BadlySized()

# Defaults to 0.95 MiB by default
with captured_logger(logging.getLogger("distributed.sizeof")) as logs:
with captured_logger("distributed.sizeof") as logs:
assert safe_sizeof(foo) == 1e6

assert "Sizeof calculation failed. Defaulting to 0.95 MiB" in logs.getvalue()

# Can provide custom `default_size`
with captured_logger(logging.getLogger("distributed.sizeof")) as logs:
with captured_logger("distributed.sizeof") as logs:
default_size = 2 * (1024**2) # 2 MiB
assert safe_sizeof(foo, default_size=default_size) == default_size

Expand Down
19 changes: 9 additions & 10 deletions distributed/tests/test_spill.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import array
import logging
import os
import random
import uuid
Expand Down Expand Up @@ -150,7 +149,7 @@ def test_spillbuffer_maxlim(tmp_path_factory):
# size of e < target but e+c > target, this will trigger movement of c to slow
# but the max spill limit prevents it. Resulting in e remaining in fast

with captured_logger(logging.getLogger("distributed.spill")) as logs_e:
with captured_logger("distributed.spill") as logs_e:
buf["e"] = e

assert "disk reached capacity" in logs_e.getvalue()
Expand All @@ -159,7 +158,7 @@ def test_spillbuffer_maxlim(tmp_path_factory):
# size of d > target, d should go to slow but slow reached the max_spill limit then
# d will end up on fast with c (which can't be move to slow because it won't fit
# either)
with captured_logger(logging.getLogger("distributed.spill")) as logs_d:
with captured_logger("distributed.spill") as logs_d:
buf["d"] = d

assert "disk reached capacity" in logs_d.getvalue()
Expand All @@ -176,7 +175,7 @@ def test_spillbuffer_maxlim(tmp_path_factory):
unlimited_buf["a_large"] = a_large
assert psize(unlimited_buf_dir, a_large=a_large)[1] > 600

with captured_logger(logging.getLogger("distributed.spill")) as logs_alarge:
with captured_logger("distributed.spill") as logs_alarge:
buf["a"] = a_large

assert "disk reached capacity" in logs_alarge.getvalue()
Expand All @@ -186,7 +185,7 @@ def test_spillbuffer_maxlim(tmp_path_factory):
# max_spill

d_large = "d" * 501
with captured_logger(logging.getLogger("distributed.spill")) as logs_dlarge:
with captured_logger("distributed.spill") as logs_dlarge:
buf["d"] = d_large

assert "disk reached capacity" in logs_dlarge.getvalue()
Expand Down Expand Up @@ -216,7 +215,7 @@ def test_spillbuffer_fail_to_serialize(tmp_path):

# Exception caught in the worker
with pytest.raises(TypeError, match="Could not serialize"):
with captured_logger(logging.getLogger("distributed.spill")) as logs_bad_key:
with captured_logger("distributed.spill") as logs_bad_key:
buf["a"] = a

# spill.py must remain silent because we're already logging in worker.py
Expand All @@ -229,7 +228,7 @@ def test_spillbuffer_fail_to_serialize(tmp_path):
assert_buf(buf, tmp_path, {"b": b}, {})

c = "c" * 100
with captured_logger(logging.getLogger("distributed.spill")) as logs_bad_key_mem:
with captured_logger("distributed.spill") as logs_bad_key_mem:
# This will go to fast and try to kick b out,
# but keep b in fast since it's not pickable
buf["c"] = c
Expand Down Expand Up @@ -262,7 +261,7 @@ def test_spillbuffer_oserror(tmp_path):
os.chmod(tmp_path, 0o555)

# Add key > than target
with captured_logger(logging.getLogger("distributed.spill")) as logs_oserror_slow:
with captured_logger("distributed.spill") as logs_oserror_slow:
buf["c"] = c

assert "Spill to disk failed" in logs_oserror_slow.getvalue()
Expand All @@ -273,7 +272,7 @@ def test_spillbuffer_oserror(tmp_path):

# add key to fast which is smaller than target but when added it triggers spill,
# which triggers OSError
with captured_logger(logging.getLogger("distributed.spill")) as logs_oserror_evict:
with captured_logger("distributed.spill") as logs_oserror_evict:
buf["d"] = d

assert "Spill to disk failed" in logs_oserror_evict.getvalue()
Expand All @@ -298,7 +297,7 @@ def test_spillbuffer_evict(tmp_path):
assert_buf(buf, tmp_path, {"bad": bad}, {"a": a})

# unsuccessful eviction
with captured_logger(logging.getLogger("distributed.spill")) as logs_evict_key:
with captured_logger("distributed.spill") as logs_evict_key:
weight = buf.evict()
assert weight == -1

Expand Down
3 changes: 1 addition & 2 deletions distributed/tests/test_variable.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import asyncio
import logging
import random
from datetime import timedelta
from time import sleep
Expand Down Expand Up @@ -64,7 +63,7 @@ def foo():
async def test_delete_unset_variable(c, s, a, b):
x = Variable()
assert x.client is c
with captured_logger(logging.getLogger("distributed.utils")) as logger:
with captured_logger("distributed.utils") as logger:
x.delete()
await c.close()
text = logger.getvalue()
Expand Down
6 changes: 3 additions & 3 deletions distributed/tests/test_worker_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ async def test_fail_to_pickle_spill(c, s, a):
"""
a.monitor.get_process_memory = lambda: 701 if a.data.fast else 0

with captured_logger(logging.getLogger("distributed.spill")) as logs:
with captured_logger("distributed.spill") as logs:
bad = c.submit(FailToPickle, key="bad")
await wait(bad)

Expand Down Expand Up @@ -586,7 +586,7 @@ def f(ev):
while a.state.executing_count != 1:
await asyncio.sleep(0.01)

with captured_logger(logging.getLogger("distributed.worker.memory")) as logger:
with captured_logger("distributed.worker.memory") as logger:
# Task that is queued on the worker when the worker pauses
y = c.submit(inc, 1, key="y")
while "y" not in a.state.tasks:
Expand Down Expand Up @@ -770,7 +770,7 @@ def __sizeof__(self):
return 8_100_000_000

# Capture output of log_errors()
with captured_logger(logging.getLogger("distributed.utils")) as logger:
with captured_logger("distributed.utils") as logger:
x = c.submit(C)
await wait(x)

Expand Down

0 comments on commit 70abff0

Please sign in to comment.