Skip to content

Commit

Permalink
MPE and GeneralLL for forward pass (block sparse kernels)
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Mar 9, 2024
1 parent 7d41932 commit edb67ec
Show file tree
Hide file tree
Showing 4 changed files with 318 additions and 51 deletions.
17 changes: 17 additions & 0 deletions src/pyjuice/layer/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@


class Layer():

propagation_alg_mapping = {
"LL": 0,
"MPE": 1,
"GeneralLL": 2
}

def __init__(self, nodes: Sequence[CircuitNodes], disable_block_size_check: bool = False) -> None:

if disable_block_size_check:
Expand Down Expand Up @@ -60,3 +67,13 @@ def disable_partial_evaluation(self, forward: bool = True, backward: bool = True

def provided(self, var_name):
return hasattr(self, var_name) and getattr(self, var_name) is not None

def _get_propagation_alg_kwargs(self, propagation_alg: str, **kwargs):
if propagation_alg == "LL":
return {}
elif propagation_alg == "MPE":
return {}
elif propagation_alg == "GeneralLL":
return {"alpha": kwargs["alpha"]}
else:
raise ValueError(f"Unknown propagation algorithm {propagation_alg}.")
196 changes: 148 additions & 48 deletions src/pyjuice/layer/sum_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ def num_param_flows(self):
return self._layer_pfid_range[1] - self._layer_pfid_range[0]

def forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, params: torch.Tensor,
force_use_fp16: bool = False, force_use_fp32: bool = False) -> None:
force_use_fp16: bool = False, force_use_fp32: bool = False,
propagation_alg: str = "LL", **kwargs) -> None:
"""
Computes the forward pass of a sum layer.
Expand All @@ -228,7 +229,8 @@ def forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, params: t
self._forward(
node_mars, element_mars, params, nids, cids, pids,
partition_id = partition_id, force_use_fp16 = force_use_fp16,
force_use_fp32 = force_use_fp32
force_use_fp32 = force_use_fp32,
propagation_alg = propagation_alg, **kwargs
)

else:
Expand All @@ -243,7 +245,8 @@ def forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, params: t
node_mars, element_mars, params,
nids, cids, pids, local_ids = local_ids,
partition_id = partition_id, force_use_fp16 = force_use_fp16,
force_use_fp32 = force_use_fp32
force_use_fp32 = force_use_fp32,
propagation_alg = propagation_alg, **kwargs
)

return None
Expand Down Expand Up @@ -344,7 +347,8 @@ def _forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor,
params: torch.Tensor, nids: torch.Tensor, cids: torch.Tensor,
pids: torch.Tensor, local_ids: Optional[torch.Tensor] = None,
partition_id: int = -1, mode: Optional[str] = None,
force_use_fp16: bool = False, force_use_fp32: bool = False) -> None:
force_use_fp16: bool = False, force_use_fp32: bool = False,
propagation_alg: str = "LL", **kwargs) -> None:
"""
Forward pass of sum layers.
Expand Down Expand Up @@ -380,18 +384,19 @@ def _forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor,
self._forward_block_sparse(
node_mars, element_mars, params, nids, cids, pids, local_ids,
partition_id = partition_id, force_use_fp16 = force_use_fp16,
force_use_fp32 = force_use_fp32
force_use_fp32 = force_use_fp32, propagation_alg = propagation_alg, **kwargs
)

elif mode == self.SPARSE:
self._forward_sparse(
node_mars, element_mars, params, nids, cids, pids, local_ids,
partition_id = partition_id
partition_id = partition_id, propagation_alg = propagation_alg, **kwargs
)

elif mode == self.PYTORCH:
self._forward_pytorch(
node_mars, element_mars, params, nids, cids, pids, local_ids
node_mars, element_mars, params, nids, cids, pids, local_ids,
propagation_alg = propagation_alg, **kwargs
)

else:
Expand All @@ -403,7 +408,8 @@ def _forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor,
def _fw_triton_block_sparse_tlmm_kernel(node_mars, element_mars, params, nids, cids_start, cids_increment,
pids_start, pids_increment, local_ids, batch_size: tl.constexpr, partial_eval: tl.constexpr,
BLOCK_B: tl.constexpr, TILE_SIZE_K: tl.constexpr, K_NUM_TILES: tl.constexpr,
TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, use_fp16: tl.constexpr):
TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, use_fp16: tl.constexpr,
propagation_alg_id: tl.constexpr, alpha = 0.0):

pid_b = tl.program_id(0) # ID of size-`BLOCK_B` batches
pid_m = tl.program_id(1) # ID of size-`TILE_SIZE_M` nodes
Expand Down Expand Up @@ -453,22 +459,45 @@ def _fw_triton_block_sparse_tlmm_kernel(node_mars, element_mars, params, nids, c
epars = tl.load(epars_ptr)
emars = tl.load(emars_ptr, mask = mask_batch[None,:])

emars_max = tl.max(emars, axis = 0)[None,:]
emars_sub = tl.where(emars_max != -float("inf"), tl.exp(emars - emars_max), 0.0)
if propagation_alg_id == 1:
# MPE propagation method
lpars = tl.log(epars)
nmars = tl.max(lpars[:,:,None] + emars[None,:,:], axis = 1)

acc = tl.maximum(acc, nmars)

if use_fp16 == 1:
# Built-in matmul kernel of triton + float16
epars_fp16 = (epars * (2**12)).to(tl.float16)
emars_fp16 = emars_sub.to(tl.float16)
nmars = tl.dot(epars_fp16, emars_fp16).to(tl.float32) / (2**12)
else:
# Built-in matmul kernel of triton + float32
nmars = tl.dot(epars, emars_sub)

acc = tl.where(emars_max > acc,
tl.log(nmars + tl.exp(acc - emars_max)) + emars_max,
tl.log(tl.exp(emars_max - acc) * nmars + 1.0) + acc
)
if propagation_alg_id == 0:
# LL propagation method
emars_max = tl.max(emars, axis = 0)[None,:]
emars_sub = tl.where(emars_max != -float("inf"), tl.exp(emars - emars_max), 0.0)

if propagation_alg_id == 2:
# GeneralLL propagation method

emars_max = tl.max(emars, axis = 0)[None,:]
# Compute p_i^{alpha} for every i
emars_sub = tl.where(emars_max != -float("inf"), tl.exp((emars - emars_max) * alpha), 0.0)
# Compute w_i^{alpha} for every i
epars = tl.exp(tl.log(epars) * alpha)

# Also scale `emars_max`
emars_max *= alpha

if use_fp16 == 1:
# Built-in matmul kernel of triton + float16
epars_fp16 = (epars * (2**12)).to(tl.float16)
emars_fp16 = emars_sub.to(tl.float16)
nmars = tl.dot(epars_fp16, emars_fp16).to(tl.float32) / (2**12)
else:
# Built-in matmul kernel of triton + float32
nmars = tl.dot(epars, emars_sub)

acc = tl.where(emars_max > acc,
tl.log(nmars + tl.exp(acc - emars_max)) + emars_max,
tl.log(tl.exp(emars_max - acc) * nmars + 1.0) + acc
)

# Increment `epars_ptr`
pids_inc = tl.load(pids_inc_ptr)
Expand All @@ -480,6 +509,10 @@ def _fw_triton_block_sparse_tlmm_kernel(node_mars, element_mars, params, nids, c
emars_ptr += cids_inc[:,None] * batch_size
cids_inc_ptr += TILE_SIZE_K

if propagation_alg_id == 2:
# Compute p_i^{1/alpha}
acc *= (1.0 / alpha)

# Write back
off_nids = tl.load(nids + nblock_id)
offs_nmars = (off_nids + offs_node[:,None]) * batch_size + offs_batch[None,:]
Expand All @@ -491,7 +524,8 @@ def _fw_triton_block_sparse_tlmm_kernel(node_mars, element_mars, params, nids, c
def _fw_triton_block_sparse_csmm1_kernel(node_mars, element_mars, params, nids, cids_start, cids_increment,
pids_start, pids_increment, local_ids, batch_size: tl.constexpr, partial_eval: tl.constexpr,
BLOCK_B: tl.constexpr, TILE_SIZE_K: tl.constexpr, K_NUM_TILES: tl.constexpr,
TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, use_fp16: tl.constexpr):
TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, use_fp16: tl.constexpr,
propagation_alg_id: tl.constexpr, alpha = 0.0):

pid_b = tl.program_id(0) # ID of size-`BLOCK_B` batches
pid_m = tl.program_id(1) # ID of size-`TILE_SIZE_M` nodes
Expand Down Expand Up @@ -541,22 +575,45 @@ def _fw_triton_block_sparse_csmm1_kernel(node_mars, element_mars, params, nids,
epars = tl.load(epars_ptr)
emars = tl.load(emars_ptr, mask = mask_batch[None,:])

emars_max = tl.max(emars, axis = 0)[None,:]
emars_sub = tl.where(emars_max != -float("inf"), tl.exp(emars - emars_max), 0.0)
if propagation_alg_id == 1:
# MPE propagation method
lpars = tl.log(epars)
nmars = tl.max(lpars[:,:,None] + emars[None,:,:], axis = 1)

acc = tl.maximum(acc, nmars)

if use_fp16 == 1:
# Simulated matmul kernel + float16
epars = (epars * (2**4)).to(tl.float16)
emars_sub = emars_sub.to(tl.float16)
nmars = tl.sum(epars[:,:,None] * emars_sub[None,:,:], axis = 1).to(tl.float32) / (2**4)
else:
# Simulated matmul kernel + float32
nmars = tl.sum(epars[:,:,None] * emars_sub[None,:,:], axis = 1)

acc = tl.where(emars_max > acc,
tl.log(nmars + tl.exp(acc - emars_max)) + emars_max,
tl.log(tl.exp(emars_max - acc) * nmars + 1.0) + acc
)
if propagation_alg_id == 0:
# LL propagation method
emars_max = tl.max(emars, axis = 0)[None,:]
emars_sub = tl.where(emars_max != -float("inf"), tl.exp(emars - emars_max), 0.0)

if propagation_alg_id == 2:
# GeneralLL propagation method

emars_max = tl.max(emars, axis = 0)[None,:]
# Compute p_i^{alpha} for every i
emars_sub = tl.where(emars_max != -float("inf"), tl.exp((emars - emars_max) * alpha), 0.0)
# Compute w_i^{alpha} for every i
epars = tl.exp(tl.log(epars) * alpha)

# Also scale `emars_max`
emars_max *= alpha

if use_fp16 == 1:
# Simulated matmul kernel + float16
epars = (epars * (2**4)).to(tl.float16)
emars_sub = emars_sub.to(tl.float16)
nmars = tl.sum(epars[:,:,None] * emars_sub[None,:,:], axis = 1).to(tl.float32) / (2**4)
else:
# Simulated matmul kernel + float32
nmars = tl.sum(epars[:,:,None] * emars_sub[None,:,:], axis = 1)

acc = tl.where(emars_max > acc,
tl.log(nmars + tl.exp(acc - emars_max)) + emars_max,
tl.log(tl.exp(emars_max - acc) * nmars + 1.0) + acc
)

# Increment `epars_ptr`
pids_inc = tl.load(pids_inc_ptr)
Expand All @@ -568,6 +625,10 @@ def _fw_triton_block_sparse_csmm1_kernel(node_mars, element_mars, params, nids,
emars_ptr += cids_inc[:,None] * batch_size
cids_inc_ptr += TILE_SIZE_K

if propagation_alg_id == 2:
# Compute p_i^{1/alpha}
acc *= (1.0 / alpha)

# Write back
off_nids = tl.load(nids + nblock_id)
offs_nmars = (off_nids + offs_node[:,None]) * batch_size + offs_batch[None,:]
Expand All @@ -579,7 +640,8 @@ def _fw_triton_block_sparse_csmm1_kernel(node_mars, element_mars, params, nids,
def _fw_triton_block_sparse_csmm2_kernel(node_mars, element_mars, params, nids, cids_start, cids_increment,
pids_start, pids_increment, local_ids, batch_size: tl.constexpr, partial_eval: tl.constexpr,
BLOCK_B: tl.constexpr, TILE_SIZE_K: tl.constexpr, K_NUM_TILES: tl.constexpr,
TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, use_fp16: tl.constexpr):
TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, use_fp16: tl.constexpr,
propagation_alg_id: tl.constexpr, alpha = 0.0):

pid_b = tl.program_id(0) # ID of size-`BLOCK_B` batches
pid_m = tl.program_id(1) # ID of size-`TILE_SIZE_M` nodes
Expand Down Expand Up @@ -629,16 +691,39 @@ def _fw_triton_block_sparse_csmm2_kernel(node_mars, element_mars, params, nids,
epars = tl.load(epars_ptr)
emars = tl.load(emars_ptr, mask = mask_batch[:,None])

emars_max = tl.max(emars, axis = 1)
emars_sub = tl.where(emars_max[:,None] != -float("inf"), tl.exp(emars - emars_max[:,None]), 0.0)
if propagation_alg_id == 1:
# MPE propagation method
lpars = tl.log(epars)
nmars = tl.max(lpars[:,:,None] + tl.trans(emars)[None,:,:], axis = 1)

# Simulated matmul kernel + float32
nmars = tl.sum(epars[:,:,None] * tl.trans(emars_sub)[None,:,:], axis = 1)
acc = tl.maximum(acc, nmars)

acc = tl.where(emars_max[None,:] > acc,
tl.log(nmars + tl.exp(acc - emars_max[None,:])) + emars_max[None,:],
tl.log(tl.exp(emars_max[None,:] - acc) * nmars + 1.0) + acc
)
else:

if propagation_alg_id == 0:
# LL propagation method
emars_max = tl.max(emars, axis = 1)
emars_sub = tl.where(emars_max[:,None] != -float("inf"), tl.exp(emars - emars_max[:,None]), 0.0)

if propagation_alg_id == 2:
# GeneralLL propagation method

emars_max = tl.max(emars, axis = 1)
# Compute p_i^{alpha} for every i
emars_sub = tl.where(emars_max[:,None] != -float("inf"), tl.exp((emars - emars_max[:,None]) * alpha), 0.0)
# Compute w_i^{alpha} for every i
epars = tl.exp(tl.log(epars) * alpha)

# Also scale `emars_max`
emars_max *= alpha

# Simulated matmul kernel + float32
nmars = tl.sum(epars[:,:,None] * tl.trans(emars_sub)[None,:,:], axis = 1)

acc = tl.where(emars_max[None,:] > acc,
tl.log(nmars + tl.exp(acc - emars_max[None,:])) + emars_max[None,:],
tl.log(tl.exp(emars_max[None,:] - acc) * nmars + 1.0) + acc
)

# Increment `epars_ptr`
pids_inc = tl.load(pids_inc_ptr)
Expand All @@ -650,6 +735,10 @@ def _fw_triton_block_sparse_csmm2_kernel(node_mars, element_mars, params, nids,
emars_ptr += cids_inc[None,:] * batch_size
cids_inc_ptr += TILE_SIZE_K

if propagation_alg_id == 2:
# Compute p_i^{1/alpha}
acc *= (1.0 / alpha)

# Write back
off_nids = tl.load(nids + nblock_id)
offs_nmars = (off_nids + offs_node[:,None]) * batch_size + offs_batch[None,:]
Expand All @@ -659,7 +748,7 @@ def _forward_block_sparse(self, node_mars: torch.Tensor, element_mars: torch.Ten
params: torch.Tensor, nids: torch.Tensor, cids: torch.Tensor,
pids: torch.Tensor, local_ids: Optional[torch.Tensor] = None,
partition_id: int = -1, force_use_fp16: bool = False,
force_use_fp32: bool = False) -> None:
force_use_fp32: bool = False, propagation_alg: str = "LL", **kwargs) -> None:
"""
Forward pass of sum layers with the block-sparse processing kernel.
Expand All @@ -680,6 +769,10 @@ def _forward_block_sparse(self, node_mars: torch.Tensor, element_mars: torch.Ten
batch_size = node_mars.size(1)
BATCH_SIZE_NP2 = triton.next_power_of_2(batch_size)

# Propagation algorithm
propagation_alg_id = self.propagation_alg_mapping[propagation_alg]
propagation_alg_kwargs = self._get_propagation_alg_kwargs(propagation_alg, **kwargs)

# Heuristic to set `TILE_SIZE_M`, `TILE_SIZE_K`, and `BLOCK_B`
base_size = min(self.block_size, num_edges, BATCH_SIZE_NP2, 128)
if base_size >= 64:
Expand Down Expand Up @@ -751,7 +844,9 @@ def _forward_block_sparse(self, node_mars: torch.Tensor, element_mars: torch.Ten
K_NUM_TILES = K_NUM_TILES,
TILE_SIZE_M = TILE_SIZE_M,
BLOCK_SIZE_M = BLOCK_SIZE_M,
use_fp16 = use_fp16
use_fp16 = use_fp16,
propagation_alg_id = propagation_alg_id,
**propagation_alg_kwargs
)

elif TILE_SIZE_M >= 8 and TILE_SIZE_K >= 8 and BLOCK_B >= 8:
Expand All @@ -772,8 +867,11 @@ def _forward_block_sparse(self, node_mars: torch.Tensor, element_mars: torch.Ten
K_NUM_TILES = K_NUM_TILES,
TILE_SIZE_M = TILE_SIZE_M,
BLOCK_SIZE_M = BLOCK_SIZE_M,
use_fp16 = use_fp16
use_fp16 = use_fp16,
propagation_alg_id = propagation_alg_id,
**propagation_alg_kwargs
)

else:
self._fw_triton_block_sparse_csmm2_kernel[grid](
node_mars,
Expand All @@ -792,7 +890,9 @@ def _forward_block_sparse(self, node_mars: torch.Tensor, element_mars: torch.Ten
K_NUM_TILES = K_NUM_TILES,
TILE_SIZE_M = TILE_SIZE_M,
BLOCK_SIZE_M = BLOCK_SIZE_M,
use_fp16 = use_fp16
use_fp16 = use_fp16,
propagation_alg_id = propagation_alg_id,
**propagation_alg_kwargs
)

return None
Expand Down
Loading

0 comments on commit edb67ec

Please sign in to comment.