Skip to content

Commit

Permalink
propagate -inf when all inputs are -inf
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Apr 19, 2024
1 parent deef3e9 commit 0da8980
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/pyjuice/layer/sum_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,7 +968,7 @@ def _fw_triton_sparse_kernel(node_mars, element_mars, params, nids, cids, pids,
epars = tl.load(epars_ptr)

if propagation_alg_id == 0:
nmars = tl.log(tl.sum(emars * epars[:,None], axis = 0)) + emars_max
nmars = tl.where(emars_max == -float("inf"), -float("inf"), tl.log(tl.sum(emars * epars[:,None], axis = 0)) + emars_max)

if propagation_alg_id == 1:
nmars = tl.max(emars + tl.log(epars)[:,None], axis = 0)
Expand Down Expand Up @@ -1035,7 +1035,7 @@ def _fw_triton_large_sparse_kernel(node_mars, element_mars, params, nids, cids,

# Compute sum node marginals
if propagation_alg_id == 0:
nmars = tl.log(tl.sum(emars * epars[:,:,None], axis = 1)) + emars_max
nmars = tl.where(emars_max == -float("inf"), -float("inf"), tl.log(tl.sum(emars * epars[:,:,None], axis = 1)) + emars_max)

if propagation_alg_id == 1:
nmars = tl.max(emars + tl.log(epars)[:,:,None], axis = 1)
Expand Down

0 comments on commit 0da8980

Please sign in to comment.