Skip to content

Commit

Permalink
fix source ns _param_range assignment error
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Jan 14, 2024
1 parent d74969c commit 356508d
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/pyjuice/layer/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,15 +342,16 @@ def sum_layer_forward_compilation(nodes, fw_partition_max_chs, n_partition_ids,
if not source_ns.provided("_param_range"):
global_pid_end = global_pid_start + ns.num_edges
ns._param_range = (global_pid_start, global_pid_end)
global_pid_start = global_pid_end

global_pfid_end = global_pfid_start + ns.num_edges
ns._param_flow_range = (global_pfid_start, global_pfid_end)
global_pfid_start = global_pfid_end

source_ns._param_range = (global_pid_start, global_pid_end)
source_ns._param_flow_range = (global_pfid_start, global_pfid_end)

global_pid_start = global_pid_end
global_pfid_start = global_pfid_end

add_params_flag = True
add_param_flows_flag = True
else:
Expand Down Expand Up @@ -538,6 +539,8 @@ def sum_layer_forward_compilation(nodes, fw_partition_max_chs, n_partition_ids,
# Store global -> local parameter id mapping in `ns`
for ns_idx, param_ids in all_ns_param_ids.items():
ns = nodes[ns_idx]
if ns.is_tied():
ns = ns.get_source_ns()
# Every edge specify the start id of [ch_group_size, group_size] parameters
ns._param_ids = param_ids.cpu()[0::ns.ch_group_size]

Expand Down

0 comments on commit 356508d

Please sign in to comment.