diff --git a/src/pyjuice/layer/compilation.py b/src/pyjuice/layer/compilation.py index f8099a29..592a1599 100644 --- a/src/pyjuice/layer/compilation.py +++ b/src/pyjuice/layer/compilation.py @@ -342,15 +342,16 @@ def sum_layer_forward_compilation(nodes, fw_partition_max_chs, n_partition_ids, if not source_ns.provided("_param_range"): global_pid_end = global_pid_start + ns.num_edges ns._param_range = (global_pid_start, global_pid_end) - global_pid_start = global_pid_end global_pfid_end = global_pfid_start + ns.num_edges ns._param_flow_range = (global_pfid_start, global_pfid_end) - global_pfid_start = global_pfid_end source_ns._param_range = (global_pid_start, global_pid_end) source_ns._param_flow_range = (global_pfid_start, global_pfid_end) + global_pid_start = global_pid_end + global_pfid_start = global_pfid_end + add_params_flag = True add_param_flows_flag = True else: @@ -538,6 +539,8 @@ def sum_layer_forward_compilation(nodes, fw_partition_max_chs, n_partition_ids, # Store global -> local parameter id mapping in `ns` for ns_idx, param_ids in all_ns_param_ids.items(): ns = nodes[ns_idx] + if ns.is_tied(): + ns = ns.get_source_ns() # Every edge specify the start id of [ch_group_size, group_size] parameters ns._param_ids = param_ids.cpu()[0::ns.ch_group_size]