Skip to content

Commit

Permalink
O(1) access for /info/task/ endpoint (#8363)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored Nov 20, 2023
1 parent 1b3ec18 commit 3460b4f
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 32 deletions.
56 changes: 38 additions & 18 deletions distributed/http/scheduler/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
import logging
import os
import os.path
from collections.abc import Hashable
from datetime import datetime
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from tlz import first, merge
from tornado import escape
from tornado.websocket import WebSocketHandler

from dask.typing import Key
from dask.utils import format_bytes, format_time

from distributed.diagnostics.websocket import WebsocketPlugin
Expand Down Expand Up @@ -90,28 +90,49 @@ def get(self):
)


def _get_actual_scheduler_key(key: str, scheduler: Scheduler) -> Hashable:
for k in scheduler.tasks:
if str(k) == key:
return k
def _get_actual_scheduler_key(key: str, scheduler: Scheduler) -> Key:
key = escape.url_unescape(key, plus=False)

if key in scheduler.tasks:
return key # Basic str key

# Tuple, bytes, or int key
# First try safely reverting str(Key)
def lists_to_tuples(o: object) -> Any:
if isinstance(o, list):
return tuple(lists_to_tuples(i) for i in o)
else:
return o

key2 = key.replace("(", "[").replace(")", "]").replace("'", '"')
try:
key2 = lists_to_tuples(json.loads(key2))
if key2 in scheduler.tasks:
return key2
except json.JSONDecodeError:
pass

# Edge case of keys with string elements containing [ ] ( ) ' " or bytes
for key3 in scheduler.tasks:
if str(key3) == key:
return key3

raise KeyError(key)


class Task(RequestHandler):
@log_errors
def get(self, task):
task = escape.url_unescape(task)

def get(self, task: str) -> None:
try:
requested_key = _get_actual_scheduler_key(task, self.server)
key = _get_actual_scheduler_key(task, self.server)
except KeyError:
self.send_error(404)
return

self.render(
"task.html",
title="Task: " + task,
Task=requested_key,
title=f"Task: {key!r}",
Task=key,
scheduler=self.server,
**merge(
self.server.__dict__,
Expand Down Expand Up @@ -175,15 +196,14 @@ async def get(self, worker):

class TaskCallStack(RequestHandler):
@log_errors
async def get(self, key):
key = escape.url_unescape(key)

async def get(self, task: str) -> None:
try:
requested_key = _get_actual_scheduler_key(key, self.server)
key = _get_actual_scheduler_key(task, self.server)
except KeyError:
self.send_error(404)
return
call_stack = await self.server.get_call_stack(keys=[requested_key])

call_stack = await self.server.get_call_stack(keys=[key])
if not call_stack:
self.write(
"<p>Task not actively running. "
Expand All @@ -192,7 +212,7 @@ async def get(self, key):
else:
self.render(
"call-stack.html",
title="Call Stack: " + key,
title=f"Call Stack: {key!r}",
call_stack=call_stack,
**merge(self.extra, rel_path_statics),
)
Expand Down
78 changes: 67 additions & 11 deletions distributed/http/scheduler/tests/test_scheduler_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,28 +340,84 @@ async def test_sitemap(s, a, b):
assert "/statics/css/base.css" in out["paths"]


KEY_EDGE_CASES = [
1,
"1",
"x",
"a+b",
"a b",
("x", 1),
("a+b", 1),
("a b", 1),
((1, 2), ("x", "y")),
"(",
"[",
"'",
'"',
b"123",
(b"123", 1),
("[", 1),
("(", 1),
("'", 1),
('"', 1),
]


@pytest.mark.parametrize("key", KEY_EDGE_CASES)
@gen_cluster(client=True)
async def test_task_page(c, s, a, b):
future = c.submit(lambda x: x + 1, 1, workers=a.address)
x = c.submit(inc, 1)
await future
async def test_task_page(c, s, a, b, key):
http_client = AsyncHTTPClient()
skey = url_escape(str(key), plus=False)
url = f"http://localhost:{s.http_server.port}/info/task/{skey}.html"

"info/task/" + url_escape(future.key) + ".html",
response = await http_client.fetch(
"http://localhost:%d/info/task/" % s.http_server.port
+ url_escape(future.key)
+ ".html"
)
response = await http_client.fetch(url, raise_error=False)
assert response.code == 404

future = c.submit(lambda: 1, key=key, workers=a.address)
await future
response = await http_client.fetch(url)
assert response.code == 200
body = response.body.decode()

assert a.address in body
assert str(sizeof(1)) in body
assert "int" in body
assert a.address in body
assert "memory" in body


@pytest.mark.parametrize("key", KEY_EDGE_CASES)
@gen_cluster(client=True)
async def test_call_stack_page(c, s, a, b, key):
http_client = AsyncHTTPClient()
skey = url_escape(str(key), plus=False)
url = f"http://localhost:{s.http_server.port}/info/call-stack/{skey}.html"

response = await http_client.fetch(url, raise_error=False)
assert response.code == 404

ev1 = Event()
ev2 = Event()

def f(ev1, ev2):
ev1.set()
ev2.wait()

future = c.submit(f, ev1, ev2, key=key)
await ev1.wait()

response = await http_client.fetch(url)
assert response.code == 200
body = response.body.decode()
assert "test_scheduler_http.py" in body

await ev2.set()
await future
response = await http_client.fetch(url)
assert response.code == 200
body = response.body.decode()
assert "Task not actively running" in body


@gen_cluster(
client=True,
scheduler_kwargs={"dashboard": True},
Expand Down
6 changes: 3 additions & 3 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ class Worker(BaseWorker, ServerNode):
transfer_outgoing_count_limit: int
threads: dict[Key, int] # {ts.key: thread ID}
active_threads_lock: threading.Lock
active_threads: dict[int, str] # {thread ID: ts.key}
active_threads: dict[int, Key] # {thread ID: ts.key}
active_keys: set[Key]
profile_keys: defaultdict[str, dict[str, Any]]
profile_keys_history: deque[tuple[float, dict[str, dict[str, Any]]]]
Expand Down Expand Up @@ -2517,7 +2517,7 @@ async def get_profile_metadata(
)
return result

def get_call_stack(self, keys: Collection[str] | None = None) -> dict[str, Any]:
def get_call_stack(self, keys: Collection[Key] | None = None) -> dict[Key, Any]:
with self.active_threads_lock:
sys_frames = sys._current_frames()
frames = {key: sys_frames[tid] for tid, key in self.active_threads.items()}
Expand Down Expand Up @@ -2609,7 +2609,7 @@ def _get_client(self, timeout: float | None = None) -> Client:

return self._client

def get_current_task(self) -> str:
def get_current_task(self) -> Key:
"""Get the key of the task we are currently running
This only makes sense to run within a task
Expand Down

0 comments on commit 3460b4f

Please sign in to comment.