diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index 0e8269a6..8dd874a6 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -1495,13 +1495,13 @@ def _bk_triton_block_sparse_par_kernel(node_flows, node_mars, element_mars, para acc = tl.zeros([TILE_SIZE_M, TILE_SIZE_K], dtype = tl.float32) for b in range(0, B_NUM_TILES): - emars = tl.load(emars_ptr, mask = mask_batch[:,None]) # [TILE_SIZE_B, TILE_SIZE_K] + emars = tl.load(emars_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_K] if allow_modify_flows == 1: - log_n_fdm = tl.load(nflows_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_M, TILE_SIZE_B] + log_n_fdm = tl.load(nflows_ptr, mask = mask_batch[None,:], other = -float("inf")) # [TILE_SIZE_M, TILE_SIZE_B] else: - nflows = tl.load(nflows_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_M, TILE_SIZE_B] - nmars = tl.load(nmars_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_M, TILE_SIZE_B] + nflows = tl.load(nflows_ptr, mask = mask_batch[None,:], other = 0.0) # [TILE_SIZE_M, TILE_SIZE_B] + nmars = tl.load(nmars_ptr, mask = mask_batch[None,:], other = 0.0) # [TILE_SIZE_M, TILE_SIZE_B] log_n_fdm = tl.log(nflows) - nmars log_n_fdm_max = tl.max(log_n_fdm, axis = 0) @@ -1574,13 +1574,13 @@ def _bk_triton_block_sparse_par_csmm2_kernel(node_flows, node_mars, element_mars acc = tl.zeros([TILE_SIZE_M, TILE_SIZE_K], dtype = tl.float32) for b in range(0, B_NUM_TILES): - emars = tl.load(emars_ptr, mask = mask_batch[:,None]) # [TILE_SIZE_B, TILE_SIZE_K] + emars = tl.load(emars_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_K] if allow_modify_flows == 1: - log_n_fdm = tl.load(nflows_ptr, mask = mask_batch[:,None]) # [TILE_SIZE_B, TILE_SIZE_M] + log_n_fdm = tl.load(nflows_ptr, mask = mask_batch[:,None], other = -float("inf")) # [TILE_SIZE_B, TILE_SIZE_M] else: - nflows = tl.load(nflows_ptr, mask = mask_batch[:,None]) # [TILE_SIZE_B, TILE_SIZE_M] - nmars = tl.load(nmars_ptr, mask = mask_batch[:,None]) # [TILE_SIZE_B, TILE_SIZE_M] + nflows = tl.load(nflows_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_M] + nmars = tl.load(nmars_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_M] log_n_fdm = tl.log(nflows) - nmars log_n_fdm_max = tl.max(log_n_fdm, axis = 1) @@ -1901,14 +1901,14 @@ def _bk_triton_sparse_par_kernel(node_flows, node_mars, element_mars, params, pa offs_batch = tl.arange(0, BLOCK_B) + pid_b * TILE_SIZE_B + b * BLOCK_B mask_batch = offs_batch < batch_size - emars = tl.load(emars_ptr, mask = mask_batch[None,:]) # [BLOCK_K, BLOCK_B] + emars = tl.load(emars_ptr, mask = mask_batch[None,:], other = -float("inf")) # [BLOCK_K, BLOCK_B] if allow_modify_flows == 1: log_n_fdm = tl.load(nflows_ptr, mask = mask_batch) # [BLOCK_B] pflows = tl.sum(tl.exp(emars + log_n_fdm[None,:]), axis = 1) else: - nmars = tl.load(nmars_ptr, mask = mask_batch) # [BLOCK_B] - nflows = tl.load(nflows_ptr, mask = mask_batch) # [BLOCK_B] + nmars = tl.load(nmars_ptr, mask = mask_batch, other = 0.0) # [BLOCK_B] + nflows = tl.load(nflows_ptr, mask = mask_batch, other = 0.0) # [BLOCK_B] pflows = tl.sum(nflows[None,:] * tl.exp(emars - nmars[None,:]), axis = 1) acc += pflows @@ -1984,6 +1984,7 @@ def _backward_sparse_par_flows(self, node_flows: torch.Tensor, params: torch.Ten grid = (B_NUM_TILES, K_NUM_BLOCKS, layer_n_nodes) self._bk_triton_sparse_par_kernel[grid]( + ddd = ddd, node_flows = node_flows, node_mars = node_mars, element_mars = element_mars,