Skip to content

Commit

Permalink
fix bug: set appropriate other values in pflow accumulation kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Jan 9, 2024
1 parent a9b2cf9 commit e4efbe9
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions src/pyjuice/layer/sum_layer.py
Original file line number Diff line number Diff line change
@@ -1495,13 +1495,13 @@ def _bk_triton_block_sparse_par_kernel(node_flows, node_mars, element_mars, para
acc = tl.zeros([TILE_SIZE_M, TILE_SIZE_K], dtype = tl.float32)

for b in range(0, B_NUM_TILES):
emars = tl.load(emars_ptr, mask = mask_batch[:,None]) # [TILE_SIZE_B, TILE_SIZE_K]
emars = tl.load(emars_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_K]

if allow_modify_flows == 1:
log_n_fdm = tl.load(nflows_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_M, TILE_SIZE_B]
log_n_fdm = tl.load(nflows_ptr, mask = mask_batch[None,:], other = -float("inf")) # [TILE_SIZE_M, TILE_SIZE_B]
else:
nflows = tl.load(nflows_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_M, TILE_SIZE_B]
nmars = tl.load(nmars_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_M, TILE_SIZE_B]
nflows = tl.load(nflows_ptr, mask = mask_batch[None,:], other = 0.0) # [TILE_SIZE_M, TILE_SIZE_B]
nmars = tl.load(nmars_ptr, mask = mask_batch[None,:], other = 0.0) # [TILE_SIZE_M, TILE_SIZE_B]
log_n_fdm = tl.log(nflows) - nmars

log_n_fdm_max = tl.max(log_n_fdm, axis = 0)
@@ -1574,13 +1574,13 @@ def _bk_triton_block_sparse_par_csmm2_kernel(node_flows, node_mars, element_mars
acc = tl.zeros([TILE_SIZE_M, TILE_SIZE_K], dtype = tl.float32)

for b in range(0, B_NUM_TILES):
emars = tl.load(emars_ptr, mask = mask_batch[:,None]) # [TILE_SIZE_B, TILE_SIZE_K]
emars = tl.load(emars_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_K]

if allow_modify_flows == 1:
log_n_fdm = tl.load(nflows_ptr, mask = mask_batch[:,None]) # [TILE_SIZE_B, TILE_SIZE_M]
log_n_fdm = tl.load(nflows_ptr, mask = mask_batch[:,None], other = -float("inf")) # [TILE_SIZE_B, TILE_SIZE_M]
else:
nflows = tl.load(nflows_ptr, mask = mask_batch[:,None]) # [TILE_SIZE_B, TILE_SIZE_M]
nmars = tl.load(nmars_ptr, mask = mask_batch[:,None]) # [TILE_SIZE_B, TILE_SIZE_M]
nflows = tl.load(nflows_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_M]
nmars = tl.load(nmars_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_M]
log_n_fdm = tl.log(nflows) - nmars

log_n_fdm_max = tl.max(log_n_fdm, axis = 1)
@@ -1901,14 +1901,14 @@ def _bk_triton_sparse_par_kernel(node_flows, node_mars, element_mars, params, pa
offs_batch = tl.arange(0, BLOCK_B) + pid_b * TILE_SIZE_B + b * BLOCK_B
mask_batch = offs_batch < batch_size

emars = tl.load(emars_ptr, mask = mask_batch[None,:]) # [BLOCK_K, BLOCK_B]
emars = tl.load(emars_ptr, mask = mask_batch[None,:], other = -float("inf")) # [BLOCK_K, BLOCK_B]

if allow_modify_flows == 1:
log_n_fdm = tl.load(nflows_ptr, mask = mask_batch) # [BLOCK_B]
pflows = tl.sum(tl.exp(emars + log_n_fdm[None,:]), axis = 1)
else:
nmars = tl.load(nmars_ptr, mask = mask_batch) # [BLOCK_B]
nflows = tl.load(nflows_ptr, mask = mask_batch) # [BLOCK_B]
nmars = tl.load(nmars_ptr, mask = mask_batch, other = 0.0) # [BLOCK_B]
nflows = tl.load(nflows_ptr, mask = mask_batch, other = 0.0) # [BLOCK_B]
pflows = tl.sum(nflows[None,:] * tl.exp(emars - nmars[None,:]), axis = 1)

acc += pflows
@@ -1984,6 +1984,7 @@ def _backward_sparse_par_flows(self, node_flows: torch.Tensor, params: torch.Ten
grid = (B_NUM_TILES, K_NUM_BLOCKS, layer_n_nodes)

self._bk_triton_sparse_par_kernel[grid](
ddd = ddd,
node_flows = node_flows,
node_mars = node_mars,
element_mars = element_mars,

0 comments on commit e4efbe9

Please sign in to comment.