Skip to content

Commit

Permalink
accept 2d sum params for fully connected nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Jan 15, 2024
1 parent 745c6a3 commit b643c3f
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/pyjuice/nodes/sum_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ def set_params(self, params: torch.Tensor, normalize: bool = True, pseudocount:

self._params = params.clone().view(-1, 1, 1)

elif params.dim() == 2:
ch_num_ngroups = sum([cs.num_node_groups for cs in self.chs])
assert params.size(0) == self.num_nodes
assert params.size(1) == self.ch_group_size * ch_num_ngroups

self._params = params.reshape(self.num_node_groups, self.group_size, ch_num_ngroups, self.ch_group_size).permute(0, 2, 1, 3).flatten(0, 1).contiguous()

elif params.dim() == 3:
assert self.edge_ids.size(1) == params.size(0) and self.group_size == params.size(1) and self.ch_group_size == params.size(2)

Expand Down

0 comments on commit b643c3f

Please sign in to comment.