Skip to content

Commit

Permalink
speed up input layer tdp computation
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Aug 12, 2024
1 parent abc8d3b commit d5ab317
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 13 deletions.
30 changes: 23 additions & 7 deletions src/pyjuice/layer/input_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@
from triton.runtime.jit import JITFunction
from copy import deepcopy

# In the latest triton, math functions were shuffled around into different modules:
# https://github.com/openai/triton/pull/3172
if hasattr(tl.extra.cuda, "libdevice"):
tlmath = tl.extra.cuda.libdevice
else:
tlmath = tl.math

from pyjuice.nodes import InputNodes
from pyjuice.utils.grad_fns import ReverseGrad
from pyjuice.utils import BitSet
Expand Down Expand Up @@ -548,14 +555,19 @@ def add_missing_flows(self, node_flows: torch.Tensor, logspace_flows: bool = Fal
node_offset = self._output_ind_range[0]
layer_num_nodes = self._output_ind_range[1] - self._output_ind_range[0]

BLOCK_SIZE = 1024

grid = (triton.cdiv(layer_num_nodes, BLOCK_SIZE),)

if not self.provided("_missing_flows_kernel"):
assert self.bk_flow_mask_fn is not None, "The target distribution doesn't have `bk_flow_mask_fn` implemented."
self._missing_flows_kernel = self._compile_triton_kernel(self._missing_flows_kernel_template, flow_fn = self.bk_flow_mask_fn)

if self._need_2nd_kernel_dim():
BLOCK_SIZE = 32
TILE_SIZE_K = 32
else:
BLOCK_SIZE = 1024
TILE_SIZE_K = 1

grid = (triton.cdiv(layer_num_nodes, BLOCK_SIZE),)

self._missing_flows_kernel[grid](
node_flows_ptr = node_flows,
params_ptr = self.params,
Expand All @@ -568,7 +580,8 @@ def add_missing_flows(self, node_flows: torch.Tensor, logspace_flows: bool = Fal
node_offset = node_offset,
layer_num_nodes = layer_num_nodes,
logspace_flows = logspace_flows,
BLOCK_SIZE = BLOCK_SIZE
BLOCK_SIZE = BLOCK_SIZE,
TILE_SIZE_K = TILE_SIZE_K
)
else:
raise NotImplementedError("CPU minibatch missing flow fn for input nodes is not implemented.")
Expand Down Expand Up @@ -707,6 +720,9 @@ def _init_parameters(self, perturbation):

p_start = p_end

def _need_2nd_kernel_dim(self):
return self.nodes[0].dist._need_2nd_kernel_dim()

@staticmethod
def _mars_kernel_template(mar_fn, params_ptr, node_mars_ptr, data_ptr, vids_ptr, s_pids_ptr, metadata_ptr, s_mids_ptr,
fw_local_ids_ptr, partial_eval: tl.constexpr, layer_num_nodes: tl.constexpr, batch_size: tl.constexpr,
Expand Down Expand Up @@ -852,7 +868,7 @@ def _flows_kernel_template(flow_fn, params_ptr, param_flows_ptr, node_flows_ptr,
@staticmethod
def _missing_flows_kernel_template(node_flows_ptr, params_ptr, param_flows_ptr, s_pids_ptr, s_pfids_ptr,
metadata_ptr, s_mids_ptr, scale, node_offset: tl.constexpr, layer_num_nodes: tl.constexpr,
logspace_flows: tl.constexpr, BLOCK_SIZE: tl.constexpr):
logspace_flows: tl.constexpr, BLOCK_SIZE: tl.constexpr, TILE_SIZE_K: tl.constexpr):
pid = tl.program_id(axis = 0)
block_start = pid * BLOCK_SIZE

Expand Down Expand Up @@ -1047,7 +1063,7 @@ def parse_source(src, get_signature = False):
new_src = new_fn_header + "\n" + "\n".join(new_fn_body)

# Add import commands
new_src = "import triton\nimport triton.language as tl\n\n" + new_src
new_src = "import triton\nimport triton.language as tl\nif hasattr(tl.extra.cuda, 'libdevice'):\n tlmath = tl.extra.cuda.libdevice\nelse:\n tlmath = tl.math\n\n" + new_src

# Make a pseudo-function from the source code
new_fn = make_function_from_src(new_src)
Expand Down
18 changes: 12 additions & 6 deletions src/pyjuice/nodes/distributions/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,21 +94,24 @@ def bk_flow_fn(local_offsets, ns_offsets, data, flows, node_mars_ptr, params_ptr

@staticmethod
def bk_flow_mask_fn(local_offsets, ns_offsets, data, flows, node_mars_ptr, params_ptr, param_flows_ptr, s_pids, s_pfids, metadata_ptr,
s_mids_ptr, mask, num_vars_per_node, BLOCK_SIZE):
s_mids_ptr, mask, num_vars_per_node, BLOCK_SIZE, TILE_SIZE_K):
# Get `num_cats` from `metadata`
s_mids = tl.load(s_mids_ptr + local_offsets, mask = mask, other = 0)
num_cats = tl.load(metadata_ptr + s_mids, mask = mask, other = 0).to(tl.int64)

max_num_cats = tl.max(num_cats, axis = 0)
num_iters = tlmath.ceil(max_num_cats / TILE_SIZE_K).to(tl.int64)

for cat_id in range(max_num_cats):
cat_mask = mask & missing_mask & (cat_id < num_cats)
cat_ids = tl.arange(0, TILE_SIZE_K)

for i in range(num_iters):
cat_mask = mask[:,None] & missing_mask[:,None] & (cat_ids[None,:] < num_cats[:,None])

p_offsets = s_pids + cat_id
p_offsets = s_pids[:,None] + cat_ids[None,:]
param = tl.load(params_ptr + p_offsets, mask = cat_mask, other = 0)

pf_offsets = s_pfids + cat_id
tl.atomic_add(param_flows_ptr + pf_offsets, flows * param, mask = cat_mask)
pf_offsets = s_pfids[:,None] + cat_ids[None,:]
tl.atomic_add(param_flows_ptr + pf_offsets, flows[:,None] * param, mask = cat_mask)

@staticmethod
def sample_fn(samples_ptr, local_offsets, batch_offsets, vids, s_pids, params_ptr, metadata_ptr, s_mids_ptr, mask, batch_size, BLOCK_SIZE, seed):
Expand Down Expand Up @@ -168,3 +171,6 @@ def em_fn(local_offsets, params_ptr, param_flows_ptr, s_pids, s_pfids, metadata_

def _get_constructor(self):
return Categorical, {"num_cats": self.num_cats}

def _need_2nd_kernel_dim(self):
return True
3 changes: 3 additions & 0 deletions src/pyjuice/nodes/distributions/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,6 @@ def em_fn(*args, **kwargs):

def _get_constructor(self):
raise NotImplementedError()

def _need_2nd_kernel_dim(self):
return False
30 changes: 30 additions & 0 deletions tests/optim/top_down_prob_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,36 @@ def test_scaled_mini_batch_em():
assert torch.all(torch.abs(ni3_params - epars) < 1e-5)


def test_tdp_speed():

device = torch.device("cuda:0")

root_ns = juice.structures.HMM(seq_length = 32, num_latents = 4096, num_emits = 50000, homogeneous = True)

pc = TensorCircuit(root_ns, layer_sparsity_tol = 0.1)
pc.to(device)

pc.init_param_flows()
pc._init_buffer(name = "node_flows", shape = (pc.num_nodes, 1), set_value = 0.0)
pc._init_buffer(name = "element_flows", shape = (pc.num_elements, 1), set_value = 0.0)

eval_top_down_probs(pc, update_pflow = True, scale = 1.0)

t0 = time.time()
for _ in range(100):
eval_top_down_probs(pc, update_pflow = True, scale = 1.0)
t1 = time.time()

tdp_ms = (t1 - t0) / 100 * 1000

print(f"Computing TDP on average takes {tdp_ms:.3f}ms.")
print("Reference computation time on RTX 4090: 31.423ms.")
print("--------------------------------------------------------------")



if __name__ == "__main__":
torch.set_num_threads(8)
test_simple_model_tdp()
test_scaled_mini_batch_em()
test_tdp_speed()

0 comments on commit d5ab317

Please sign in to comment.