Skip to content

Commit

Permalink
make layer stack user input more tolerant (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinGeens authored Oct 28, 2024
1 parent 9ef6fda commit c546816
Showing 1 changed file with 24 additions and 5 deletions.
29 changes: 24 additions & 5 deletions stream/stages/generation/layer_stacks_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@


class LayerStacksGenerationStage(Stage):
layer_stacks: list[tuple[int, ...]] | None

def __init__(
self,
list_of_callables: list[StageCallable],
Expand Down Expand Up @@ -50,6 +52,9 @@ def run(self):
self.layer_stacks = self.get_layer_stacks_fused_multiple_fixed()
else:
self.layer_stacks = self.get_layer_stacks_fused_single()
else:
self.layer_stacks = self.fill_layer_stacks_to_completion()

elif self.mode == "lbl":
self.layer_stacks = self.get_layer_stacks_lbl()
else:
Expand All @@ -69,19 +74,33 @@ def run(self):

def only_keep_computation_node_ids(self):
"""! Update the layer stacks to only keep ids of ComputationNodes"""
updated_layer_stacks = []
assert self.layer_stacks is not None
updated_layer_stacks: list[tuple[int, ...]] = []
for stack in self.layer_stacks:
update_stack = []
update_stack: list[tuple[int, ...]] = []
for layer_id in stack:
n = next(n for n in self.workload.node_list if n.id == layer_id)
if isinstance(n, ComputationNode):
update_stack.append(layer_id)
try:
# Ignore node ids that do not exist
n = next(n for n in self.workload.node_list if n.id == layer_id)
if isinstance(n, ComputationNode):
update_stack.append(layer_id)
except StopIteration:
pass
updated_layer_stacks.append(tuple(update_stack))
self.layer_stacks = updated_layer_stacks

def get_layer_stacks_lbl(self):
return [(id,) for id in sorted([n.id for n in self.workload.node_list if isinstance(n, ComputationNode)])]

def fill_layer_stacks_to_completion(self):
assert self.layer_stacks is not None
stacks: list[tuple[int, ...]] = self.layer_stacks

for node in self.workload.node_list:
if not any(node.id in stack for stack in stacks):
stacks += [(node.id,)]
return stacks

def get_layer_stacks_fused(self):
cumsum = 0
stacks = []
Expand Down

0 comments on commit c546816

Please sign in to comment.