Skip to content

Commit

Permalink
fix and speedup juice.queries.sample
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Aug 8, 2024
1 parent d231c84 commit c6bdde2
Show file tree
Hide file tree
Showing 2 changed files with 224 additions and 51 deletions.
253 changes: 205 additions & 48 deletions src/pyjuice/queries/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,23 @@ def _assign_cids_ind_target(ind_target, element_pointers, ind_b, num_samples):
element_pointers[bid] = ind_t + 1


@njit
def _assign_nids_ind_target(ind_target, ind_target_sid, node_pointers, ind_ch_count, ind_b, num_samples):
nid = 0
for i in range(ind_target.shape[0]):
if nid < ind_target_sid.shape[0] - 1 and i >= ind_target_sid[nid+1]:
nid += 1
bid = ind_b[nid]
ind_t = node_pointers[bid]
ind_target[i] = ind_t * num_samples + bid
node_pointers[bid] = ind_t + 1


@triton.jit
def sample_sum_layer_kernel(nids, cids, pids, node_mars, element_mars, params, node_samples, element_samples,
ind_target, ind_n, ind_b, seed, block_size: tl.constexpr, batch_size: tl.constexpr,
num_edges: tl.constexpr, num_samples: tl.constexpr, num_nblocks: tl.constexpr, BLOCK_S: tl.constexpr,
BLOCK_M: tl.constexpr, TILE_SIZE_M: tl.constexpr, BLOCK_K: tl.constexpr, TILE_SIZE_K: tl.constexpr,
BLOCK_M: tl.constexpr, M_NUM_BLKS: tl.constexpr, BLOCK_K: tl.constexpr, K_NUM_BLKS: tl.constexpr,
conditional: tl.constexpr):

pid_s = tl.program_id(0) # ID of size-`BLOCK_S` batches
Expand All @@ -38,13 +50,13 @@ def sample_sum_layer_kernel(nids, cids, pids, node_mars, element_mars, params, n
# Load node and batch ids
node_sample_id = tl.load(ind_n + offs_sample, mask = mask_sample, other = 0)
batch_id = tl.load(ind_b + offs_sample, mask = mask_sample, other = 0)
node_id = tl.load(node_samples + node_sample_id * batch_size)
node_id = tl.load(node_samples + node_sample_id * batch_size + batch_id)

# Locate node ids in `nids`
offs_nids = tl.arange(0, BLOCK_M)
local_nids = tl.zeros([BLOCK_S], dtype = tl.int64) - 1
local_nid_offs = tl.zeros([BLOCK_S], dtype = tl.int64)
for i in range(TILE_SIZE_M):
for i in range(M_NUM_BLKS):
mask_nids = offs_nids < num_nblocks

ref_nid = tl.load(nids + offs_nids, mask = mask_nids, other = 0)
Expand All @@ -69,22 +81,22 @@ def sample_sum_layer_kernel(nids, cids, pids, node_mars, element_mars, params, n
mask_child = offs_child < num_edges

if conditional:
nmars = tl.load(node_mars + node_id, mask = mask_sample, other = 0.0)
nmars = tl.load(node_mars + node_id * batch_size + batch_id, mask = mask_sample, other = 0.0) # [Block_B]

# Main loop over blocks of child nodes
chids = tl.zeros([BLOCK_S], dtype = tl.int64) - 1
for i in range(TILE_SIZE_K):
for i in range(K_NUM_BLKS):

# Load parameters
param_id = tl.load(pids + local_nids[None,:] * num_edges + offs_child[:,None], mask = (mask_sample[None,:] & mask_child[:,None]), other = 0)
epars = tl.load(params + param_id + local_nid_offs[None,:], mask = (mask_sample[None,:] & mask_child[:,None]), other = 0.0)
epars = tl.load(params + param_id + local_nid_offs[None,:], mask = (mask_sample[None,:] & mask_child[:,None]), other = 0.0) # [BLOCK_K, BLOCK_B]

if conditional:
# In this case, we use `param * cmar / nmar` as the "parameter"
emars_id = tl.load(cids + local_nids[None,:] * num_edges + offs_child[:,None], mask = (mask_sample[None,:] & mask_child[:,None]), other = 0)
emars = tl.load(params + emars_id, mask = (mask_sample[None,:] & mask_child[:,None]), other = 0.0)
emars = tl.load(element_mars + emars_id * batch_size + batch_id, mask = (mask_sample[None,:] & mask_child[:,None]), other = 0.0)

epars = epars * tl.exp(emars - nmars[None,:])
epars = epars * tl.exp(emars - nmars[None,:]) # [BLOCK_K, BLOCK_B]

cum_probs = tl.cumsum(epars, axis = 0) # [BLOCK_K, BLOCK_S]
local_chids = tl.sum((rnd_val[None,:] >= cum_probs).to(tl.int64), axis = 0) # [BLOCK_S]
Expand Down Expand Up @@ -117,15 +129,15 @@ def sample_sum_layer(layer, nids, cids, pids, node_mars, element_mars, params, n
BLOCK_M = min(1024 // BLOCK_S, triton.next_power_of_2(num_nblocks))
BLOCK_K = min(1024 // BLOCK_S, triton.next_power_of_2(num_edges))

TILE_SIZE_M = triton.cdiv(num_nblocks, BLOCK_M)
TILE_SIZE_K = triton.cdiv(num_edges, BLOCK_K)
M_NUM_BLKS = triton.cdiv(num_nblocks, BLOCK_M)
K_NUM_BLKS = triton.cdiv(num_edges, BLOCK_K)

grid = (triton.cdiv(num_samples, BLOCK_S),)

sample_sum_layer_kernel[grid](
nids, cids, pids, node_mars, element_mars, params, node_samples, element_samples,
ind_target, ind_n, ind_b, seed, block_size, batch_size, num_edges, num_samples, num_nblocks,
BLOCK_S, BLOCK_M, TILE_SIZE_M, BLOCK_K, TILE_SIZE_K, conditional
BLOCK_S, BLOCK_M, M_NUM_BLKS, BLOCK_K, K_NUM_BLKS, conditional
)

return None
Expand All @@ -144,6 +156,153 @@ def push_non_neg_ones_to_front(matrix):
return s_mask.long().sum(dim = 0)


@triton.jit
def count_prod_nch_kernel(nids, cids, element_samples, ind_ch_count, ind_nids, ind_nid_offs, ind_mask, ind_n, ind_b, partition_id,
block_size: tl.constexpr, num_samples: tl.constexpr, num_nblocks: tl.constexpr,
batch_size: tl.constexpr, num_edges: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_C: tl.constexpr,
BLOCK_S: tl.constexpr, M_NUM_BLKS: tl.constexpr, C_NUM_BLKS: tl.constexpr):

pid_s = tl.program_id(0) # ID of size-`BLOCK_S` batches

# Sample offsets and mask
offs_sample = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
mask_sample = offs_sample < num_samples

# Load node and batch ids
node_sample_id = tl.load(ind_n + offs_sample, mask = mask_sample, other = 0)
batch_id = tl.load(ind_b + offs_sample, mask = mask_sample, other = 0)
ele_id = tl.load(element_samples + node_sample_id * batch_size + batch_id)

# Locate node ids in `nids`
offs_nids = tl.arange(0, BLOCK_M)
local_nids = tl.zeros([BLOCK_S], dtype = tl.int64) - 1
local_nid_offs = tl.zeros([BLOCK_S], dtype = tl.int64)
for i in range(M_NUM_BLKS):
mask_nids = offs_nids < num_nblocks

ref_nid = tl.load(nids + offs_nids, mask = mask_nids, other = 0)
is_match = (ele_id[:,None] >= ref_nid[None,:]) & (ele_id[:,None] < ref_nid[None,:] + block_size)

match_local_id = tl.sum(is_match * (offs_nids[None,:] + 1), axis = 1)
match_local_offset = tl.sum(is_match * (ele_id[:,None] - ref_nid[None,:]), axis = 1)

local_nids = tl.where(match_local_id > 0, match_local_id - 1, local_nids)
local_nid_offs = tl.where(match_local_id > 0, match_local_offset, local_nid_offs)

offs_nids += BLOCK_M

# Store `local_nids` and `local_nid_offs` for future reuse
mask_sample = mask_sample & (local_nids >= 0)
tl.store(ind_nids + offs_sample, local_nids, mask = mask_sample)
tl.store(ind_nid_offs + offs_sample, local_nid_offs, mask = mask_sample)
tl.store(ind_mask + offs_sample, partition_id, mask = mask_sample)

# Offset for children
offs_child = tl.arange(0, BLOCK_C)
mask_child = offs_child < num_edges

# Main loop over blocks of child nodes
ch_count = tl.zeros([BLOCK_S], dtype = tl.int64)
for i in range(C_NUM_BLKS):

c_ids = tl.load(cids + local_nids[:,None] * num_edges + offs_child[None,:], mask = (mask_sample[:,None] & mask_child[None,:]), other = 0)
ch_count += tl.sum((c_ids > 0).to(tl.int64), axis = 1)

offs_child += BLOCK_C
mask_child = offs_child < num_edges

# Store `ch_count`
tl.store(ind_ch_count + offs_sample, ch_count, mask = mask_sample)


def count_prod_nch(layer, nids, cids, element_samples, ind_ch_count, ind_nids, ind_nid_offs, ind_mask, ind_n, ind_b, block_size, partition_id):

num_samples = ind_n.size(0)
num_nblocks = nids.size(0)
batch_size = element_samples.size(1)
num_edges = cids.size(1)

BLOCK_C = min(1024, triton.next_power_of_2(num_edges))
BLOCK_S = min(1024 // BLOCK_C, triton.next_power_of_2(num_samples))
BLOCK_M = min(1024 // BLOCK_S, triton.next_power_of_2(num_nblocks))

M_NUM_BLKS = triton.cdiv(num_nblocks, BLOCK_M)
C_NUM_BLKS = triton.cdiv(num_edges, BLOCK_C)

grid = (triton.cdiv(num_samples, BLOCK_S),)

count_prod_nch_kernel[grid](
nids, cids, element_samples, ind_ch_count, ind_nids, ind_nid_offs, ind_mask, ind_n, ind_b, partition_id,
block_size, num_samples, num_nblocks, batch_size, num_edges, BLOCK_M, BLOCK_C, BLOCK_S, M_NUM_BLKS, C_NUM_BLKS
)

return None


@triton.jit
def sample_prod_layer_kernel(nids, cids, node_samples, element_samples, ind_target, ind_target_sid, ind_n, ind_b,
ind_nids, ind_nid_offs, ind_mask, partition_id, block_size: tl.constexpr,
num_samples: tl.constexpr, num_nblocks: tl.constexpr, batch_size: tl.constexpr, num_edges: tl.constexpr,
BLOCK_S: tl.constexpr, BLOCK_C: tl.constexpr, C_NUM_BLKS: tl.constexpr):

pid_s = tl.program_id(0) # ID of size-`BLOCK_S` batches

# Sample offsets and mask
offs_sample = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
mask_sample = offs_sample < num_samples

# Load node and batch ids
node_sample_id = tl.load(ind_n + offs_sample, mask = mask_sample, other = 0)
batch_id = tl.load(ind_b + offs_sample, mask = mask_sample, other = 0)
ele_id = tl.load(element_samples + node_sample_id * batch_size + batch_id)

# Load offsets of `nids` and the node offsets
local_nids = tl.load(ind_nids + offs_sample, mask = mask_sample, other = 0)
local_nid_offs = tl.load(ind_nid_offs + offs_sample, mask = mask_sample, other = 0)
local_partition_id = tl.load(ind_mask + offs_sample, mask = mask_sample, other = 0)

# Update sample mask
mask_sample = mask_sample & (local_partition_id == partition_id)

# Offset for children
offs_child = tl.arange(0, BLOCK_C)
mask_child = offs_child < num_edges

# Main loop over blocks of child nodes
target_sid = tl.load(ind_target_sid + offs_sample, mask = mask_sample, other = 0)
for i in range(C_NUM_BLKS):

c_ids = tl.load(cids + local_nids[:,None] * num_edges + offs_child[None,:], mask = (mask_sample[:,None] & mask_child[None,:]), other = 0)
target_id = tl.load(ind_target + target_sid[:,None] + offs_child[None,:], mask = (mask_sample[:,None] & mask_child[None,:] & (c_ids > 0)), other = 0)

tl.store(node_samples + target_id, c_ids + local_nid_offs[:,None], mask = (mask_sample[:,None] & mask_child[None,:] & (c_ids > 0)))

offs_child += BLOCK_C
mask_child = offs_child < num_edges


def sample_prod_layer(layer, nids, cids, node_samples, element_samples, ind_target, ind_target_sid,
ind_n, ind_b, ind_nids, ind_nid_offs, ind_mask, block_size, partition_id):

num_samples = ind_n.size(0)
num_nblocks = nids.size(0)
num_edges = cids.size(1)
batch_size = node_samples.size(1)

BLOCK_C = min(1024, triton.next_power_of_2(num_edges))
BLOCK_S = min(1024 // BLOCK_C, triton.next_power_of_2(num_samples))

C_NUM_BLKS = triton.cdiv(num_edges, BLOCK_C)

grid = (triton.cdiv(num_samples, BLOCK_S),)

sample_prod_layer_kernel[grid](
nids, cids, node_samples, element_samples, ind_target, ind_target_sid, ind_n, ind_b,
ind_nids, ind_nid_offs, ind_mask, partition_id, block_size, num_samples,
num_nblocks, batch_size, num_edges, BLOCK_S, BLOCK_C, C_NUM_BLKS
)


def sample(pc: TensorCircuit, num_samples: Optional[int] = None, conditional: bool = False):
if not conditional:
assert num_samples is not None, "`num_samples` should be specified when doing unconditioned sampling."
Expand All @@ -156,18 +315,21 @@ def sample(pc: TensorCircuit, num_samples: Optional[int] = None, conditional: bo
num_nscopes = 0
num_escopes = 0
for layer_group in pc.layers(ret_layer_groups = True):
curr_scopes = 0
for layer in layer_group:
curr_scopes += len(layer.scopes)

if layer_group.is_input() or layer_group.is_sum():
for layer in layer_group:
num_nscopes += len(layer.scopes)
num_nscopes += curr_scopes
else:
assert layer_group.is_prod()
curr_escopes = 0
for layer in layer_group:
curr_escopes += len(layer.scopes)
num_escopes = max(num_escopes, curr_escopes)
num_escopes = max(num_escopes, curr_scopes)

node_samples = torch.zeros([num_nscopes * 2, num_samples], dtype = torch.long, device = pc.device)
# Stores selected node indices by the sampler
node_samples = torch.zeros([num_nscopes, num_samples], dtype = torch.long, device = pc.device)
# Stores selected element indices by the sampler
element_samples = torch.zeros([num_escopes, num_samples], dtype = torch.long, device = pc.device)
# Pointers indicating how many elements are used in each column of `element_samples`
element_pointers = np.zeros([num_samples], dtype = np.int64)

# Initialize pointers to the root node
Expand Down Expand Up @@ -221,43 +383,38 @@ def sample(pc: TensorCircuit, num_samples: Optional[int] = None, conditional: bo

# Gather the indices to be processed
lsid, leid = layer._layer_nid_range
mask = (element_samples >= lsid) & (element_samples < leid)
ind_n, ind_b = torch.where(mask)
global_cids = element_samples[ind_n, ind_b]
ind_n, ind_b = torch.where((element_samples >= lsid) & (element_samples < leid))

# Get child indices
# Get the number of children for the selected sample indices
ind_ch_count = torch.zeros_like(ind_n)
ind_nids = torch.zeros_like(ind_n)
ind_nid_offs = torch.zeros_like(ind_n)
ind_mask = torch.zeros_like(ind_n)
for partition_id in range(layer.num_fw_partitions):
nids = layer.partitioned_nids[partition_id]
cids = layer.partitioned_cids[partition_id]

is_match = (global_cids[:,None] >= nids[None,:]) & (global_cids[:,None] < nids[None,:] + layer.block_size)
local_nids = (is_match * torch.arange(1, nids.size(0) + 1, device = pc.device)[None,:] - 1).sum(dim = 1)
local_nid_offsets = ((global_cids[:,None] - nids[None,:]) * is_match.long()).sum(dim = 1)

target_nids = cids[local_nids,:] + local_nid_offsets[:,None]
target_cids = cids[local_nids,:]

target_idx = node_pointers[ind_b] + (torch.cumsum(mask, dim = 0)[ind_n, ind_b] - 1) * cids.size(1)

if target_idx.max() + cids.size(1) > node_samples.size(0):

node_samples_new = torch.zeros([target_idx.max() + cids.size(1), num_samples], dtype = torch.long, device = pc.device)
node_samples_new[:,:] = -1

node_samples_new[:node_samples.size(0),:] = node_samples
node_samples = node_samples_new

match_filter = is_match.any(dim = 1)
target_nids = target_nids[match_filter,:]
target_cids = target_cids[match_filter,:]
target_idx = target_idx[match_filter]
target_b = ind_b[match_filter]
count_prod_nch(layer, nids, cids, element_samples, ind_ch_count, ind_nids,
ind_nid_offs, ind_mask, ind_n, ind_b, layer.block_size, partition_id)

# Pre-compute the target indices in `node_samples`
ind_target_sid = np.zeros([ind_n.size(0)], dtype = np.int64)
ind_target_sid[1:] = ind_ch_count[:-1].cumsum(dim = 0).detach().cpu().numpy()
ind_target = np.zeros([ind_ch_count.sum()], dtype = np.int64)
_assign_nids_ind_target(ind_target, ind_target_sid,
node_pointers.detach().cpu().numpy(),
ind_ch_count.detach().cpu().numpy(),
ind_b.detach().cpu().numpy(), num_samples)
ind_target_sid = torch.from_numpy(ind_target_sid).to(pc.device)
ind_target = torch.from_numpy(ind_target).to(pc.device)

for i in range(cids.size(1)):
cmask = target_cids[:,i] != 0
node_samples[target_idx[cmask]+i, target_b[cmask]] = target_nids[cmask,i]
# Store child indices
for partition_id in range(layer.num_fw_partitions):
nids = layer.partitioned_nids[partition_id]
cids = layer.partitioned_cids[partition_id]

node_pointers = (node_samples != -1).sum(dim = 0)
sample_prod_layer(layer, nids, cids, node_samples, element_samples, ind_target, ind_target_sid,
ind_n, ind_b, ind_nids, ind_nid_offs, ind_mask, layer.block_size, partition_id)

# Create tensor for the samples
data_dtype = pc.input_layer_group[0].get_data_dtype()
Expand Down
22 changes: 19 additions & 3 deletions tests/queries/sample_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def test_sample_correctness():

samples = juice.queries.sample(pc, num_samples = 512)

assert samples[:,0].float().mean() < 0.4
assert samples[:,1].float().mean() > 0.6
assert samples[:,0].float().mean() > 0.6
assert samples[:,1].float().mean() < 0.4


def test_sample_hclt():
Expand Down Expand Up @@ -118,8 +118,24 @@ def test_sample_hclt():
assert ((samples >= 0) & (samples < 256)).all()


def test_sample_hmm():

device = torch.device("cuda:0")

ns = juice.structures.HMM(seq_length = 32, num_latents = 256, num_emits = 100, homogeneous = True, block_size = 64)
ns.init_parameters(perturbation = 2.0)

pc = juice.compile(ns)
pc.to(device)

samples = juice.queries.sample(pc, num_samples = 16)

assert ((samples >= 0) & (samples < 100)).all()


if __name__ == "__main__":
torch.set_num_threads(4)
test_sample()
test_sample_correctness()
test_sample_hclt()
test_sample_hclt()
test_sample_hmm()

0 comments on commit c6bdde2

Please sign in to comment.