Skip to content

Commit

Permalink
receive input logspace_flows for SumLayer
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Mar 17, 2024
1 parent 6de9045 commit d1e7c02
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions src/pyjuice/layer/sum_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,8 @@ def forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, params: t
def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor,
node_mars: torch.Tensor, element_mars: torch.Tensor,
params: torch.Tensor, param_flows: Optional[torch.Tensor] = None,
allow_modify_flows: bool = False, propagation_alg: str = "LL", **kwargs) -> None:
allow_modify_flows: bool = False, propagation_alg: str = "LL",
logspace_flows: bool = False, **kwargs) -> None:
"""
Computes the forward pass of a sum layer:
```
Expand All @@ -276,6 +277,8 @@ def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor,
`params`: [num_params, B] or [num_params]
"""

assert not (allow_modify_flows and logspace_flows), "`allow_modify_flows` should be set to `False` when using `logspace_flows`."

# Disallow modifications of `node_flows` in case of partial evaluation
if self.provided("bk_partition_local_ids") and allow_modify_flows:
allow_modify_flows = False
Expand Down Expand Up @@ -308,6 +311,7 @@ def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor,
partition_id = partition_id,
allow_modify_flows = allow_modify_flows,
propagation_alg = propagation_alg,
logspace_flows = logspace_flows,
**kwargs
)

Expand All @@ -328,6 +332,7 @@ def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor,
partition_id = partition_id,
allow_modify_flows = allow_modify_flows,
propagation_alg = propagation_alg,
logspace_flows = logspace_flows,
**kwargs
)

Expand All @@ -346,6 +351,7 @@ def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor,
partition_id = partition_id,
allow_modify_flows = allow_modify_flows,
propagation_alg = propagation_alg,
logspace_flows = logspace_flows,
**kwargs
)

Expand Down Expand Up @@ -1211,7 +1217,8 @@ def _backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor,
cs_block_size: int = 0, local_ids: Optional[torch.Tensor] = None,
partition_id: int = -1, mode: Optional[str] = None,
allow_modify_flows: bool = False,
propagation_alg: str = "LL", **kwargs) -> None:
propagation_alg: str = "LL",
logspace_flows: bool = False, **kwargs) -> None:
"""
Back pass of sum layers.
Expand Down Expand Up @@ -1256,15 +1263,15 @@ def _backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor,
node_flows, element_flows, params, node_mars, element_mars, param_flows,
nids, cids, pids, pfids, chids, parids, parpids, cs_block_size, local_ids,
partition_id = partition_id, allow_modify_flows = allow_modify_flows,
propagation_alg = propagation_alg, **kwargs
propagation_alg = propagation_alg, logspace_flows = logspace_flows, **kwargs
)

elif mode == self.SPARSE:
self._backward_sparse(
node_flows, element_flows, params, node_mars, element_mars, param_flows,
nids, cids, pids, pfids, chids, parids, parpids, cs_block_size, local_ids,
partition_id = partition_id, allow_modify_flows = allow_modify_flows,
propagation_alg = propagation_alg, **kwargs
propagation_alg = propagation_alg, logspace_flows = logspace_flows, **kwargs
)

elif mode == self.PYTORCH:
Expand Down Expand Up @@ -1444,7 +1451,8 @@ def _backward_block_sparse(self, node_flows: torch.Tensor, element_flows: torch.
nids: Optional[torch.Tensor], cids: Optional[torch.Tensor], pids: Optional[torch.Tensor], pfids: Optional[torch.Tensor],
chids: Optional[torch.Tensor], parids: Optional[torch.Tensor], parpids: Optional[torch.Tensor],
cs_block_size: int, local_ids: Optional[torch.Tensor] = None,
partition_id: int = -1, allow_modify_flows: bool = False, propagation_alg: str = "LL", **kwargs) -> None:
partition_id: int = -1, allow_modify_flows: bool = False, propagation_alg: str = "LL",
logspace_flows: bool = False, **kwargs) -> None:
"""
Back pass of sum layers with block-sparse processing kernel.
Expand All @@ -1467,7 +1475,8 @@ def _backward_block_sparse(self, node_flows: torch.Tensor, element_flows: torch.
chids = chids, parids = parids, parpids = parpids,
cs_block_size = cs_block_size, local_ids = local_ids,
partition_id = partition_id, allow_modify_flows = allow_modify_flows,
propagation_alg = propagation_alg, **kwargs
propagation_alg = propagation_alg,
logspace_flows = logspace_flows, **kwargs
)

# Flows w.r.t. parameters
Expand All @@ -1476,7 +1485,8 @@ def _backward_block_sparse(self, node_flows: torch.Tensor, element_flows: torch.
node_flows, params, node_mars, element_mars, param_flows,
nids = nids, cids = cids, pids = pids, pfids = pfids,
allow_modify_flows = allow_modify_flows,
propagation_alg = propagation_alg, **kwargs
propagation_alg = propagation_alg,
logspace_flows = logspace_flows, **kwargs
)

return None
Expand Down Expand Up @@ -1554,7 +1564,7 @@ def _bk_triton_block_sparse_ele_kernel(node_flows, element_flows, node_mars, ele

eflows = tl.sum(tl.where(tl.abs(elpars[:,:,None] + emars[None,:,:] - nmars[:,None,:]) < 1e-6, nflows[:,None,:], 0.0), axis = 0)

if prop_logsumexp:
if logspace_flows:
# logaddexp
diff = acc - eflows
acc = tl.where(
Expand Down Expand Up @@ -1700,7 +1710,7 @@ def _bk_triton_block_sparse_ele_csmm2_kernel(node_flows, element_flows, node_mar
eflows = tl.sum(tl.where(tl.abs(elpars[None,:,:] + tl.trans(emars)[:,None,:] - nmars[:,:,None]) < 1e-6, nflows[:,:,None], 0.0), axis = 1)
eflows = tl.trans(eflows)

if prop_logsumexp:
if logspace_flows:
# logaddexp
diff = acc - eflows
acc = tl.where(
Expand Down Expand Up @@ -1778,7 +1788,6 @@ def _backward_block_sparse_ele_flows(self, node_flows: torch.Tensor, element_flo
propagation_alg: str = "LL", logspace_flows: bool = False, **kwargs) -> None:

assert params.dim() == 1, "Expecting a 1D `params`."
assert not (allow_modify_flows and logspace_flows), "`allow_modify_flows` should be set to `False` when using `logspace_flows`."

num_nblocks = chids.size(0) if local_ids is None else local_ids.size(0)
layer_n_nodes = num_nblocks * cs_block_size
Expand Down Expand Up @@ -2154,7 +2163,6 @@ def _backward_block_sparse_par_flows(self, node_flows: torch.Tensor, params: tor
"""

assert params.dim() == 1, "Expecting a 1D `params`."
assert not (allow_modify_flows and logspace_flows), "`allow_modify_flows` should be set to `False` when using `logspace_flows`."

num_nblocks = nids.size(0)
layer_n_nodes = num_nblocks * self.block_size
Expand Down Expand Up @@ -2496,7 +2504,6 @@ def _backward_sparse_ele_flows(self, node_flows: torch.Tensor, element_flows: to
logspace_flows: bool = False, **kwargs) -> None:

assert params.dim() == 1, "Expecting a 1D `params`."
assert not (allow_modify_flows and logspace_flows), "`allow_modify_flows` should be set to `False` when using `logspace_flows`."

num_nblocks = chids.size(0) if local_ids is None else local_ids.size(0)
layer_n_nodes = num_nblocks * cs_block_size
Expand Down Expand Up @@ -2732,7 +2739,6 @@ def _backward_sparse_par_flows(self, node_flows: torch.Tensor, params: torch.Ten
"""

assert params.dim() == 1, "Expecting a 1D `params`."
assert not (allow_modify_flows and logspace_flows), "`allow_modify_flows` should be set to `False` when using `logspace_flows`."

num_nblocks = nids.size(0)
layer_n_nodes = num_nblocks * self.block_size
Expand Down

0 comments on commit d1e7c02

Please sign in to comment.