Skip to content

Commit

Permalink
fix categorical missing flow kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Aug 12, 2024
1 parent 00a7ab7 commit f78ec57
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/pyjuice/nodes/distributions/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def bk_flow_mask_fn(local_offsets, ns_offsets, data, flows, node_mars_ptr, param
pf_offsets = s_pfids[:,None] + cat_ids[None,:]
tl.atomic_add(param_flows_ptr + pf_offsets, flows[:,None] * param, mask = cat_mask)

cat_ids += TILE_SIZE_K

@staticmethod
def sample_fn(samples_ptr, local_offsets, batch_offsets, vids, s_pids, params_ptr, metadata_ptr, s_mids_ptr, mask, batch_size, BLOCK_SIZE, seed):
# Get `num_cats` from `metadata`
Expand Down

0 comments on commit f78ec57

Please sign in to comment.