diff --git a/src/pyjuice/nodes/sum_nodes.py b/src/pyjuice/nodes/sum_nodes.py index e058219c..6dcec7ed 100644 --- a/src/pyjuice/nodes/sum_nodes.py +++ b/src/pyjuice/nodes/sum_nodes.py @@ -95,6 +95,13 @@ def set_params(self, params: torch.Tensor, normalize: bool = True, pseudocount: self._params = params.clone().view(-1, 1, 1) + elif params.dim() == 2: + ch_num_ngroups = sum([cs.num_node_groups for cs in self.chs]) + assert params.size(0) == self.num_nodes + assert params.size(1) == self.ch_group_size * ch_num_ngroups + + self._params = params.reshape(self.num_node_groups, self.group_size, ch_num_ngroups, self.ch_group_size).permute(0, 2, 1, 3).flatten(0, 1).contiguous() + elif params.dim() == 3: assert self.edge_ids.size(1) == params.size(0) and self.group_size == params.size(1) and self.ch_group_size == params.size(2)