From fd415218971e7e8f515a78fa32c93c6130d4272c Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Mon, 18 Mar 2024 05:19:56 +0800 Subject: [PATCH] avoid nans in backward pass for zero-flow inner nodes --- src/pyjuice/layer/prod_layer.py | 34 ++++++++++++++++----------------- src/pyjuice/layer/sum_layer.py | 23 +++++++++++++++------- 2 files changed, 33 insertions(+), 24 deletions(-) diff --git a/src/pyjuice/layer/prod_layer.py b/src/pyjuice/layer/prod_layer.py index 5a8e106c..c47fa24a 100644 --- a/src/pyjuice/layer/prod_layer.py +++ b/src/pyjuice/layer/prod_layer.py @@ -370,14 +370,14 @@ def _forward_backward_kernel_3d(node_vals_ptr, element_vals_ptr, local_ids_ptr, if prop_logsumexp: # logaddexp - diff = nvals - node_vals + diff = node_vals - nvals nvals = tl.where( - diff == 0, - nvals + 0.69314718055994530942, # log(2) + nvals == -float("inf"), + node_vals, tl.where( diff > 0, - nvals + tlmath.log1p(tl.exp(-diff)), - node_vals + tlmath.log1p(tl.exp(diff)) + node_vals + tlmath.log1p(tl.exp(-diff)), + nvals + tlmath.log1p(tl.exp(diff)) ) ) else: @@ -429,7 +429,7 @@ def _forward_backward_kernel_2d(node_vals_ptr, element_vals_ptr, local_ids_ptr, if prop_logsumexp: # Take the logsumexp of the child nodes' values evals_max = tl.max(evals, axis = 0) - nvals = tl.log(tl.sum(tl.exp(evals - evals_max[None,:]), axis = 0)) + evals_max + nvals = tl.where(evals_max != -float("inf"), tl.log(tl.sum(tl.exp(evals - evals_max[None,:]), axis = 0)) + evals_max, -float("inf")) else: # Take the sum of the child nodes' values nvals = tl.sum(evals, axis = 0) @@ -440,14 +440,14 @@ def _forward_backward_kernel_2d(node_vals_ptr, element_vals_ptr, local_ids_ptr, if prop_logsumexp: # logaddexp - diff = nvals - node_vals + diff = node_vals - nvals nvals = tl.where( - diff == 0, - nvals + 0.69314718055994530942, # log(2) + nvals == -float("inf"), + node_vals, tl.where( diff > 0, - nvals + tlmath.log1p(tl.exp(-diff)), - node_vals + tlmath.log1p(tl.exp(diff)) + node_vals + tlmath.log1p(tl.exp(-diff)), + nvals + tlmath.log1p(tl.exp(diff)) ) ) else: @@ -510,7 +510,7 @@ def _forward_backward_kernel_large(node_vals_ptr, element_vals_ptr, local_ids_pt if prop_logsumexp: # Take the logsumexp of the child nodes' values evals_max = tl.max(evals, axis = 0) - nvals_sub = tl.sum(tl.exp(evals - evals_max[None,:]), axis = 2) + nvals_sub = tl.where(evals_max != -float("inf"), tl.sum(tl.exp(evals - evals_max[None,:]), axis = 0), 0.0) nvals = tl.where(evals_max > nvals, tl.log(nvals_sub + tl.exp(nvals - evals_max) + 1e-24) + evals_max, tl.log(tl.exp(evals_max - nvals) * nvals_sub + 1.0) + nvals @@ -532,14 +532,14 @@ def _forward_backward_kernel_large(node_vals_ptr, element_vals_ptr, local_ids_pt if prop_logsumexp: # logaddexp - diff = nvals - node_vals + diff = node_vals - nvals nvals = tl.where( - diff == 0, - nvals + 0.69314718055994530942, # log(2) + nvals == -float("inf"), + node_vals, tl.where( diff > 0, - nvals + tlmath.log1p(tl.exp(-diff)), - node_vals + tlmath.log1p(tl.exp(diff)) + node_vals + tlmath.log1p(tl.exp(-diff)), + nvals + tlmath.log1p(tl.exp(diff)) ) ) else: diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index d3ce78db..286e193e 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -1613,9 +1613,12 @@ def _bk_triton_block_sparse_ele_kernel(node_flows, element_flows, node_mars, ele if logspace_flows: partial_flows_max = emars + log_n_fdm_max - acc = tl.where(partial_flows_max > acc, - tl.log(partial_flows + tl.exp(acc - partial_flows_max) + 1e-24) + partial_flows_max, - tl.log(tl.exp(partial_flows_max - acc) * partial_flows + 1.0) + acc + acc = tl.where(log_n_fdm_max == -float("inf"), + acc, + tl.where(partial_flows_max > acc, + tl.log(partial_flows + tl.exp(acc - partial_flows_max) + 1e-24) + partial_flows_max, + tl.log(tl.exp(partial_flows_max - acc) * partial_flows + 1.0) + acc + ) ) else: acc += partial_flows * tl.exp(emars + log_n_fdm_max) @@ -1757,9 +1760,12 @@ def _bk_triton_block_sparse_ele_csmm2_kernel(node_flows, element_flows, node_mar if logspace_flows: partial_flows_max = emars + log_n_fdm_max[None,:] - acc = tl.where(partial_flows_max > acc, - tl.log(partial_flows + tl.exp(acc - partial_flows_max) + 1e-24) + partial_flows_max, - tl.log(tl.exp(partial_flows_max - acc) * partial_flows + 1.0) + acc + acc = tl.where(log_n_fdm_max[None,:] == -float("inf"), + acc, + tl.where(partial_flows_max > acc, + tl.log(partial_flows + tl.exp(acc - partial_flows_max) + 1e-24) + partial_flows_max, + tl.log(tl.exp(partial_flows_max - acc) * partial_flows + 1.0) + acc + ) ) else: acc += partial_flows * tl.exp(emars + log_n_fdm_max[None,:]) @@ -2692,7 +2698,10 @@ def _bk_triton_sparse_par_kernel(node_flows, node_mars, element_mars, params, pa if logspace_flows: plflows = nflows[None,:] + emars - nmars[None,:] plflows_max = tl.max(plflows, axis = 1) - pflows = tl.sum(tl.exp(plflows - plflows_max[:,None]), axis = 1) * tl.exp(plflows_max) + pflows = tl.where(plflows_max != -float("inf"), + tl.exp(tl.log(tl.sum(tl.exp(plflows - plflows_max[:,None]), axis = 1)) + plflows_max), + 0.0 + ) else: pflows = tl.sum(nflows[None,:] * tl.exp(emars - nmars[None,:]), axis = 1)