From d1e7c02881cadeaa7ff1cc97f402552c7ff18102 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Sun, 17 Mar 2024 23:53:44 +0800 Subject: [PATCH] receive input `logspace_flows` for `SumLayer` --- src/pyjuice/layer/sum_layer.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index afc015b6..d3ce78db 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -255,7 +255,8 @@ def forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, params: t def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, node_mars: torch.Tensor, element_mars: torch.Tensor, params: torch.Tensor, param_flows: Optional[torch.Tensor] = None, - allow_modify_flows: bool = False, propagation_alg: str = "LL", **kwargs) -> None: + allow_modify_flows: bool = False, propagation_alg: str = "LL", + logspace_flows: bool = False, **kwargs) -> None: """ Computes the forward pass of a sum layer: ``` @@ -276,6 +277,8 @@ def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, `params`: [num_params, B] or [num_params] """ + assert not (allow_modify_flows and logspace_flows), "`allow_modify_flows` should be set to `False` when using `logspace_flows`." + # Disallow modifications of `node_flows` in case of partial evaluation if self.provided("bk_partition_local_ids") and allow_modify_flows: allow_modify_flows = False @@ -308,6 +311,7 @@ def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, partition_id = partition_id, allow_modify_flows = allow_modify_flows, propagation_alg = propagation_alg, + logspace_flows = logspace_flows, **kwargs ) @@ -328,6 +332,7 @@ def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, partition_id = partition_id, allow_modify_flows = allow_modify_flows, propagation_alg = propagation_alg, + logspace_flows = logspace_flows, **kwargs ) @@ -346,6 +351,7 @@ def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, partition_id = partition_id, allow_modify_flows = allow_modify_flows, propagation_alg = propagation_alg, + logspace_flows = logspace_flows, **kwargs ) @@ -1211,7 +1217,8 @@ def _backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, cs_block_size: int = 0, local_ids: Optional[torch.Tensor] = None, partition_id: int = -1, mode: Optional[str] = None, allow_modify_flows: bool = False, - propagation_alg: str = "LL", **kwargs) -> None: + propagation_alg: str = "LL", + logspace_flows: bool = False, **kwargs) -> None: """ Back pass of sum layers. @@ -1256,7 +1263,7 @@ def _backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, node_flows, element_flows, params, node_mars, element_mars, param_flows, nids, cids, pids, pfids, chids, parids, parpids, cs_block_size, local_ids, partition_id = partition_id, allow_modify_flows = allow_modify_flows, - propagation_alg = propagation_alg, **kwargs + propagation_alg = propagation_alg, logspace_flows = logspace_flows, **kwargs ) elif mode == self.SPARSE: @@ -1264,7 +1271,7 @@ def _backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, node_flows, element_flows, params, node_mars, element_mars, param_flows, nids, cids, pids, pfids, chids, parids, parpids, cs_block_size, local_ids, partition_id = partition_id, allow_modify_flows = allow_modify_flows, - propagation_alg = propagation_alg, **kwargs + propagation_alg = propagation_alg, logspace_flows = logspace_flows, **kwargs ) elif mode == self.PYTORCH: @@ -1444,7 +1451,8 @@ def _backward_block_sparse(self, node_flows: torch.Tensor, element_flows: torch. nids: Optional[torch.Tensor], cids: Optional[torch.Tensor], pids: Optional[torch.Tensor], pfids: Optional[torch.Tensor], chids: Optional[torch.Tensor], parids: Optional[torch.Tensor], parpids: Optional[torch.Tensor], cs_block_size: int, local_ids: Optional[torch.Tensor] = None, - partition_id: int = -1, allow_modify_flows: bool = False, propagation_alg: str = "LL", **kwargs) -> None: + partition_id: int = -1, allow_modify_flows: bool = False, propagation_alg: str = "LL", + logspace_flows: bool = False, **kwargs) -> None: """ Back pass of sum layers with block-sparse processing kernel. @@ -1467,7 +1475,8 @@ def _backward_block_sparse(self, node_flows: torch.Tensor, element_flows: torch. chids = chids, parids = parids, parpids = parpids, cs_block_size = cs_block_size, local_ids = local_ids, partition_id = partition_id, allow_modify_flows = allow_modify_flows, - propagation_alg = propagation_alg, **kwargs + propagation_alg = propagation_alg, + logspace_flows = logspace_flows, **kwargs ) # Flows w.r.t. parameters @@ -1476,7 +1485,8 @@ def _backward_block_sparse(self, node_flows: torch.Tensor, element_flows: torch. node_flows, params, node_mars, element_mars, param_flows, nids = nids, cids = cids, pids = pids, pfids = pfids, allow_modify_flows = allow_modify_flows, - propagation_alg = propagation_alg, **kwargs + propagation_alg = propagation_alg, + logspace_flows = logspace_flows, **kwargs ) return None @@ -1554,7 +1564,7 @@ def _bk_triton_block_sparse_ele_kernel(node_flows, element_flows, node_mars, ele eflows = tl.sum(tl.where(tl.abs(elpars[:,:,None] + emars[None,:,:] - nmars[:,None,:]) < 1e-6, nflows[:,None,:], 0.0), axis = 0) - if prop_logsumexp: + if logspace_flows: # logaddexp diff = acc - eflows acc = tl.where( @@ -1700,7 +1710,7 @@ def _bk_triton_block_sparse_ele_csmm2_kernel(node_flows, element_flows, node_mar eflows = tl.sum(tl.where(tl.abs(elpars[None,:,:] + tl.trans(emars)[:,None,:] - nmars[:,:,None]) < 1e-6, nflows[:,:,None], 0.0), axis = 1) eflows = tl.trans(eflows) - if prop_logsumexp: + if logspace_flows: # logaddexp diff = acc - eflows acc = tl.where( @@ -1778,7 +1788,6 @@ def _backward_block_sparse_ele_flows(self, node_flows: torch.Tensor, element_flo propagation_alg: str = "LL", logspace_flows: bool = False, **kwargs) -> None: assert params.dim() == 1, "Expecting a 1D `params`." - assert not (allow_modify_flows and logspace_flows), "`allow_modify_flows` should be set to `False` when using `logspace_flows`." num_nblocks = chids.size(0) if local_ids is None else local_ids.size(0) layer_n_nodes = num_nblocks * cs_block_size @@ -2154,7 +2163,6 @@ def _backward_block_sparse_par_flows(self, node_flows: torch.Tensor, params: tor """ assert params.dim() == 1, "Expecting a 1D `params`." - assert not (allow_modify_flows and logspace_flows), "`allow_modify_flows` should be set to `False` when using `logspace_flows`." num_nblocks = nids.size(0) layer_n_nodes = num_nblocks * self.block_size @@ -2496,7 +2504,6 @@ def _backward_sparse_ele_flows(self, node_flows: torch.Tensor, element_flows: to logspace_flows: bool = False, **kwargs) -> None: assert params.dim() == 1, "Expecting a 1D `params`." - assert not (allow_modify_flows and logspace_flows), "`allow_modify_flows` should be set to `False` when using `logspace_flows`." num_nblocks = chids.size(0) if local_ids is None else local_ids.size(0) layer_n_nodes = num_nblocks * cs_block_size @@ -2732,7 +2739,6 @@ def _backward_sparse_par_flows(self, node_flows: torch.Tensor, params: torch.Ten """ assert params.dim() == 1, "Expecting a 1D `params`." - assert not (allow_modify_flows and logspace_flows), "`allow_modify_flows` should be set to `False` when using `logspace_flows`." num_nblocks = nids.size(0) layer_n_nodes = num_nblocks * self.block_size