Skip to content

Commit

Permalink
fix compilation logic for tied nodes when the source does not appear …
Browse files Browse the repository at this point in the history
…first
  • Loading branch information
liuanji committed Jan 14, 2024
1 parent bb9aca6 commit 481fc2b
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions src/pyjuice/layer/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,35 +348,41 @@ def sum_layer_forward_compilation(nodes, fw_partition_max_chs, n_partition_ids,
ns._param_flow_range = (global_pfid_start, global_pfid_end)
global_pfid_start = global_pfid_end

source_ns._param_range = (global_pid_start, global_pid_end)
source_ns._param_flow_range = (global_pfid_start, global_pfid_end)

add_params_flag = True
add_param_flows_flag = True
else:
ns._param_range = deepcopy(source_ns._param_range)

add_params_flag = False
add_param_flows_flag = False

if source_ns not in node2tiednodes:
node2tiednodes[source_ns] = [[source_ns], [source_ns._param_flow_range]]
node2tiedcounts[source_ns] = [1]
elif source_ns not in node2tiedcounts:
node2tiedcounts[source_ns] = [0 for _ in range(len(node2tiednodes[source_ns][0]))]

if all([dup_count >= max_tied_ns_per_parflow_group for dup_count in node2tiedcounts[source_ns]]):
global_pfid_end = global_pfid_start + ns.num_edges
ns._param_flow_range = (global_pfid_start, global_pfid_end)
global_pfid_start = global_pfid_end
node2tiednodes[source_ns][1].append(ns._param_flow_range)
if not ns.provided("_param_flow_range"):
if all([dup_count >= max_tied_ns_per_parflow_group for dup_count in node2tiedcounts[source_ns]]):
global_pfid_end = global_pfid_start + ns.num_edges
ns._param_flow_range = (global_pfid_start, global_pfid_end)
global_pfid_start = global_pfid_end
node2tiednodes[source_ns][1].append(ns._param_flow_range)

node2tiednodes[source_ns][0].append(ns)
node2tiedcounts[source_ns].append(1)
node2tiednodes[source_ns][0].append(ns)
node2tiedcounts[source_ns].append(1)

add_param_flows_flag = True
else:
target_id = min(range(len(node2tiedcounts[source_ns])), key = lambda i: node2tiedcounts[source_ns][i])
ns._param_flow_range = deepcopy(node2tiednodes[source_ns][1][target_id])
add_param_flows_flag = True
else:
target_id = min(range(len(node2tiedcounts[source_ns])), key = lambda i: node2tiedcounts[source_ns][i])
ns._param_flow_range = deepcopy(node2tiednodes[source_ns][1][target_id])

node2tiedcounts[source_ns][target_id] += 1
node2tiedcounts[source_ns][target_id] += 1

add_param_flows_flag = False
add_param_flows_flag = False

# Global pid and pfid start index for `ns`
ns_pid_start = source_ns._param_range[0]
Expand Down

0 comments on commit 481fc2b

Please sign in to comment.