Skip to content

Commit

Permalink
fix group for prod ns
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Jan 14, 2024
1 parent 2c27175 commit d74969c
Showing 1 changed file with 152 additions and 3 deletions.
155 changes: 152 additions & 3 deletions src/pyjuice/transformations/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,14 +183,14 @@ def update_ns(ns: CircuitNodes, ns_chs: Sequence[CircuitNodes]):
edge_ids = edge_ids.reshape(new_num_ngroups, group_mul_size, ns.num_chs)
if torch.all(edge_ids[:,1:,:] - edge_ids[:,:-1,:]) == 1:
# Block-sparse mode
edge_ids = edge_ids[:,0,:].contiguous()
edge_ids = edge_ids[:,0,:].contiguous() // group_mul_size
mode = "block_sparse"
else:
# Sparse mode
edge_ids = (edge_ids.reshape(ns.num_node_groups, ns.num_chs)[:,None,:] * ns.group_size + \
torch.arange(0, ns.group_size)[None,:,None]).flatten(0, 1)
mode = "sparse"

new_ns = ProdNodes(
num_node_groups = new_num_ngroups,
chs = ns_chs,
Expand Down Expand Up @@ -246,7 +246,12 @@ def update_ns(ns: CircuitNodes, ns_chs: Sequence[CircuitNodes]):
old_ch_group_size = ns.chs[0].group_size
if new_group_size > old_group_size or new_ch_group_size > old_ch_group_size:
new_params = torch.zeros([new_edge_ids.size(1), new_group_size, new_ch_group_size], device = device)
if use_cuda:
if new_edge_ids.size(1) == new_num_ngroups * new_num_cgroups:
# Fully-connected parameters
new_params = params.reshape(
new_num_ngroups, group_mul_size, new_num_cgroups, ch_group_mul_size, old_group_size, old_ch_group_size
).permute(0, 2, 1, 4, 3, 5).reshape(new_params.size()).contiguous()
elif use_cuda:
edge_ids_np = edge_ids.numpy()
new_edge_ids_np = new_edge_ids.numpy()

Expand Down Expand Up @@ -310,3 +315,147 @@ def update_ns(ns: CircuitNodes, ns_chs: Sequence[CircuitNodes]):
new_ns.set_source_ns(new_source_ns)

return new_root_ns


def bump_group_size(ns: CircuitNodes, group_size: int, use_cuda: bool = True):
assert group_size > ns.group_size, f"`group_size` already greater than {group_size}."
assert ns.num_nodes % group_size == 0, f"`num_nodes` not divicible by the target group size."

if use_cuda:
device = torch.device("cuda:0")
else:
device = torch.device("cpu")

group_mul_size = group_size // ns.group_size

if ns.is_input():
new_ns = InputNodes(
num_node_groups = new_num_ngroups,
scope = pydeepcopy(ns.scope),
dist = pydeepcopy(ns.dist),
group_size = group_size
)

if not ns.is_tied():
params = ns.get_params()
if params is not None:
new_ns.set_params(params.clone(), normalize = False)

elif ns.is_prod():
edge_ids = ns.edge_ids.clone()
edge_ids = edge_ids.reshape(new_num_ngroups, group_mul_size, ns.num_chs)
if torch.all(edge_ids[:,1:,:] - edge_ids[:,:-1,:]) == 1:
# Block-sparse mode
edge_ids = edge_ids[:,0,:].contiguous()
mode = "block_sparse"
else:
# Sparse mode
edge_ids = (edge_ids.reshape(ns.num_node_groups, ns.num_chs)[:,None,:] * ns.group_size + \
torch.arange(0, ns.group_size)[None,:,None]).flatten(0, 1)
mode = "sparse"

new_ns = ProdNodes(
num_node_groups = new_num_ngroups,
chs = ns_chs,
edge_ids = edge_ids,
group_size = group_size
)

if mode == "block_sparse":
assert new_ns.is_block_sparse()
elif mode == "sparse":
assert new_ns.is_sparse()
if ns.is_sum():
old_num_ngroups = ns.num_node_groups
old_group_size = ns.group_size
num_cgroups = sum([cs.num_node_groups for cs in ns.chs])

new_num_ngroups = ns.num_nodes // group_size

ns_chs = ns.chs
ch_group_size = ns.chs[0].group_size

edge_ids = ns.edge_ids.clone()
grid_edge_ids = torch.zeros([old_num_ngroups, num_cgroups], dtype = torch.bool)
grid_edge_ids[edge_ids[0,:],edge_ids[1,:]] = True

grid_edge_ids = grid_edge_ids.reshape(new_num_ngroups, group_mul_size, num_cgroups, 1)
new_edge_ids = torch.nonzero(grid_edge_ids.any(dim = 3).any(dim = 1), as_tuple = False).permute(1, 0)

new_ns = SumNodes(
num_node_groups = new_num_ngroups,
chs = ns_chs,
edge_ids = new_edge_ids,
group_size = group_size
)

if not ns.is_tied():
# Collect selected blocks
grid_edge_ids = grid_edge_ids.permute(0, 2, 1, 3).flatten(0, 1)
block_ids = new_edge_ids[0,:] * num_cgroups + new_edge_ids[1,:]
param_indicator = grid_edge_ids[block_ids,:,:]
if not torch.all(param_indicator):
param_indicator = param_indicator[:,:,None,:,None].repeat(1, 1, ns.group_size, 1, ns.chs[0].group_size)
param_indicator = param_indicator.flatten(3, 4).flatten(1, 2)
zero_param_mask = ~param_indicator

new_ns.set_zero_param_mask(zero_param_mask)

params = ns.get_params()
if params is not None:
new_params = torch.zeros([new_edge_ids.size(1), group_size, ch_group_size], device = device)
if new_edge_ids.size(1) == new_num_ngroups * num_cgroups:
# Fully-connected parameters
new_params = params.reshape(
new_num_ngroups, group_size // old_group_size, num_cgroups, 1, old_group_size, ch_group_size
).permute(0, 2, 1, 4, 3, 5).reshape(new_params.size()).contiguous()
elif use_cuda:
edge_ids_np = edge_ids.numpy()
new_edge_ids_np = new_edge_ids.numpy()

target_id0 = np.zeros([edge_ids.size(1)], dtype = np.int64) - 1
target_id1 = np.zeros([2, edge_ids.size(1)], dtype = np.int64) - 1
target_id2 = np.zeros([2, edge_ids.size(1)], dtype = np.int64) - 1

_compute_param_target_ids_kernel(
target_id0, target_id1, target_id2, edge_ids_np, new_edge_ids_np,
group_mul_size, 1, old_group_size, ch_group_size
)

target_id0 = torch.from_numpy(target_id0).to(device)
target_id1 = torch.from_numpy(target_id1).to(device)
target_id2 = torch.from_numpy(target_id2).to(device)

params = params.to(device)

BLOCK_M = min(32, old_group_size)
BLOCK_N = min(32, ch_group_size)

grid = (ch_group_size // BLOCK_N, old_group_size // BLOCK_M, edge_ids.size(1))

_copy_params_kernel[grid](
new_params, params, target_id0, target_id1, target_id2,
old_group_size = old_group_size,
old_ch_group_size = ch_group_size,
new_group_size = group_size,
new_ch_group_size = ch_group_size,
BLOCK_M = BLOCK_M,
BLOCK_N = BLOCK_N
)

else:
for par_group_id in range(new_edge_ids.size(1)):
nsid = new_edge_ids[0,par_group_id] * group_mul_size
neid = nsid + group_mul_size
csid = new_edge_ids[1,par_group_id]
ceid = csid + 1

blk_ids = torch.where((edge_ids[0,:] >= nsid) & (edge_ids[0,:] < neid) & (edge_ids[1,:] >= csid) & (edge_ids[1,:] < ceid))[0]
for blk_id in blk_ids:
nid0, nid1 = (edge_ids[0,blk_id] - nsid) * ns.group_size, (edge_ids[0,blk_id] - nsid + 1) * ns.group_size
cid0, cid1 = (edge_ids[1,blk_id] - csid) * ns.chs[0].group_size, (edge_ids[1,blk_id] - csid + 1) * ns.chs[0].group_size
new_params[par_group_id,nid0:nid1,cid0:cid1] = params[blk_id,:,:]

new_ns.set_params(new_params.cpu(), normalize = False)

return new_ns

0 comments on commit d74969c

Please sign in to comment.