Skip to content

Commit

Permalink
Cosmetic review of story() (#6442)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored Jun 15, 2022
1 parent bc90846 commit 2778cf5
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 29 deletions.
33 changes: 22 additions & 11 deletions distributed/_stories.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,55 @@
from __future__ import annotations

from typing import Iterable


def scheduler_story(keys: set, transition_log: Iterable) -> list:
def scheduler_story(
keys_or_stimuli: set[str], transition_log: Iterable[tuple]
) -> list[tuple]:
"""Creates a story from the scheduler transition log given a set of keys
describing tasks or stimuli.
Parameters
----------
keys : set
A set of task `keys` or `stimulus_id`'s
keys_or_stimuli : set[str]
Task keys or stimulus_id's
log : iterable
The scheduler transition log
Returns
-------
story : list
story : list[tuple]
"""
return [t for t in transition_log if t[0] in keys or keys.intersection(t[3])]
return [
t
for t in transition_log
if t[0] in keys_or_stimuli or keys_or_stimuli.intersection(t[3])
]


def worker_story(keys: set, log: Iterable) -> list:
def worker_story(keys_or_stimuli: set[str], log: Iterable[tuple]) -> list:
"""Creates a story from the worker log given a set of keys
describing tasks or stimuli.
Parameters
----------
keys : set
A set of task `keys` or `stimulus_id`'s
keys_or_stimuli : set[str]
Task keys or stimulus_id's
log : iterable
The worker log
Returns
-------
story : list
story : list[str]
"""
return [
msg
for msg in log
if any(key in msg for key in keys)
if any(key in msg for key in keys_or_stimuli)
or any(
key in c for key in keys for c in msg if isinstance(c, (tuple, list, set))
key in c
for key in keys_or_stimuli
for c in msg
if isinstance(c, (tuple, list, set))
)
]
15 changes: 9 additions & 6 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4272,11 +4272,13 @@ def collections_to_dsk(collections, *args, **kwargs):
"""Convert many collections into a single dask graph, after optimization"""
return collections_to_dsk(collections, *args, **kwargs)

async def _story(self, keys=(), on_error="raise"):
async def _story(self, *keys_or_stimuli: str, on_error="raise"):
assert on_error in ("raise", "ignore")

try:
flat_stories = await self.scheduler.get_story(keys=keys)
flat_stories = await self.scheduler.get_story(
keys_or_stimuli=keys_or_stimuli
)
flat_stories = [("scheduler", *msg) for msg in flat_stories]
except Exception:
if on_error == "raise":
Expand All @@ -4287,15 +4289,16 @@ async def _story(self, keys=(), on_error="raise"):
raise ValueError(f"on_error not in {'raise', 'ignore'}")

responses = await self.scheduler.broadcast(
msg={"op": "get_story", "keys": keys}, on_error=on_error
msg={"op": "get_story", "keys_or_stimuli": keys_or_stimuli},
on_error=on_error,
)
for worker, stories in responses.items():
flat_stories.extend((worker, *msg) for msg in stories)
return flat_stories

def story(self, *keys_or_stimulus_ids, on_error="raise"):
"""Returns a cluster-wide story for the given keys or simtulus_id's"""
return self.sync(self._story, keys=keys_or_stimulus_ids, on_error=on_error)
def story(self, *keys_or_stimuli, on_error="raise"):
"""Returns a cluster-wide story for the given keys or stimulus_id's"""
return self.sync(self._story, *keys_or_stimuli, on_error=on_error)

def get_task_stream(
self,
Expand Down
15 changes: 9 additions & 6 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6539,13 +6539,16 @@ def transitions(self, recommendations: dict, stimulus_id: str):
self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id)
self.send_all(client_msgs, worker_msgs)

def story(self, *keys):
"""Get all transitions that touch one of the input keys"""
keys = {key.key if isinstance(key, TaskState) else key for key in keys}
return scheduler_story(keys, self.transition_log)
def story(self, *keys_or_tasks_or_stimuli: str | TaskState) -> list[tuple]:
"""Get all transitions that touch one of the input keys or stimulus_id's"""
keys_or_stimuli = {
key.key if isinstance(key, TaskState) else key
for key in keys_or_tasks_or_stimuli
}
return scheduler_story(keys_or_stimuli, self.transition_log)

async def get_story(self, keys=()):
return self.story(*keys)
async def get_story(self, keys_or_stimuli: Iterable[str]) -> list[tuple]:
return self.story(*keys_or_stimuli)

transition_story = story

Expand Down
4 changes: 2 additions & 2 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1843,8 +1843,8 @@ def stateof(self, key: str) -> dict[str, Any]:
"data": key in self.data,
}

async def get_story(self, keys=None):
return self.story(*keys)
async def get_story(self, keys_or_stimuli: Iterable[str]) -> list[tuple]:
return self.state.story(*keys_or_stimuli)

##########################
# Dependencies gathering #
Expand Down
12 changes: 8 additions & 4 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2774,10 +2774,14 @@ def _handle_refresh_who_has(self, ev: RefreshWhoHasEvent) -> RecsInstrs:
# Diagnostics #
###############

def story(self, *keys_or_tasks: str | TaskState) -> list[tuple]:
"""Return all transitions involving one or more tasks"""
keys = {e.key if isinstance(e, TaskState) else e for e in keys_or_tasks}
return worker_story(keys, self.log)
def story(self, *keys_or_tasks_or_stimuli: str | TaskState) -> list[tuple]:
"""Return all records from the transitions log involving one or more tasks or
stimulus_id's
"""
keys_or_stimuli = {
e.key if isinstance(e, TaskState) else e for e in keys_or_tasks_or_stimuli
}
return worker_story(keys_or_stimuli, self.log)

def stimulus_story(
self, *keys_or_tasks: str | TaskState
Expand Down

0 comments on commit 2778cf5

Please sign in to comment.