Skip to content

Commit

Permalink
logspace backward with pytorch kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Mar 15, 2024
1 parent 917843d commit 59b89bc
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions src/pyjuice/layer/sum_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2829,7 +2829,8 @@ def _backward_sparse_par_flows(self, node_flows: torch.Tensor, params: torch.Ten

def _backward_pytorch(self, node_flows, element_flows, params, node_mars,
element_mars, param_flows, nids, cids, pids, pfids,
chids, parids, parpids, cs_block_size):
chids, parids, parpids, cs_block_size, propagation_alg: str = "LL",
logspace_flows: bool = False):
"""
Back pass of sum layers with native pytorch.
Expand All @@ -2845,26 +2846,28 @@ def _backward_pytorch(self, node_flows, element_flows, params, node_mars,
`parpids`: [ng, c]
"""

assert propagation_alg == "LL"

# Flows w.r.t. input elements (product nodes)
if chids is not None:
self._backward_pytorch_ele_kernel(
node_flows, element_flows, params, node_mars, element_mars,
param_flows, chids, parids, parpids, cs_block_size
param_flows, chids, parids, parpids, cs_block_size, logspace_flows
)

# Flows w.r.t. parameters
if param_flows is not None and nids is not None:
self._backward_pytorch_par_kernel(
node_flows, params, node_mars, element_mars, param_flows,
nids, cids, pids, pfids, self.block_size
nids, cids, pids, pfids, self.block_size, logspace_flows
)

@torch.compile
def _backward_pytorch_ele_kernel(self, node_flows: torch.Tensor, element_flows: torch.Tensor,
params: torch.Tensor, node_mars: torch.Tensor,
element_mars: torch.Tensor, param_flows: Optional[torch.Tensor],
chids: torch.Tensor, parids: torch.Tensor, parpids: torch.Tensor,
cs_block_size: int):
cs_block_size: int, logspace_flows: bool):

num_nblocks = chids.size(0)
num_eblocks = parids.size(1)
Expand All @@ -2878,15 +2881,20 @@ def _backward_pytorch_ele_kernel(self, node_flows: torch.Tensor, element_flows:
num_nblocks * cs_block_size, num_eblocks * self.block_size
)

element_flows[chids] = (node_flows[parids] * params[parpids].unsqueeze(-1) * \
(element_mars[chids].unsqueeze(1) - node_mars[parids]).exp()).sum(dim = 1)
if logspace_flows:
element_flows[chids] = (node_flows[parids] + params[parpids].log().unsqueeze(-1) + \
element_mars[chids].unsqueeze(1) - node_mars[parids]).logsumexp(dim = 1)
else:
element_flows[chids] = (node_flows[parids] * params[parpids].unsqueeze(-1) * \
(element_mars[chids].unsqueeze(1) - node_mars[parids]).exp()).sum(dim = 1)

return None

@torch.compile
def _backward_pytorch_par_kernel(self, node_flows: torch.Tensor, params: torch.Tensor, node_mars: torch.Tensor,
element_mars: torch.Tensor, param_flows: torch.Tensor, nids: torch.Tensor,
cids: torch.Tensor, pids: torch.Tensor, pfids: torch.Tensor, ns_block_size: int):
cids: torch.Tensor, pids: torch.Tensor, pfids: torch.Tensor, ns_block_size: int,
logspace_flows: bool):

num_nblocks = nids.size(0)
num_edges = cids.size(1)
Expand All @@ -2898,7 +2906,10 @@ def _backward_pytorch_par_kernel(self, node_flows: torch.Tensor, params: torch.T
pfids = (pfids[:,None,:].repeat(1, self.block_size, 1) + \
torch.arange(0, self.block_size, device = cids.device)[None,:,None]).reshape(num_nblocks * self.block_size, num_edges)

parflows = (node_flows[nids].unsqueeze(1) * params[pids].unsqueeze(-1) * (element_mars[cids] - node_mars[nids].unsqueeze(1)).exp()).sum(dim = 2)
if logspace_flows:
parflows = (node_flows[nids].exp().unsqueeze(1) * params[pids].unsqueeze(-1) * (element_mars[cids] - node_mars[nids].unsqueeze(1)).exp()).sum(dim = 2)
else:
parflows = (node_flows[nids].unsqueeze(1) * params[pids].unsqueeze(-1) * (element_mars[cids] - node_mars[nids].unsqueeze(1)).exp()).sum(dim = 2)

for i in range(num_nblocks):
sid, eid = ns_block_size * i, ns_block_size * (i + 1)
Expand Down

0 comments on commit 59b89bc

Please sign in to comment.