diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index 61b4333e..afc015b6 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -2829,7 +2829,8 @@ def _backward_sparse_par_flows(self, node_flows: torch.Tensor, params: torch.Ten def _backward_pytorch(self, node_flows, element_flows, params, node_mars, element_mars, param_flows, nids, cids, pids, pfids, - chids, parids, parpids, cs_block_size): + chids, parids, parpids, cs_block_size, propagation_alg: str = "LL", + logspace_flows: bool = False): """ Back pass of sum layers with native pytorch. @@ -2845,18 +2846,20 @@ def _backward_pytorch(self, node_flows, element_flows, params, node_mars, `parpids`: [ng, c] """ + assert propagation_alg == "LL" + # Flows w.r.t. input elements (product nodes) if chids is not None: self._backward_pytorch_ele_kernel( node_flows, element_flows, params, node_mars, element_mars, - param_flows, chids, parids, parpids, cs_block_size + param_flows, chids, parids, parpids, cs_block_size, logspace_flows ) # Flows w.r.t. parameters if param_flows is not None and nids is not None: self._backward_pytorch_par_kernel( node_flows, params, node_mars, element_mars, param_flows, - nids, cids, pids, pfids, self.block_size + nids, cids, pids, pfids, self.block_size, logspace_flows ) @torch.compile @@ -2864,7 +2867,7 @@ def _backward_pytorch_ele_kernel(self, node_flows: torch.Tensor, element_flows: params: torch.Tensor, node_mars: torch.Tensor, element_mars: torch.Tensor, param_flows: Optional[torch.Tensor], chids: torch.Tensor, parids: torch.Tensor, parpids: torch.Tensor, - cs_block_size: int): + cs_block_size: int, logspace_flows: bool): num_nblocks = chids.size(0) num_eblocks = parids.size(1) @@ -2878,15 +2881,20 @@ def _backward_pytorch_ele_kernel(self, node_flows: torch.Tensor, element_flows: num_nblocks * cs_block_size, num_eblocks * self.block_size ) - element_flows[chids] = (node_flows[parids] * params[parpids].unsqueeze(-1) * \ - (element_mars[chids].unsqueeze(1) - node_mars[parids]).exp()).sum(dim = 1) + if logspace_flows: + element_flows[chids] = (node_flows[parids] + params[parpids].log().unsqueeze(-1) + \ + element_mars[chids].unsqueeze(1) - node_mars[parids]).logsumexp(dim = 1) + else: + element_flows[chids] = (node_flows[parids] * params[parpids].unsqueeze(-1) * \ + (element_mars[chids].unsqueeze(1) - node_mars[parids]).exp()).sum(dim = 1) return None @torch.compile def _backward_pytorch_par_kernel(self, node_flows: torch.Tensor, params: torch.Tensor, node_mars: torch.Tensor, element_mars: torch.Tensor, param_flows: torch.Tensor, nids: torch.Tensor, - cids: torch.Tensor, pids: torch.Tensor, pfids: torch.Tensor, ns_block_size: int): + cids: torch.Tensor, pids: torch.Tensor, pfids: torch.Tensor, ns_block_size: int, + logspace_flows: bool): num_nblocks = nids.size(0) num_edges = cids.size(1) @@ -2898,7 +2906,10 @@ def _backward_pytorch_par_kernel(self, node_flows: torch.Tensor, params: torch.T pfids = (pfids[:,None,:].repeat(1, self.block_size, 1) + \ torch.arange(0, self.block_size, device = cids.device)[None,:,None]).reshape(num_nblocks * self.block_size, num_edges) - parflows = (node_flows[nids].unsqueeze(1) * params[pids].unsqueeze(-1) * (element_mars[cids] - node_mars[nids].unsqueeze(1)).exp()).sum(dim = 2) + if logspace_flows: + parflows = (node_flows[nids].exp().unsqueeze(1) * params[pids].unsqueeze(-1) * (element_mars[cids] - node_mars[nids].unsqueeze(1)).exp()).sum(dim = 2) + else: + parflows = (node_flows[nids].unsqueeze(1) * params[pids].unsqueeze(-1) * (element_mars[cids] - node_mars[nids].unsqueeze(1)).exp()).sum(dim = 2) for i in range(num_nblocks): sid, eid = ns_block_size * i, ns_block_size * (i + 1)