Skip to content

Commit

Permalink
reuse parameter flows space for tied nodes in different layers
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Jan 10, 2024
1 parent 85320d2 commit d01830d
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 14 deletions.
18 changes: 11 additions & 7 deletions src/pyjuice/layer/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def sum_layer_forward_compilation(nodes, fw_partition_max_chs, n_partition_ids,

# This is the main loop: iterate over `ns` in the layer
ngroup_start = 0 # The start index of the node groups in the current `ns`
node2tiedcounts = dict() # Locally accumulate the occupation count
for ns_idx, ns in enumerate(nodes):

if not ns.is_tied():
Expand Down Expand Up @@ -354,23 +355,26 @@ def sum_layer_forward_compilation(nodes, fw_partition_max_chs, n_partition_ids,
add_params_flag = False

if source_ns not in node2tiednodes:
node2tiednodes[source_ns] = [[source_ns], 1, source_ns._param_flow_range]
node2tiednodes[source_ns] = [[source_ns], [source_ns._param_flow_range]]
node2tiedcounts[source_ns] = [1]
elif source_ns not in node2tiedcounts:
node2tiedcounts[source_ns] = [0 for _ in range(len(node2tiednodes[source_ns][0]))]

dup_count = node2tiednodes[source_ns][1]
if dup_count >= max_tied_ns_per_parflow_group:
if all([dup_count >= max_tied_ns_per_parflow_group for dup_count in node2tiedcounts[source_ns]]):
global_pfid_end = global_pfid_start + ns.num_edges
ns._param_flow_range = (global_pfid_start, global_pfid_end)
global_pfid_start = global_pfid_end
node2tiednodes[source_ns][2] = ns._param_flow_range
node2tiednodes[source_ns][1].append(ns._param_flow_range)

node2tiednodes[source_ns][0].append(ns)
node2tiednodes[source_ns][1] = 1
node2tiedcounts[source_ns].append(1)

add_param_flows_flag = True
else:
ns._param_flow_range = deepcopy(node2tiednodes[source_ns][2])
target_id = min(range(len(node2tiedcounts[source_ns])), key = lambda i: node2tiedcounts[source_ns][i])
ns._param_flow_range = deepcopy(node2tiednodes[source_ns][1][target_id])

node2tiednodes[source_ns][1] += 1
node2tiedcounts[source_ns][target_id] += 1

add_param_flows_flag = False

Expand Down
4 changes: 2 additions & 2 deletions tests/model/homogeneous_hmm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def homogeneous_hmm_test():
assert torch.all(pc.inner_layer_groups[5][0].partitioned_nids[0] == torch.tensor([13, 14]))
assert torch.all(pc.inner_layer_groups[5][0].partitioned_cids[0] == torch.tensor([[1, 2], [1, 2]]))
assert torch.all(pc.inner_layer_groups[5][0].partitioned_pids[0] == torch.tensor([[1, 2], [3, 4]]))
assert torch.all(pc.inner_layer_groups[5][0].partitioned_pfids[0] == torch.tensor([[4, 5], [6, 7]]))
assert torch.all(pc.inner_layer_groups[5][0].partitioned_pfids[0] == torch.tensor([[0, 1], [2, 3]]))

assert torch.all(pc.inner_layer_groups[5][0].partitioned_chids[0] == torch.tensor([1, 2]))
assert torch.all(pc.inner_layer_groups[5][0].partitioned_parids[0] == torch.tensor([[13, 14], [13, 14]]))
Expand Down Expand Up @@ -189,7 +189,7 @@ def homogeneous_hmm_test():
ni0_flows = element_flows[1:3,:]
ni1_flows = element_flows[1:3,:]

assert torch.all(torch.abs(param_flows1.reshape(-1) - (param_flows[0:4] + param_flows[4:8])) < 1e-4)
assert torch.all(torch.abs(param_flows1.reshape(-1) - param_flows[0:4]) < 1e-4)

## Parameter learning & flow aggregation tests ##

Expand Down
10 changes: 5 additions & 5 deletions tests/model/parameter_tying_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def simple_structure_test_group1():

assert torch.all(pc.inner_layer_groups[5][0].partitioned_pids[0] == torch.tensor([[5, 6, 7, 8], [9, 10, 11, 12]]))

assert torch.all(pc.inner_layer_groups[5][0].partitioned_pfids[0] == torch.tensor([[16, 17, 18, 19], [20, 21, 22, 23]]))
assert torch.all(pc.inner_layer_groups[5][0].partitioned_pfids[0] == torch.tensor([[8, 9, 10, 11], [12, 13, 14, 15]]))

assert torch.all(pc.inner_layer_groups[5][0].partitioned_chids[0] == torch.arange(1, 5))
assert torch.all(pc.inner_layer_groups[5][0].partitioned_parids[0] == torch.tensor([19, 20]).reshape(1, 2))
Expand Down Expand Up @@ -282,7 +282,7 @@ def simple_structure_test_group1():
assert torch.all(torch.abs(node_flows[7:9,:] - ni3_flows) < 1e-4)

assert torch.all(torch.abs(param_flows0.reshape(-1) - (param_flows[0:4] + param_flows[4:8])) < 1e-4)
assert torch.all(torch.abs(param_flows1.reshape(-1) - (param_flows[8:16] + param_flows[16:24])) < 1e-4)
assert torch.all(torch.abs(param_flows1.reshape(-1) - param_flows[8:16]) < 1e-4)

## Parameter learning & flow aggregation tests ##

Expand Down Expand Up @@ -433,8 +433,8 @@ def simple_structure_test_group16():
assert torch.all(pc.inner_layer_groups[5][0].partitioned_pids[0][0,:] == torch.arange(1280, 2304, 16))
assert torch.all(pc.inner_layer_groups[5][0].partitioned_pids[0][1,:] == torch.arange(2304, 3328, 16))

assert torch.all(pc.inner_layer_groups[5][0].partitioned_pfids[0][0,:] == torch.arange(4096, 5120, 16))
assert torch.all(pc.inner_layer_groups[5][0].partitioned_pfids[0][1,:] == torch.arange(5120, 6144, 16))
assert torch.all(pc.inner_layer_groups[5][0].partitioned_pfids[0][0,:] == torch.arange(2048, 3072, 16))
assert torch.all(pc.inner_layer_groups[5][0].partitioned_pfids[0][1,:] == torch.arange(3072, 4096, 16))

assert torch.all(pc.inner_layer_groups[5][0].partitioned_chids[0] == torch.arange(1, 5) * 16)
assert torch.all(pc.inner_layer_groups[5][0].partitioned_parids[0] == torch.tensor([19, 20]).reshape(1, 2) * 16)
Expand Down Expand Up @@ -582,7 +582,7 @@ def simple_structure_test_group16():
ref_param_flows0 = (param_flows[0:1024] + param_flows[1024:2048]).reshape(2, 2, 16, 16).permute(0, 3, 1, 2).reshape(-1)
assert torch.all(torch.abs(param_flows0.reshape(-1) - ref_param_flows0) < 1e-2)

ref_param_flows1 = (param_flows[2048:4096] + param_flows[4096:6144]).reshape(2, 4, 16, 16).permute(0, 3, 1, 2).reshape(-1)
ref_param_flows1 = param_flows[2048:4096].reshape(2, 4, 16, 16).permute(0, 3, 1, 2).reshape(-1)
assert torch.all(torch.abs(param_flows1.reshape(-1) - ref_param_flows1) < 1e-2)

## Parameter learning & flow aggregation tests ##
Expand Down

0 comments on commit d01830d

Please sign in to comment.