From 481fc2b00fbc0c56f86f7eae497a10f9bc82ab21 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Sun, 14 Jan 2024 16:53:29 +0800 Subject: [PATCH] fix compilation logic for tied nodes when the source does not appear first --- src/pyjuice/layer/compilation.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/src/pyjuice/layer/compilation.py b/src/pyjuice/layer/compilation.py index a3ff8160..f8099a29 100644 --- a/src/pyjuice/layer/compilation.py +++ b/src/pyjuice/layer/compilation.py @@ -348,11 +348,16 @@ def sum_layer_forward_compilation(nodes, fw_partition_max_chs, n_partition_ids, ns._param_flow_range = (global_pfid_start, global_pfid_end) global_pfid_start = global_pfid_end + source_ns._param_range = (global_pid_start, global_pid_end) + source_ns._param_flow_range = (global_pfid_start, global_pfid_end) + add_params_flag = True + add_param_flows_flag = True else: ns._param_range = deepcopy(source_ns._param_range) add_params_flag = False + add_param_flows_flag = False if source_ns not in node2tiednodes: node2tiednodes[source_ns] = [[source_ns], [source_ns._param_flow_range]] @@ -360,23 +365,24 @@ def sum_layer_forward_compilation(nodes, fw_partition_max_chs, n_partition_ids, elif source_ns not in node2tiedcounts: node2tiedcounts[source_ns] = [0 for _ in range(len(node2tiednodes[source_ns][0]))] - if all([dup_count >= max_tied_ns_per_parflow_group for dup_count in node2tiedcounts[source_ns]]): - global_pfid_end = global_pfid_start + ns.num_edges - ns._param_flow_range = (global_pfid_start, global_pfid_end) - global_pfid_start = global_pfid_end - node2tiednodes[source_ns][1].append(ns._param_flow_range) + if not ns.provided("_param_flow_range"): + if all([dup_count >= max_tied_ns_per_parflow_group for dup_count in node2tiedcounts[source_ns]]): + global_pfid_end = global_pfid_start + ns.num_edges + ns._param_flow_range = (global_pfid_start, global_pfid_end) + global_pfid_start = global_pfid_end + node2tiednodes[source_ns][1].append(ns._param_flow_range) - node2tiednodes[source_ns][0].append(ns) - node2tiedcounts[source_ns].append(1) + node2tiednodes[source_ns][0].append(ns) + node2tiedcounts[source_ns].append(1) - add_param_flows_flag = True - else: - target_id = min(range(len(node2tiedcounts[source_ns])), key = lambda i: node2tiedcounts[source_ns][i]) - ns._param_flow_range = deepcopy(node2tiednodes[source_ns][1][target_id]) + add_param_flows_flag = True + else: + target_id = min(range(len(node2tiedcounts[source_ns])), key = lambda i: node2tiedcounts[source_ns][i]) + ns._param_flow_range = deepcopy(node2tiednodes[source_ns][1][target_id]) - node2tiedcounts[source_ns][target_id] += 1 + node2tiedcounts[source_ns][target_id] += 1 - add_param_flows_flag = False + add_param_flows_flag = False # Global pid and pfid start index for `ns` ns_pid_start = source_ns._param_range[0]