diff --git a/src/pyjuice/queries/sample.py b/src/pyjuice/queries/sample.py index 0662956..3bb7cda 100644 --- a/src/pyjuice/queries/sample.py +++ b/src/pyjuice/queries/sample.py @@ -11,7 +11,6 @@ from pyjuice.nodes import CircuitNodes from pyjuice.model import TensorCircuit -from pyjuice.utils.kernel_launcher import FastJITFunction @njit @@ -23,68 +22,19 @@ def _assign_cids_ind_target(ind_target, element_pointers, ind_b, num_samples): element_pointers[bid] = ind_t + 1 -@triton.jit -def _assign_nids_ind_target_kernel(ind_target_ptr, ind_ch_count_ptr, node_pointers_ptr, ind_b_ptr, - num_samples, num_nodes, BLOCK_SIZE: tl.constexpr, NUM_BLKS: tl.constexpr): - bid = tl.program_id(0) # The batch ID for this node block - - target_val_sid = tl.load(node_pointers_ptr + bid) - - offsets = tl.arange(0, BLOCK_SIZE) - offset_first = 0 - - for i in range(NUM_BLKS): - mask = (offsets < num_nodes) - - inds_b = tl.load(ind_b_ptr + offsets, mask = mask) - mask_b = (inds_b == bid) - - count_c = tl.load(ind_ch_count_ptr + offsets, mask = mask & mask_b, other = 0) - - cumcount_c = tl.cumsum(count_c, axis = 0) - count_c + target_val_sid - - tl.store(ind_target_ptr + offsets, cumcount_c * num_samples + bid, mask = mask & mask_b) - - last_onehot = ((offsets + 1) == tl.max((offsets + 1) * mask_b.to(tl.int64))).to(tl.int64) - target_val_sid = tl.max(cumcount_c) + tl.sum(count_c * last_onehot) - - offsets += BLOCK_SIZE - - -def _assign_nids_ind_target(ind_target, ind_ch_count, node_pointers, ind_b, num_samples): - """ - A GPU implementation of the following: - - @njit - def _assign_nids_ind_target(ind_target, ind_ch_count, node_pointers, ind_b, num_samples): - for nid in range(ind_target.shape[0]): - bid = ind_b[nid] - ind_t = node_pointers[bid] - ind_target[i] = ind_t * num_samples + bid - node_pointers[bid] = ind_t + ind_ch_count[nid] - """ - - num_nodes = ind_b.size(0) - - BLOCK_SIZE = min(512, triton.next_power_of_2(num_nodes)) - NUM_BLKS = triton.cdiv(num_nodes, BLOCK_SIZE) - - grid = (num_samples,) - - _assign_nids_ind_target_kernel[grid]( - ind_target, - ind_ch_count, - node_pointers, - ind_b, - num_samples, - num_nodes, - BLOCK_SIZE = BLOCK_SIZE, - NUM_BLKS = NUM_BLKS - ) +@njit +def _assign_nids_ind_target(ind_target, ind_target_sid, node_pointers, ind_ch_count, ind_b, num_samples): + nid = 0 + for i in range(ind_target.shape[0]): + if nid < ind_target_sid.shape[0] - 1 and i >= ind_target_sid[nid+1]: + nid += 1 + bid = ind_b[nid] + ind_t = node_pointers[bid] + ind_target[i] = ind_t * num_samples + bid + node_pointers[bid] = ind_t + 1 @triton.jit -# @FastJITFunction def sample_sum_layer_kernel(nids, cids, pids, node_mars, element_mars, params, node_samples, element_samples, ind_target, ind_n, ind_b, seed, block_size: tl.constexpr, batch_size: tl.constexpr, num_edges: tl.constexpr, num_samples: tl.constexpr, num_nblocks: tl.constexpr, BLOCK_S: tl.constexpr, @@ -193,83 +143,20 @@ def sample_sum_layer(layer, nids, cids, pids, node_mars, element_mars, params, n return None -@triton.jit -# @FastJITFunction -def push_non_neg_ones_to_front_kernel(matrix_ptr, counts_ptr, row_count, col_count, - BLOCK_SIZE: tl.constexpr, NUM_BLKS: tl.constexpr): - off_col = tl.program_id(0) - - offs_row = tl.arange(0, BLOCK_SIZE) - - # Target row id - target_row_id = -1 - target_row_id = target_row_id.to(tl.int64) - - for i in range(NUM_BLKS): - mask_row = (offs_row < row_count) - - value = tl.load(matrix_ptr + offs_row * col_count + off_col, mask = mask_row, other = -1) - - mask_val = (value != -1) - - offs_target = tl.cumsum(mask_val.to(tl.int64), axis = 0) + target_row_id - - tl.store(matrix_ptr + offs_target * col_count + off_col, value, mask = mask_row & mask_val) - - offs_row += BLOCK_SIZE - target_row_id += tl.sum(mask_val.to(tl.int64)) - - tl.store(counts_ptr + off_col, target_row_id + 1) - - target_row_id = (target_row_id + 1).to(tl.int32) - - while target_row_id < row_count: - offs_row = tl.arange(0, BLOCK_SIZE) + target_row_id - mask_row = (offs_row < row_count) - - tl.store(matrix_ptr + offs_row * col_count + off_col, -1, mask = mask_row) - - target_row_id += BLOCK_SIZE - - def push_non_neg_ones_to_front(matrix): - """ - An efficient implementation of the following: - def push_non_neg_ones_to_front(matrix): - result = torch.full_like(matrix, -1) + result = torch.full_like(matrix, -1) - s_mask = (matrix != -1) - d_mask = torch.sum(s_mask, dim = 0, keepdims = True) > torch.arange(matrix.size(0)).to(matrix.device)[:,None] + s_mask = (matrix != -1) + d_mask = torch.sum(s_mask, dim = 0, keepdims = True) > torch.arange(matrix.size(0)).to(matrix.device)[:,None] - result[d_mask] = matrix[s_mask] - matrix[:] = result[:] + result[d_mask] = matrix[s_mask] + matrix[:] = result[:] - return s_mask.long().sum(dim = 0) - """ - row_count, col_count = matrix.size() - - counts = torch.zeros([col_count], dtype = torch.long, device = matrix.device) - - BLOCK_SIZE = min(1024, triton.next_power_of_2(row_count)) - NUM_BLKS = triton.cdiv(row_count, BLOCK_SIZE) - - grid = lambda meta: (col_count,) - - push_non_neg_ones_to_front_kernel[grid]( - matrix, - counts, - row_count, - col_count, - BLOCK_SIZE = BLOCK_SIZE, - NUM_BLKS = NUM_BLKS - ) - - return counts + return s_mask.long().sum(dim = 0) @triton.jit -# @FastJITFunction def count_prod_nch_kernel(nids, cids, element_samples, ind_ch_count, ind_nids, ind_nid_offs, ind_mask, ind_n, ind_b, partition_id, block_size: tl.constexpr, num_samples: tl.constexpr, num_nblocks: tl.constexpr, batch_size: tl.constexpr, num_edges: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_C: tl.constexpr, @@ -353,8 +240,7 @@ def count_prod_nch(layer, nids, cids, element_samples, ind_ch_count, ind_nids, i @triton.jit -# @FastJITFunction -def sample_prod_layer_kernel(nids, cids, node_samples, element_samples, ind_target, ind_n, ind_b, +def sample_prod_layer_kernel(nids, cids, node_samples, element_samples, ind_target, ind_target_sid, ind_n, ind_b, ind_nids, ind_nid_offs, ind_mask, partition_id, block_size: tl.constexpr, num_samples: tl.constexpr, num_nblocks: tl.constexpr, batch_size: tl.constexpr, num_edges: tl.constexpr, BLOCK_S: tl.constexpr, BLOCK_C: tl.constexpr, C_NUM_BLKS: tl.constexpr): @@ -383,12 +269,11 @@ def sample_prod_layer_kernel(nids, cids, node_samples, element_samples, ind_targ mask_child = offs_child < num_edges # Main loop over blocks of child nodes - target_id_base = tl.load(ind_target + offs_sample, mask = mask_sample, other = 0) + target_sid = tl.load(ind_target_sid + offs_sample, mask = mask_sample, other = 0) for i in range(C_NUM_BLKS): c_ids = tl.load(cids + local_nids[:,None] * num_edges + offs_child[None,:], mask = (mask_sample[:,None] & mask_child[None,:]), other = 0) - - target_id = target_id_base[:,None] + offs_child[None,:] * num_samples + target_id = tl.load(ind_target + target_sid[:,None] + offs_child[None,:], mask = (mask_sample[:,None] & mask_child[None,:] & (c_ids > 0)), other = 0) tl.store(node_samples + target_id, c_ids + local_nid_offs[:,None], mask = (mask_sample[:,None] & mask_child[None,:] & (c_ids > 0))) @@ -396,7 +281,7 @@ def sample_prod_layer_kernel(nids, cids, node_samples, element_samples, ind_targ mask_child = offs_child < num_edges -def sample_prod_layer(layer, nids, cids, node_samples, element_samples, ind_target, +def sample_prod_layer(layer, nids, cids, node_samples, element_samples, ind_target, ind_target_sid, ind_n, ind_b, ind_nids, ind_nid_offs, ind_mask, block_size, partition_id): num_samples = ind_n.size(0) @@ -404,15 +289,15 @@ def sample_prod_layer(layer, nids, cids, node_samples, element_samples, ind_targ num_edges = cids.size(1) batch_size = node_samples.size(1) - BLOCK_C = min(128, triton.next_power_of_2(num_edges)) - BLOCK_S = min(128 // BLOCK_C, triton.next_power_of_2(num_samples)) + BLOCK_C = min(1024, triton.next_power_of_2(num_edges)) + BLOCK_S = min(1024 // BLOCK_C, triton.next_power_of_2(num_samples)) C_NUM_BLKS = triton.cdiv(num_edges, BLOCK_C) grid = (triton.cdiv(num_samples, BLOCK_S),) sample_prod_layer_kernel[grid]( - nids, cids, node_samples, element_samples, ind_target, ind_n, ind_b, + nids, cids, node_samples, element_samples, ind_target, ind_target_sid, ind_n, ind_b, ind_nids, ind_nid_offs, ind_mask, partition_id, block_size, num_samples, num_nblocks, batch_size, num_edges, BLOCK_S, BLOCK_C, C_NUM_BLKS ) @@ -461,6 +346,7 @@ def sample(pc: TensorCircuit, num_samples: Optional[int] = None, conditional: bo # Iterate over sum layers in the current layer group for layer in layer_group: + # Gather the indices to be processed lsid, leid = layer._layer_nid_range ind_n, ind_b = torch.where((node_samples >= lsid) & (node_samples < leid)) @@ -479,9 +365,6 @@ def sample(pc: TensorCircuit, num_samples: Optional[int] = None, conditional: bo nids = layer.partitioned_nids[partition_id] cids = layer.partitioned_cids[partition_id] pids = layer.partitioned_pids[partition_id] - - if ind_n.size(0) == 0: - import pdb; pdb.set_trace() sample_sum_layer(layer, nids, cids, pids, pc.node_mars, pc.element_mars, pc.params, node_samples, element_samples, ind_target, ind_n, ind_b, @@ -515,27 +398,34 @@ def sample(pc: TensorCircuit, num_samples: Optional[int] = None, conditional: bo ind_nid_offs, ind_mask, ind_n, ind_b, layer.block_size, partition_id) # Pre-compute the target indices in `node_samples` - ind_target = torch.zeros_like(ind_n) - _assign_nids_ind_target(ind_target, ind_ch_count, node_pointers, ind_b, num_samples) + ind_target_sid = np.zeros([ind_n.size(0)], dtype = np.int64) + ind_target_sid[1:] = ind_ch_count[:-1].cumsum(dim = 0).detach().cpu().numpy() + ind_target = np.zeros([ind_ch_count.sum()], dtype = np.int64) + _assign_nids_ind_target(ind_target, ind_target_sid, + node_pointers.detach().cpu().numpy(), + ind_ch_count.detach().cpu().numpy(), + ind_b.detach().cpu().numpy(), num_samples) + ind_target_sid = torch.from_numpy(ind_target_sid).to(pc.device) + ind_target = torch.from_numpy(ind_target).to(pc.device) # Store child indices for partition_id in range(layer.num_fw_partitions): nids = layer.partitioned_nids[partition_id] cids = layer.partitioned_cids[partition_id] - sample_prod_layer(layer, nids, cids, node_samples, element_samples, ind_target, + sample_prod_layer(layer, nids, cids, node_samples, element_samples, ind_target, ind_target_sid, ind_n, ind_b, ind_nids, ind_nid_offs, ind_mask, layer.block_size, partition_id) - if _sample_input_ns: - # Create tensor for the samples - data_dtype = pc.input_layer_group[0].get_data_dtype() - samples = torch.zeros([pc.num_vars, num_samples], dtype = data_dtype, device = pc.device) + # Create tensor for the samples + data_dtype = pc.input_layer_group[0].get_data_dtype() + samples = torch.zeros([pc.num_vars, num_samples], dtype = data_dtype, device = pc.device) - pc._init_buffer(name = "node_flows", shape = (pc.num_nodes, num_samples), set_value = 0.0) - ind_n, ind_b = torch.where(node_samples != -1) - ind_node = node_samples[ind_n, ind_b] - pc.node_flows[ind_node, ind_b] = 1.0 + pc._init_buffer(name = "node_flows", shape = (pc.num_nodes, num_samples), set_value = 0.0) + ind_n, ind_b = torch.where(node_samples != -1) + ind_node = node_samples[ind_n, ind_b] + pc.node_flows[ind_node, ind_b] = 1.0 + if _sample_input_ns: for layer in pc.input_layer_group: seed = random.randint(0, 2**31) layer.sample(samples, pc.node_flows, seed = seed) @@ -543,4 +433,4 @@ def sample(pc: TensorCircuit, num_samples: Optional[int] = None, conditional: bo return samples.permute(1, 0).contiguous() else: # In this case, we do not explicitly sample input nodes - return node_samples + return node_samples \ No newline at end of file