Skip to content

Commit

Permalink
add log-space backward option for product layers
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Mar 15, 2024
1 parent e298620 commit c61463d
Showing 1 changed file with 129 additions and 28 deletions.
157 changes: 129 additions & 28 deletions src/pyjuice/layer/prod_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@
import time
from typing import Sequence, Optional

# In the latest triton, math functions were shuffled around into different modules:
# https://github.com/openai/triton/pull/3172
if hasattr(tl.extra.cuda, "libdevice"):
tlmath = tl.extra.cuda.libdevic
else:
tlmath = tl.math

from pyjuice.nodes import ProdNodes
from pyjuice.utils.parameter_list import FastParamList
from pyjuice.utils.kernel_launcher import FastJITFunction
Expand Down Expand Up @@ -249,7 +256,7 @@ def enable_partial_evaluation(self, fw_scopes: Optional[Sequence[BitSet]] = None
@FastJITFunction
def _forward_backward_kernel_3d(node_vals_ptr, element_vals_ptr, local_ids_ptr, nids_ptr, cids_ptr, tot_n_nodes, tot_n_eles, n_nblocks,
num_edges: tl.constexpr, batch_size, BLOCK_M: tl.constexpr, BLOCK_B: tl.constexpr,
block_size: tl.constexpr, accum: tl.constexpr, partial_eval: tl.constexpr):
block_size: tl.constexpr, accum: tl.constexpr, partial_eval: tl.constexpr, prop_logsumexp: tl.constexpr):
"""
This kernel implements the function with 3d tensors. However, it only work with `triton==2.0.0`.
"""
Expand All @@ -264,7 +271,7 @@ def _forward_backward_kernel_3d(node_vals_ptr, element_vals_ptr, local_ids_ptr,
ntile_id = pid_m % (block_size // BLOCK_M)

# For partial evaluation
if partial_eval == 1:
if partial_eval:
nblock_id = tl.load(local_ids_ptr + nblock_id)

# Batch offsets and mask
Expand All @@ -282,17 +289,36 @@ def _forward_backward_kernel_3d(node_vals_ptr, element_vals_ptr, local_ids_ptr,
offs_evals = offs_egstart + block_nids[:,None]
evals = tl.load(element_vals_ptr + offs_evals[None,:,:] * batch_size + offs_batch[:,None,None], mask = mask_batch[:,None,None])

# Take the sum of the child nodes' log-probabilities
nvals = tl.sum(evals, axis = 2)
if prop_logsumexp:
# Take the logsumexp of the child nodes' values
evals_max = tl.max(evals, axis = 2)
nvals = tl.sum(tl.exp(evals - evals_max[:,:,None]), axis = 2) + evals_max
else:
# Take the sum of the child nodes' values
nvals = tl.sum(evals, axis = 2)

# Node ids to `node_vals_ptr`
nblock_start = tl.load(nids_ptr + nblock_id)
offs_nvals = (nblock_start + block_nids[None,:]) * batch_size + offs_batch[:,None]

# Accumulate the `node_vals` if required
if accum == 1:
if accum:
node_vals = tl.load(node_vals_ptr + offs_nvals, mask = mask_batch[:,None], other = 0)
nvals += node_vals
if prop_logsumexp:
# logaddexp
diff = nvals - node_vals
nvals = tl.where(
diff == 0,
nvals + 0.69314718055994530942, # log(2)
tl.where(
diff > 0,
nvals + tlmath.log1p(tl.exp(-diff)),
node_vals + tlmath.log1p(tl.exp(diff))
)
)
else:
# sum
nvals += node_vals

tl.store(node_vals_ptr + offs_nvals, nvals, mask = mask_batch[:,None])

Expand All @@ -306,7 +332,7 @@ def _forward_backward_kernel_3d(node_vals_ptr, element_vals_ptr, local_ids_ptr,
nblock_ids = offs_node // block_size

# For partial evaluation
if partial_eval == 1:
if partial_eval:
nblock_ids = tl.load(local_ids_ptr + nblock_ids, mask = mask_node)

# Batch offsets and mask
Expand All @@ -323,17 +349,36 @@ def _forward_backward_kernel_3d(node_vals_ptr, element_vals_ptr, local_ids_ptr,
offs_evals = offs_egstart + block_nids[:,None]
evals = tl.load(element_vals_ptr + offs_evals[None,:,:] * batch_size + offs_batch[:,None,None], mask = (mask_batch[:,None,None] & mask_node[None,:,None]))

# Take the sum of the child nodes' log-probabilities
nvals = tl.sum(evals, axis = 2)
if prop_logsumexp:
# Take the logsumexp of the child nodes' values
evals_max = tl.max(evals, axis = 2)
nvals = tl.sum(tl.exp(evals - evals_max[:,:,None]), axis = 2) + evals_max
else:
# Take the sum of the child nodes' values
nvals = tl.sum(evals, axis = 2)

# Node ids to `node_vals_ptr`
nblock_start = tl.load(nids_ptr + nblock_ids[None,:])
offs_nvals = (nblock_start + block_nids[None,:]) * batch_size + offs_batch[:,None]

# Accumulate the `node_vals` if required
if accum == 1:
if accum:
node_vals = tl.load(node_vals_ptr + offs_nvals, mask = mask_batch[:,None], other = 0)
nvals += node_vals
if prop_logsumexp:
# logaddexp
diff = nvals - node_vals
nvals = tl.where(
diff == 0,
nvals + 0.69314718055994530942, # log(2)
tl.where(
diff > 0,
nvals + tlmath.log1p(tl.exp(-diff)),
node_vals + tlmath.log1p(tl.exp(diff))
)
)
else:
# sum
nvals += node_vals

tl.store(node_vals_ptr + offs_nvals, nvals, mask = mask_batch[:,None])

Expand All @@ -342,7 +387,7 @@ def _forward_backward_kernel_3d(node_vals_ptr, element_vals_ptr, local_ids_ptr,
@FastJITFunction
def _forward_backward_kernel_2d(node_vals_ptr, element_vals_ptr, local_ids_ptr, nids_ptr, cids_ptr, tot_n_nodes, tot_n_eles, n_nblocks,
num_edges: tl.constexpr, batch_size, BLOCK_M: tl.constexpr, BLOCK_B: tl.constexpr,
block_size: tl.constexpr, accum: tl.constexpr, partial_eval: tl.constexpr):
block_size: tl.constexpr, accum: tl.constexpr, partial_eval: tl.constexpr, prop_logsumexp: tl.constexpr):
"""
This kernel implements the function with 2d tensors. It works for all `triton` versions.
"""
Expand All @@ -355,7 +400,7 @@ def _forward_backward_kernel_2d(node_vals_ptr, element_vals_ptr, local_ids_ptr,
ntile_id = pid_m % (block_size // BLOCK_M)

# For partial evaluation
if partial_eval == 1:
if partial_eval:
nblock_id = tl.load(local_ids_ptr + nblock_id)

# Batch offsets and mask
Expand All @@ -376,12 +421,34 @@ def _forward_backward_kernel_2d(node_vals_ptr, element_vals_ptr, local_ids_ptr,
# Inner loop
for i in range(0, BLOCK_M):
evals = tl.load(element_vals_ptr + offs_evals, mask = mask_batch[None,:], other = 0)
nvals = tl.sum(evals, axis = 0)

if prop_logsumexp:
# Take the logsumexp of the child nodes' values
evals_max = tl.max(evals, axis = 0)
nvals = tl.sum(tl.exp(evals - evals_max[:,:,None]), axis = 2) + evals_max
else:
# Take the sum of the child nodes' values
nvals = tl.sum(evals, axis = 0)

# Accumulate the `node_vals` if required
if accum == 1:
if accum:
node_vals = tl.load(node_vals_ptr + offs_nvals, mask = mask_batch)
nvals += node_vals

if prop_logsumexp:
# logaddexp
diff = nvals - node_vals
nvals = tl.where(
diff == 0,
nvals + 0.69314718055994530942, # log(2)
tl.where(
diff > 0,
nvals + tlmath.log1p(tl.exp(-diff)),
node_vals + tlmath.log1p(tl.exp(diff))
)
)
else:
# sum
nvals += node_vals

tl.store(node_vals_ptr + offs_nvals, nvals, mask = mask_batch)

Expand All @@ -393,7 +460,8 @@ def _forward_backward_kernel_2d(node_vals_ptr, element_vals_ptr, local_ids_ptr,
@FastJITFunction
def _forward_backward_kernel_large(node_vals_ptr, element_vals_ptr, local_ids_ptr, nids_ptr, cids_ptr, tot_n_nodes, tot_n_eles, n_nblocks,
num_edges: tl.constexpr, batch_size, BLOCK_N: tl.constexpr, BLOCK_B: tl.constexpr,
N_NUM_BLKS: tl.constexpr, block_size: tl.constexpr, accum: tl.constexpr, partial_eval: tl.constexpr):
N_NUM_BLKS: tl.constexpr, block_size: tl.constexpr, accum: tl.constexpr, partial_eval: tl.constexpr,
prop_logsumexp: tl.constexpr):
"""
This kernel implements the function with 2d tensors. It is designed for nodes with many edges.
"""
Expand All @@ -406,7 +474,7 @@ def _forward_backward_kernel_large(node_vals_ptr, element_vals_ptr, local_ids_pt
ntile_id = pid_m % block_size

# For partial evaluation
if partial_eval == 1:
if partial_eval:
nblock_id = tl.load(local_ids_ptr + nblock_id)

# Batch offsets and mask
Expand All @@ -425,11 +493,27 @@ def _forward_backward_kernel_large(node_vals_ptr, element_vals_ptr, local_ids_pt
nblock_start = tl.load(nids_ptr + nblock_id)
offs_nvals = (nblock_start + ntile_id) * batch_size + offs_batch # [BLOCK_B]

# Prepare buffer
if prop_logsumexp:
nvals = tl.zeros([BLOCK_B], dtype = tl.float32) - float("inf")
else:
nvals = tl.zeros([BLOCK_B], dtype = tl.float32)

# Inner loop
nvals = tl.zeros([BLOCK_B], dtype = tl.float32)
for i in range(0, N_NUM_BLKS):
evals = tl.load(element_vals_ptr + offs_evals, mask = (mask_edge[:,None] & mask_batch[None,:]), other = 0)
nvals += tl.sum(evals, axis = 0)

if prop_logsumexp:
# Take the logsumexp of the child nodes' values
evals_max = tl.max(evals, axis = 0)
nvals_sub = tl.sum(tl.exp(evals - evals_max[None,:]), axis = 2)
nvals = tl.where(evals_max > nvals,
tl.log(nvals_sub + tl.exp(nvals - evals_max) + 1e-24) + evals_max,
tl.log(tl.exp(evals_max - nvals) * nvals_sub + 1.0) + nvals
)
else:
# Take the sum of the child nodes' values
nvals += tl.sum(evals, axis = 0)

offs_edge += BLOCK_N
mask_edge = (offs_edge < num_edges)
Expand All @@ -439,24 +523,38 @@ def _forward_backward_kernel_large(node_vals_ptr, element_vals_ptr, local_ids_pt
offs_evals = (offs_egstart[:,None] + ntile_id) * batch_size + offs_batch[None,:] # [BLOCK_N, BLOCK_B]

# Accumulate the `node_vals` if required
if accum == 1:
if accum:
node_vals = tl.load(node_vals_ptr + offs_nvals, mask = mask_batch)
nvals += node_vals

if prop_logsumexp:
# logaddexp
diff = nvals - node_vals
nvals = tl.where(
diff == 0,
nvals + 0.69314718055994530942, # log(2)
tl.where(
diff > 0,
nvals + tlmath.log1p(tl.exp(-diff)),
node_vals + tlmath.log1p(tl.exp(diff))
)
)
else:
# sum
nvals += node_vals

tl.store(node_vals_ptr + offs_nvals, nvals, mask = mask_batch)

def _forward_backward(self, node_vals: torch.Tensor, element_vals: torch.Tensor,
nids: torch.Tensor, cids: torch.Tensor, local_ids: Optional[torch.Tensor] = None,
accum: bool = False) -> None:
accum: bool = False, prop_logsumexp: bool = False) -> None:
tot_n_nodes = node_vals.size(0)
tot_n_eles = element_vals.size(0)
n_nblocks = nids.size(0) if local_ids is None else local_ids.size(0)
num_edges = cids.size(1)
batch_size = node_vals.size(1)

block_size = self.block_size
accum = 1 if accum else 0
partial_eval = 1 if local_ids is not None else 0
partial_eval = local_ids is not None

assert num_edges & (num_edges - 1) == 0, "`num_edges` must be a power of 2."

Expand Down Expand Up @@ -484,7 +582,8 @@ def _forward_backward(self, node_vals: torch.Tensor, element_vals: torch.Tensor,
N_NUM_BLKS = triton.cdiv(num_edges, BLOCK_B),
block_size = block_size,
accum = accum,
partial_eval = partial_eval
partial_eval = partial_eval,
prop_logsumexp = prop_logsumexp
)

return None
Expand All @@ -511,7 +610,8 @@ def _forward_backward(self, node_vals: torch.Tensor, element_vals: torch.Tensor,
BLOCK_B = BLOCK_B,
block_size = block_size,
accum = accum,
partial_eval = partial_eval
partial_eval = partial_eval,
prop_logsumexp = prop_logsumexp
)

else:
Expand All @@ -536,7 +636,8 @@ def _forward_backward(self, node_vals: torch.Tensor, element_vals: torch.Tensor,
BLOCK_B = BLOCK_B,
block_size = block_size,
accum = accum,
partial_eval = partial_eval
partial_eval = partial_eval,
prop_logsumexp = prop_logsumexp
)

return None
Expand Down

0 comments on commit c61463d

Please sign in to comment.