Skip to content

Commit

Permalink
receive logspace_flows in ProdLayer
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Mar 17, 2024
1 parent 0db6431 commit 6de9045
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/pyjuice/layer/prod_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, _for_back

return None

def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, **kwargs) -> None:
def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, logspace_flows: bool = False, **kwargs) -> None:
"""
Computes the backward pass of a product layer:
```
Expand All @@ -222,15 +222,17 @@ def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, **kwar
parids = self.partitioned_parids[partition_id]
local_ids = self.bk_partition_local_ids[partition_id]

self._forward_backward(node_flows, element_flows, u_cids, parids, local_ids = local_ids, accum = True)
self._forward_backward(node_flows, element_flows, u_cids, parids, local_ids = local_ids, accum = True,
prop_logsumexp = logspace_flows)

else:
# Evaluate the whole layer
for partition_id in range(self.num_bk_partitions):
u_cids = self.partitioned_u_cids[partition_id]
parids = self.partitioned_parids[partition_id]

self._forward_backward(node_flows, element_flows, u_cids, parids, accum = True)
self._forward_backward(node_flows, element_flows, u_cids, parids, accum = True,
prop_logsumexp = logspace_flows)

return None

Expand Down Expand Up @@ -427,7 +429,7 @@ def _forward_backward_kernel_2d(node_vals_ptr, element_vals_ptr, local_ids_ptr,
if prop_logsumexp:
# Take the logsumexp of the child nodes' values
evals_max = tl.max(evals, axis = 0)
nvals = tl.log(tl.sum(tl.exp(evals - evals_max[:,:,None]), axis = 2)) + evals_max
nvals = tl.log(tl.sum(tl.exp(evals - evals_max[None,:]), axis = 0)) + evals_max
else:
# Take the sum of the child nodes' values
nvals = tl.sum(evals, axis = 0)
Expand Down

0 comments on commit 6de9045

Please sign in to comment.