From c6bdde2f33175b194fe8bae521507357ee6abad4 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Thu, 8 Aug 2024 20:38:16 +0000 Subject: [PATCH] fix and speedup `juice.queries.sample` --- src/pyjuice/queries/sample.py | 253 +++++++++++++++++++++++++++------- tests/queries/sample_test.py | 22 ++- 2 files changed, 224 insertions(+), 51 deletions(-) diff --git a/src/pyjuice/queries/sample.py b/src/pyjuice/queries/sample.py index 697414a4..7fcade47 100644 --- a/src/pyjuice/queries/sample.py +++ b/src/pyjuice/queries/sample.py @@ -22,11 +22,23 @@ def _assign_cids_ind_target(ind_target, element_pointers, ind_b, num_samples): element_pointers[bid] = ind_t + 1 +@njit +def _assign_nids_ind_target(ind_target, ind_target_sid, node_pointers, ind_ch_count, ind_b, num_samples): + nid = 0 + for i in range(ind_target.shape[0]): + if nid < ind_target_sid.shape[0] - 1 and i >= ind_target_sid[nid+1]: + nid += 1 + bid = ind_b[nid] + ind_t = node_pointers[bid] + ind_target[i] = ind_t * num_samples + bid + node_pointers[bid] = ind_t + 1 + + @triton.jit def sample_sum_layer_kernel(nids, cids, pids, node_mars, element_mars, params, node_samples, element_samples, ind_target, ind_n, ind_b, seed, block_size: tl.constexpr, batch_size: tl.constexpr, num_edges: tl.constexpr, num_samples: tl.constexpr, num_nblocks: tl.constexpr, BLOCK_S: tl.constexpr, - BLOCK_M: tl.constexpr, TILE_SIZE_M: tl.constexpr, BLOCK_K: tl.constexpr, TILE_SIZE_K: tl.constexpr, + BLOCK_M: tl.constexpr, M_NUM_BLKS: tl.constexpr, BLOCK_K: tl.constexpr, K_NUM_BLKS: tl.constexpr, conditional: tl.constexpr): pid_s = tl.program_id(0) # ID of size-`BLOCK_S` batches @@ -38,13 +50,13 @@ def sample_sum_layer_kernel(nids, cids, pids, node_mars, element_mars, params, n # Load node and batch ids node_sample_id = tl.load(ind_n + offs_sample, mask = mask_sample, other = 0) batch_id = tl.load(ind_b + offs_sample, mask = mask_sample, other = 0) - node_id = tl.load(node_samples + node_sample_id * batch_size) + node_id = tl.load(node_samples + node_sample_id * batch_size + batch_id) # Locate node ids in `nids` offs_nids = tl.arange(0, BLOCK_M) local_nids = tl.zeros([BLOCK_S], dtype = tl.int64) - 1 local_nid_offs = tl.zeros([BLOCK_S], dtype = tl.int64) - for i in range(TILE_SIZE_M): + for i in range(M_NUM_BLKS): mask_nids = offs_nids < num_nblocks ref_nid = tl.load(nids + offs_nids, mask = mask_nids, other = 0) @@ -69,22 +81,22 @@ def sample_sum_layer_kernel(nids, cids, pids, node_mars, element_mars, params, n mask_child = offs_child < num_edges if conditional: - nmars = tl.load(node_mars + node_id, mask = mask_sample, other = 0.0) + nmars = tl.load(node_mars + node_id * batch_size + batch_id, mask = mask_sample, other = 0.0) # [Block_B] # Main loop over blocks of child nodes chids = tl.zeros([BLOCK_S], dtype = tl.int64) - 1 - for i in range(TILE_SIZE_K): + for i in range(K_NUM_BLKS): # Load parameters param_id = tl.load(pids + local_nids[None,:] * num_edges + offs_child[:,None], mask = (mask_sample[None,:] & mask_child[:,None]), other = 0) - epars = tl.load(params + param_id + local_nid_offs[None,:], mask = (mask_sample[None,:] & mask_child[:,None]), other = 0.0) + epars = tl.load(params + param_id + local_nid_offs[None,:], mask = (mask_sample[None,:] & mask_child[:,None]), other = 0.0) # [BLOCK_K, BLOCK_B] if conditional: # In this case, we use `param * cmar / nmar` as the "parameter" emars_id = tl.load(cids + local_nids[None,:] * num_edges + offs_child[:,None], mask = (mask_sample[None,:] & mask_child[:,None]), other = 0) - emars = tl.load(params + emars_id, mask = (mask_sample[None,:] & mask_child[:,None]), other = 0.0) + emars = tl.load(element_mars + emars_id * batch_size + batch_id, mask = (mask_sample[None,:] & mask_child[:,None]), other = 0.0) - epars = epars * tl.exp(emars - nmars[None,:]) + epars = epars * tl.exp(emars - nmars[None,:]) # [BLOCK_K, BLOCK_B] cum_probs = tl.cumsum(epars, axis = 0) # [BLOCK_K, BLOCK_S] local_chids = tl.sum((rnd_val[None,:] >= cum_probs).to(tl.int64), axis = 0) # [BLOCK_S] @@ -117,15 +129,15 @@ def sample_sum_layer(layer, nids, cids, pids, node_mars, element_mars, params, n BLOCK_M = min(1024 // BLOCK_S, triton.next_power_of_2(num_nblocks)) BLOCK_K = min(1024 // BLOCK_S, triton.next_power_of_2(num_edges)) - TILE_SIZE_M = triton.cdiv(num_nblocks, BLOCK_M) - TILE_SIZE_K = triton.cdiv(num_edges, BLOCK_K) + M_NUM_BLKS = triton.cdiv(num_nblocks, BLOCK_M) + K_NUM_BLKS = triton.cdiv(num_edges, BLOCK_K) grid = (triton.cdiv(num_samples, BLOCK_S),) sample_sum_layer_kernel[grid]( nids, cids, pids, node_mars, element_mars, params, node_samples, element_samples, ind_target, ind_n, ind_b, seed, block_size, batch_size, num_edges, num_samples, num_nblocks, - BLOCK_S, BLOCK_M, TILE_SIZE_M, BLOCK_K, TILE_SIZE_K, conditional + BLOCK_S, BLOCK_M, M_NUM_BLKS, BLOCK_K, K_NUM_BLKS, conditional ) return None @@ -144,6 +156,153 @@ def push_non_neg_ones_to_front(matrix): return s_mask.long().sum(dim = 0) +@triton.jit +def count_prod_nch_kernel(nids, cids, element_samples, ind_ch_count, ind_nids, ind_nid_offs, ind_mask, ind_n, ind_b, partition_id, + block_size: tl.constexpr, num_samples: tl.constexpr, num_nblocks: tl.constexpr, + batch_size: tl.constexpr, num_edges: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_C: tl.constexpr, + BLOCK_S: tl.constexpr, M_NUM_BLKS: tl.constexpr, C_NUM_BLKS: tl.constexpr): + + pid_s = tl.program_id(0) # ID of size-`BLOCK_S` batches + + # Sample offsets and mask + offs_sample = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S + mask_sample = offs_sample < num_samples + + # Load node and batch ids + node_sample_id = tl.load(ind_n + offs_sample, mask = mask_sample, other = 0) + batch_id = tl.load(ind_b + offs_sample, mask = mask_sample, other = 0) + ele_id = tl.load(element_samples + node_sample_id * batch_size + batch_id) + + # Locate node ids in `nids` + offs_nids = tl.arange(0, BLOCK_M) + local_nids = tl.zeros([BLOCK_S], dtype = tl.int64) - 1 + local_nid_offs = tl.zeros([BLOCK_S], dtype = tl.int64) + for i in range(M_NUM_BLKS): + mask_nids = offs_nids < num_nblocks + + ref_nid = tl.load(nids + offs_nids, mask = mask_nids, other = 0) + is_match = (ele_id[:,None] >= ref_nid[None,:]) & (ele_id[:,None] < ref_nid[None,:] + block_size) + + match_local_id = tl.sum(is_match * (offs_nids[None,:] + 1), axis = 1) + match_local_offset = tl.sum(is_match * (ele_id[:,None] - ref_nid[None,:]), axis = 1) + + local_nids = tl.where(match_local_id > 0, match_local_id - 1, local_nids) + local_nid_offs = tl.where(match_local_id > 0, match_local_offset, local_nid_offs) + + offs_nids += BLOCK_M + + # Store `local_nids` and `local_nid_offs` for future reuse + mask_sample = mask_sample & (local_nids >= 0) + tl.store(ind_nids + offs_sample, local_nids, mask = mask_sample) + tl.store(ind_nid_offs + offs_sample, local_nid_offs, mask = mask_sample) + tl.store(ind_mask + offs_sample, partition_id, mask = mask_sample) + + # Offset for children + offs_child = tl.arange(0, BLOCK_C) + mask_child = offs_child < num_edges + + # Main loop over blocks of child nodes + ch_count = tl.zeros([BLOCK_S], dtype = tl.int64) + for i in range(C_NUM_BLKS): + + c_ids = tl.load(cids + local_nids[:,None] * num_edges + offs_child[None,:], mask = (mask_sample[:,None] & mask_child[None,:]), other = 0) + ch_count += tl.sum((c_ids > 0).to(tl.int64), axis = 1) + + offs_child += BLOCK_C + mask_child = offs_child < num_edges + + # Store `ch_count` + tl.store(ind_ch_count + offs_sample, ch_count, mask = mask_sample) + + +def count_prod_nch(layer, nids, cids, element_samples, ind_ch_count, ind_nids, ind_nid_offs, ind_mask, ind_n, ind_b, block_size, partition_id): + + num_samples = ind_n.size(0) + num_nblocks = nids.size(0) + batch_size = element_samples.size(1) + num_edges = cids.size(1) + + BLOCK_C = min(1024, triton.next_power_of_2(num_edges)) + BLOCK_S = min(1024 // BLOCK_C, triton.next_power_of_2(num_samples)) + BLOCK_M = min(1024 // BLOCK_S, triton.next_power_of_2(num_nblocks)) + + M_NUM_BLKS = triton.cdiv(num_nblocks, BLOCK_M) + C_NUM_BLKS = triton.cdiv(num_edges, BLOCK_C) + + grid = (triton.cdiv(num_samples, BLOCK_S),) + + count_prod_nch_kernel[grid]( + nids, cids, element_samples, ind_ch_count, ind_nids, ind_nid_offs, ind_mask, ind_n, ind_b, partition_id, + block_size, num_samples, num_nblocks, batch_size, num_edges, BLOCK_M, BLOCK_C, BLOCK_S, M_NUM_BLKS, C_NUM_BLKS + ) + + return None + + +@triton.jit +def sample_prod_layer_kernel(nids, cids, node_samples, element_samples, ind_target, ind_target_sid, ind_n, ind_b, + ind_nids, ind_nid_offs, ind_mask, partition_id, block_size: tl.constexpr, + num_samples: tl.constexpr, num_nblocks: tl.constexpr, batch_size: tl.constexpr, num_edges: tl.constexpr, + BLOCK_S: tl.constexpr, BLOCK_C: tl.constexpr, C_NUM_BLKS: tl.constexpr): + + pid_s = tl.program_id(0) # ID of size-`BLOCK_S` batches + + # Sample offsets and mask + offs_sample = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S + mask_sample = offs_sample < num_samples + + # Load node and batch ids + node_sample_id = tl.load(ind_n + offs_sample, mask = mask_sample, other = 0) + batch_id = tl.load(ind_b + offs_sample, mask = mask_sample, other = 0) + ele_id = tl.load(element_samples + node_sample_id * batch_size + batch_id) + + # Load offsets of `nids` and the node offsets + local_nids = tl.load(ind_nids + offs_sample, mask = mask_sample, other = 0) + local_nid_offs = tl.load(ind_nid_offs + offs_sample, mask = mask_sample, other = 0) + local_partition_id = tl.load(ind_mask + offs_sample, mask = mask_sample, other = 0) + + # Update sample mask + mask_sample = mask_sample & (local_partition_id == partition_id) + + # Offset for children + offs_child = tl.arange(0, BLOCK_C) + mask_child = offs_child < num_edges + + # Main loop over blocks of child nodes + target_sid = tl.load(ind_target_sid + offs_sample, mask = mask_sample, other = 0) + for i in range(C_NUM_BLKS): + + c_ids = tl.load(cids + local_nids[:,None] * num_edges + offs_child[None,:], mask = (mask_sample[:,None] & mask_child[None,:]), other = 0) + target_id = tl.load(ind_target + target_sid[:,None] + offs_child[None,:], mask = (mask_sample[:,None] & mask_child[None,:] & (c_ids > 0)), other = 0) + + tl.store(node_samples + target_id, c_ids + local_nid_offs[:,None], mask = (mask_sample[:,None] & mask_child[None,:] & (c_ids > 0))) + + offs_child += BLOCK_C + mask_child = offs_child < num_edges + + +def sample_prod_layer(layer, nids, cids, node_samples, element_samples, ind_target, ind_target_sid, + ind_n, ind_b, ind_nids, ind_nid_offs, ind_mask, block_size, partition_id): + + num_samples = ind_n.size(0) + num_nblocks = nids.size(0) + num_edges = cids.size(1) + batch_size = node_samples.size(1) + + BLOCK_C = min(1024, triton.next_power_of_2(num_edges)) + BLOCK_S = min(1024 // BLOCK_C, triton.next_power_of_2(num_samples)) + + C_NUM_BLKS = triton.cdiv(num_edges, BLOCK_C) + + grid = (triton.cdiv(num_samples, BLOCK_S),) + + sample_prod_layer_kernel[grid]( + nids, cids, node_samples, element_samples, ind_target, ind_target_sid, ind_n, ind_b, + ind_nids, ind_nid_offs, ind_mask, partition_id, block_size, num_samples, + num_nblocks, batch_size, num_edges, BLOCK_S, BLOCK_C, C_NUM_BLKS + ) + + def sample(pc: TensorCircuit, num_samples: Optional[int] = None, conditional: bool = False): if not conditional: assert num_samples is not None, "`num_samples` should be specified when doing unconditioned sampling." @@ -156,18 +315,21 @@ def sample(pc: TensorCircuit, num_samples: Optional[int] = None, conditional: bo num_nscopes = 0 num_escopes = 0 for layer_group in pc.layers(ret_layer_groups = True): + curr_scopes = 0 + for layer in layer_group: + curr_scopes += len(layer.scopes) + if layer_group.is_input() or layer_group.is_sum(): - for layer in layer_group: - num_nscopes += len(layer.scopes) + num_nscopes += curr_scopes else: assert layer_group.is_prod() - curr_escopes = 0 - for layer in layer_group: - curr_escopes += len(layer.scopes) - num_escopes = max(num_escopes, curr_escopes) + num_escopes = max(num_escopes, curr_scopes) - node_samples = torch.zeros([num_nscopes * 2, num_samples], dtype = torch.long, device = pc.device) + # Stores selected node indices by the sampler + node_samples = torch.zeros([num_nscopes, num_samples], dtype = torch.long, device = pc.device) + # Stores selected element indices by the sampler element_samples = torch.zeros([num_escopes, num_samples], dtype = torch.long, device = pc.device) + # Pointers indicating how many elements are used in each column of `element_samples` element_pointers = np.zeros([num_samples], dtype = np.int64) # Initialize pointers to the root node @@ -221,43 +383,38 @@ def sample(pc: TensorCircuit, num_samples: Optional[int] = None, conditional: bo # Gather the indices to be processed lsid, leid = layer._layer_nid_range - mask = (element_samples >= lsid) & (element_samples < leid) - ind_n, ind_b = torch.where(mask) - global_cids = element_samples[ind_n, ind_b] + ind_n, ind_b = torch.where((element_samples >= lsid) & (element_samples < leid)) - # Get child indices + # Get the number of children for the selected sample indices + ind_ch_count = torch.zeros_like(ind_n) + ind_nids = torch.zeros_like(ind_n) + ind_nid_offs = torch.zeros_like(ind_n) + ind_mask = torch.zeros_like(ind_n) for partition_id in range(layer.num_fw_partitions): nids = layer.partitioned_nids[partition_id] cids = layer.partitioned_cids[partition_id] - is_match = (global_cids[:,None] >= nids[None,:]) & (global_cids[:,None] < nids[None,:] + layer.block_size) - local_nids = (is_match * torch.arange(1, nids.size(0) + 1, device = pc.device)[None,:] - 1).sum(dim = 1) - local_nid_offsets = ((global_cids[:,None] - nids[None,:]) * is_match.long()).sum(dim = 1) - - target_nids = cids[local_nids,:] + local_nid_offsets[:,None] - target_cids = cids[local_nids,:] - - target_idx = node_pointers[ind_b] + (torch.cumsum(mask, dim = 0)[ind_n, ind_b] - 1) * cids.size(1) - - if target_idx.max() + cids.size(1) > node_samples.size(0): - - node_samples_new = torch.zeros([target_idx.max() + cids.size(1), num_samples], dtype = torch.long, device = pc.device) - node_samples_new[:,:] = -1 - - node_samples_new[:node_samples.size(0),:] = node_samples - node_samples = node_samples_new - - match_filter = is_match.any(dim = 1) - target_nids = target_nids[match_filter,:] - target_cids = target_cids[match_filter,:] - target_idx = target_idx[match_filter] - target_b = ind_b[match_filter] + count_prod_nch(layer, nids, cids, element_samples, ind_ch_count, ind_nids, + ind_nid_offs, ind_mask, ind_n, ind_b, layer.block_size, partition_id) + + # Pre-compute the target indices in `node_samples` + ind_target_sid = np.zeros([ind_n.size(0)], dtype = np.int64) + ind_target_sid[1:] = ind_ch_count[:-1].cumsum(dim = 0).detach().cpu().numpy() + ind_target = np.zeros([ind_ch_count.sum()], dtype = np.int64) + _assign_nids_ind_target(ind_target, ind_target_sid, + node_pointers.detach().cpu().numpy(), + ind_ch_count.detach().cpu().numpy(), + ind_b.detach().cpu().numpy(), num_samples) + ind_target_sid = torch.from_numpy(ind_target_sid).to(pc.device) + ind_target = torch.from_numpy(ind_target).to(pc.device) - for i in range(cids.size(1)): - cmask = target_cids[:,i] != 0 - node_samples[target_idx[cmask]+i, target_b[cmask]] = target_nids[cmask,i] + # Store child indices + for partition_id in range(layer.num_fw_partitions): + nids = layer.partitioned_nids[partition_id] + cids = layer.partitioned_cids[partition_id] - node_pointers = (node_samples != -1).sum(dim = 0) + sample_prod_layer(layer, nids, cids, node_samples, element_samples, ind_target, ind_target_sid, + ind_n, ind_b, ind_nids, ind_nid_offs, ind_mask, layer.block_size, partition_id) # Create tensor for the samples data_dtype = pc.input_layer_group[0].get_data_dtype() diff --git a/tests/queries/sample_test.py b/tests/queries/sample_test.py index 74552017..d47f884f 100644 --- a/tests/queries/sample_test.py +++ b/tests/queries/sample_test.py @@ -72,8 +72,8 @@ def test_sample_correctness(): samples = juice.queries.sample(pc, num_samples = 512) - assert samples[:,0].float().mean() < 0.4 - assert samples[:,1].float().mean() > 0.6 + assert samples[:,0].float().mean() > 0.6 + assert samples[:,1].float().mean() < 0.4 def test_sample_hclt(): @@ -118,8 +118,24 @@ def test_sample_hclt(): assert ((samples >= 0) & (samples < 256)).all() +def test_sample_hmm(): + + device = torch.device("cuda:0") + + ns = juice.structures.HMM(seq_length = 32, num_latents = 256, num_emits = 100, homogeneous = True, block_size = 64) + ns.init_parameters(perturbation = 2.0) + + pc = juice.compile(ns) + pc.to(device) + + samples = juice.queries.sample(pc, num_samples = 16) + + assert ((samples >= 0) & (samples < 100)).all() + + if __name__ == "__main__": torch.set_num_threads(4) test_sample() test_sample_correctness() - test_sample_hclt() \ No newline at end of file + test_sample_hclt() + test_sample_hmm() \ No newline at end of file