Skip to content

Commit

Permalink
avoid nans in backward pass for zero-flow inner nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Mar 17, 2024
1 parent 83694a6 commit fd41521
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 24 deletions.
34 changes: 17 additions & 17 deletions src/pyjuice/layer/prod_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,14 +370,14 @@ def _forward_backward_kernel_3d(node_vals_ptr, element_vals_ptr, local_ids_ptr,

if prop_logsumexp:
# logaddexp
diff = nvals - node_vals
diff = node_vals - nvals
nvals = tl.where(
diff == 0,
nvals + 0.69314718055994530942, # log(2)
nvals == -float("inf"),
node_vals,
tl.where(
diff > 0,
nvals + tlmath.log1p(tl.exp(-diff)),
node_vals + tlmath.log1p(tl.exp(diff))
node_vals + tlmath.log1p(tl.exp(-diff)),
nvals + tlmath.log1p(tl.exp(diff))
)
)
else:
Expand Down Expand Up @@ -429,7 +429,7 @@ def _forward_backward_kernel_2d(node_vals_ptr, element_vals_ptr, local_ids_ptr,
if prop_logsumexp:
# Take the logsumexp of the child nodes' values
evals_max = tl.max(evals, axis = 0)
nvals = tl.log(tl.sum(tl.exp(evals - evals_max[None,:]), axis = 0)) + evals_max
nvals = tl.where(evals_max != -float("inf"), tl.log(tl.sum(tl.exp(evals - evals_max[None,:]), axis = 0)) + evals_max, -float("inf"))
else:
# Take the sum of the child nodes' values
nvals = tl.sum(evals, axis = 0)
Expand All @@ -440,14 +440,14 @@ def _forward_backward_kernel_2d(node_vals_ptr, element_vals_ptr, local_ids_ptr,

if prop_logsumexp:
# logaddexp
diff = nvals - node_vals
diff = node_vals - nvals
nvals = tl.where(
diff == 0,
nvals + 0.69314718055994530942, # log(2)
nvals == -float("inf"),
node_vals,
tl.where(
diff > 0,
nvals + tlmath.log1p(tl.exp(-diff)),
node_vals + tlmath.log1p(tl.exp(diff))
node_vals + tlmath.log1p(tl.exp(-diff)),
nvals + tlmath.log1p(tl.exp(diff))
)
)
else:
Expand Down Expand Up @@ -510,7 +510,7 @@ def _forward_backward_kernel_large(node_vals_ptr, element_vals_ptr, local_ids_pt
if prop_logsumexp:
# Take the logsumexp of the child nodes' values
evals_max = tl.max(evals, axis = 0)
nvals_sub = tl.sum(tl.exp(evals - evals_max[None,:]), axis = 2)
nvals_sub = tl.where(evals_max != -float("inf"), tl.sum(tl.exp(evals - evals_max[None,:]), axis = 0), 0.0)
nvals = tl.where(evals_max > nvals,
tl.log(nvals_sub + tl.exp(nvals - evals_max) + 1e-24) + evals_max,
tl.log(tl.exp(evals_max - nvals) * nvals_sub + 1.0) + nvals
Expand All @@ -532,14 +532,14 @@ def _forward_backward_kernel_large(node_vals_ptr, element_vals_ptr, local_ids_pt

if prop_logsumexp:
# logaddexp
diff = nvals - node_vals
diff = node_vals - nvals
nvals = tl.where(
diff == 0,
nvals + 0.69314718055994530942, # log(2)
nvals == -float("inf"),
node_vals,
tl.where(
diff > 0,
nvals + tlmath.log1p(tl.exp(-diff)),
node_vals + tlmath.log1p(tl.exp(diff))
node_vals + tlmath.log1p(tl.exp(-diff)),
nvals + tlmath.log1p(tl.exp(diff))
)
)
else:
Expand Down
23 changes: 16 additions & 7 deletions src/pyjuice/layer/sum_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1613,9 +1613,12 @@ def _bk_triton_block_sparse_ele_kernel(node_flows, element_flows, node_mars, ele

if logspace_flows:
partial_flows_max = emars + log_n_fdm_max
acc = tl.where(partial_flows_max > acc,
tl.log(partial_flows + tl.exp(acc - partial_flows_max) + 1e-24) + partial_flows_max,
tl.log(tl.exp(partial_flows_max - acc) * partial_flows + 1.0) + acc
acc = tl.where(log_n_fdm_max == -float("inf"),
acc,
tl.where(partial_flows_max > acc,
tl.log(partial_flows + tl.exp(acc - partial_flows_max) + 1e-24) + partial_flows_max,
tl.log(tl.exp(partial_flows_max - acc) * partial_flows + 1.0) + acc
)
)
else:
acc += partial_flows * tl.exp(emars + log_n_fdm_max)
Expand Down Expand Up @@ -1757,9 +1760,12 @@ def _bk_triton_block_sparse_ele_csmm2_kernel(node_flows, element_flows, node_mar

if logspace_flows:
partial_flows_max = emars + log_n_fdm_max[None,:]
acc = tl.where(partial_flows_max > acc,
tl.log(partial_flows + tl.exp(acc - partial_flows_max) + 1e-24) + partial_flows_max,
tl.log(tl.exp(partial_flows_max - acc) * partial_flows + 1.0) + acc
acc = tl.where(log_n_fdm_max[None,:] == -float("inf"),
acc,
tl.where(partial_flows_max > acc,
tl.log(partial_flows + tl.exp(acc - partial_flows_max) + 1e-24) + partial_flows_max,
tl.log(tl.exp(partial_flows_max - acc) * partial_flows + 1.0) + acc
)
)
else:
acc += partial_flows * tl.exp(emars + log_n_fdm_max[None,:])
Expand Down Expand Up @@ -2692,7 +2698,10 @@ def _bk_triton_sparse_par_kernel(node_flows, node_mars, element_mars, params, pa
if logspace_flows:
plflows = nflows[None,:] + emars - nmars[None,:]
plflows_max = tl.max(plflows, axis = 1)
pflows = tl.sum(tl.exp(plflows - plflows_max[:,None]), axis = 1) * tl.exp(plflows_max)
pflows = tl.where(plflows_max != -float("inf"),
tl.exp(tl.log(tl.sum(tl.exp(plflows - plflows_max[:,None]), axis = 1)) + plflows_max),
0.0
)
else:
pflows = tl.sum(nflows[None,:] * tl.exp(emars - nmars[None,:]), axis = 1)

Expand Down

0 comments on commit fd41521

Please sign in to comment.