From edb67ecc3f77593db13899091786fbfad312a248 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Sun, 10 Mar 2024 00:44:41 +0800 Subject: [PATCH 01/53] `MPE` and `GeneralLL` for forward pass (block sparse kernels) --- src/pyjuice/layer/layer.py | 17 +++ src/pyjuice/layer/sum_layer.py | 196 ++++++++++++++++++++------- src/pyjuice/model/tensorcircuit.py | 32 ++++- tests/layer/propagation_algs_test.py | 124 +++++++++++++++++ 4 files changed, 318 insertions(+), 51 deletions(-) create mode 100644 tests/layer/propagation_algs_test.py diff --git a/src/pyjuice/layer/layer.py b/src/pyjuice/layer/layer.py index a8585fbf..00677336 100644 --- a/src/pyjuice/layer/layer.py +++ b/src/pyjuice/layer/layer.py @@ -7,6 +7,13 @@ class Layer(): + + propagation_alg_mapping = { + "LL": 0, + "MPE": 1, + "GeneralLL": 2 + } + def __init__(self, nodes: Sequence[CircuitNodes], disable_block_size_check: bool = False) -> None: if disable_block_size_check: @@ -60,3 +67,13 @@ def disable_partial_evaluation(self, forward: bool = True, backward: bool = True def provided(self, var_name): return hasattr(self, var_name) and getattr(self, var_name) is not None + + def _get_propagation_alg_kwargs(self, propagation_alg: str, **kwargs): + if propagation_alg == "LL": + return {} + elif propagation_alg == "MPE": + return {} + elif propagation_alg == "GeneralLL": + return {"alpha": kwargs["alpha"]} + else: + raise ValueError(f"Unknown propagation algorithm {propagation_alg}.") diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index f28597ed..7bd58282 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -208,7 +208,8 @@ def num_param_flows(self): return self._layer_pfid_range[1] - self._layer_pfid_range[0] def forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, params: torch.Tensor, - force_use_fp16: bool = False, force_use_fp32: bool = False) -> None: + force_use_fp16: bool = False, force_use_fp32: bool = False, + propagation_alg: str = "LL", **kwargs) -> None: """ Computes the forward pass of a sum layer. @@ -228,7 +229,8 @@ def forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, params: t self._forward( node_mars, element_mars, params, nids, cids, pids, partition_id = partition_id, force_use_fp16 = force_use_fp16, - force_use_fp32 = force_use_fp32 + force_use_fp32 = force_use_fp32, + propagation_alg = propagation_alg, **kwargs ) else: @@ -243,7 +245,8 @@ def forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, params: t node_mars, element_mars, params, nids, cids, pids, local_ids = local_ids, partition_id = partition_id, force_use_fp16 = force_use_fp16, - force_use_fp32 = force_use_fp32 + force_use_fp32 = force_use_fp32, + propagation_alg = propagation_alg, **kwargs ) return None @@ -344,7 +347,8 @@ def _forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, params: torch.Tensor, nids: torch.Tensor, cids: torch.Tensor, pids: torch.Tensor, local_ids: Optional[torch.Tensor] = None, partition_id: int = -1, mode: Optional[str] = None, - force_use_fp16: bool = False, force_use_fp32: bool = False) -> None: + force_use_fp16: bool = False, force_use_fp32: bool = False, + propagation_alg: str = "LL", **kwargs) -> None: """ Forward pass of sum layers. @@ -380,18 +384,19 @@ def _forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, self._forward_block_sparse( node_mars, element_mars, params, nids, cids, pids, local_ids, partition_id = partition_id, force_use_fp16 = force_use_fp16, - force_use_fp32 = force_use_fp32 + force_use_fp32 = force_use_fp32, propagation_alg = propagation_alg, **kwargs ) elif mode == self.SPARSE: self._forward_sparse( node_mars, element_mars, params, nids, cids, pids, local_ids, - partition_id = partition_id + partition_id = partition_id, propagation_alg = propagation_alg, **kwargs ) elif mode == self.PYTORCH: self._forward_pytorch( - node_mars, element_mars, params, nids, cids, pids, local_ids + node_mars, element_mars, params, nids, cids, pids, local_ids, + propagation_alg = propagation_alg, **kwargs ) else: @@ -403,7 +408,8 @@ def _forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, def _fw_triton_block_sparse_tlmm_kernel(node_mars, element_mars, params, nids, cids_start, cids_increment, pids_start, pids_increment, local_ids, batch_size: tl.constexpr, partial_eval: tl.constexpr, BLOCK_B: tl.constexpr, TILE_SIZE_K: tl.constexpr, K_NUM_TILES: tl.constexpr, - TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, use_fp16: tl.constexpr): + TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, use_fp16: tl.constexpr, + propagation_alg_id: tl.constexpr, alpha = 0.0): pid_b = tl.program_id(0) # ID of size-`BLOCK_B` batches pid_m = tl.program_id(1) # ID of size-`TILE_SIZE_M` nodes @@ -453,22 +459,45 @@ def _fw_triton_block_sparse_tlmm_kernel(node_mars, element_mars, params, nids, c epars = tl.load(epars_ptr) emars = tl.load(emars_ptr, mask = mask_batch[None,:]) - emars_max = tl.max(emars, axis = 0)[None,:] - emars_sub = tl.where(emars_max != -float("inf"), tl.exp(emars - emars_max), 0.0) + if propagation_alg_id == 1: + # MPE propagation method + lpars = tl.log(epars) + nmars = tl.max(lpars[:,:,None] + emars[None,:,:], axis = 1) + + acc = tl.maximum(acc, nmars) - if use_fp16 == 1: - # Built-in matmul kernel of triton + float16 - epars_fp16 = (epars * (2**12)).to(tl.float16) - emars_fp16 = emars_sub.to(tl.float16) - nmars = tl.dot(epars_fp16, emars_fp16).to(tl.float32) / (2**12) else: - # Built-in matmul kernel of triton + float32 - nmars = tl.dot(epars, emars_sub) - acc = tl.where(emars_max > acc, - tl.log(nmars + tl.exp(acc - emars_max)) + emars_max, - tl.log(tl.exp(emars_max - acc) * nmars + 1.0) + acc - ) + if propagation_alg_id == 0: + # LL propagation method + emars_max = tl.max(emars, axis = 0)[None,:] + emars_sub = tl.where(emars_max != -float("inf"), tl.exp(emars - emars_max), 0.0) + + if propagation_alg_id == 2: + # GeneralLL propagation method + + emars_max = tl.max(emars, axis = 0)[None,:] + # Compute p_i^{alpha} for every i + emars_sub = tl.where(emars_max != -float("inf"), tl.exp((emars - emars_max) * alpha), 0.0) + # Compute w_i^{alpha} for every i + epars = tl.exp(tl.log(epars) * alpha) + + # Also scale `emars_max` + emars_max *= alpha + + if use_fp16 == 1: + # Built-in matmul kernel of triton + float16 + epars_fp16 = (epars * (2**12)).to(tl.float16) + emars_fp16 = emars_sub.to(tl.float16) + nmars = tl.dot(epars_fp16, emars_fp16).to(tl.float32) / (2**12) + else: + # Built-in matmul kernel of triton + float32 + nmars = tl.dot(epars, emars_sub) + + acc = tl.where(emars_max > acc, + tl.log(nmars + tl.exp(acc - emars_max)) + emars_max, + tl.log(tl.exp(emars_max - acc) * nmars + 1.0) + acc + ) # Increment `epars_ptr` pids_inc = tl.load(pids_inc_ptr) @@ -480,6 +509,10 @@ def _fw_triton_block_sparse_tlmm_kernel(node_mars, element_mars, params, nids, c emars_ptr += cids_inc[:,None] * batch_size cids_inc_ptr += TILE_SIZE_K + if propagation_alg_id == 2: + # Compute p_i^{1/alpha} + acc *= (1.0 / alpha) + # Write back off_nids = tl.load(nids + nblock_id) offs_nmars = (off_nids + offs_node[:,None]) * batch_size + offs_batch[None,:] @@ -491,7 +524,8 @@ def _fw_triton_block_sparse_tlmm_kernel(node_mars, element_mars, params, nids, c def _fw_triton_block_sparse_csmm1_kernel(node_mars, element_mars, params, nids, cids_start, cids_increment, pids_start, pids_increment, local_ids, batch_size: tl.constexpr, partial_eval: tl.constexpr, BLOCK_B: tl.constexpr, TILE_SIZE_K: tl.constexpr, K_NUM_TILES: tl.constexpr, - TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, use_fp16: tl.constexpr): + TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, use_fp16: tl.constexpr, + propagation_alg_id: tl.constexpr, alpha = 0.0): pid_b = tl.program_id(0) # ID of size-`BLOCK_B` batches pid_m = tl.program_id(1) # ID of size-`TILE_SIZE_M` nodes @@ -541,22 +575,45 @@ def _fw_triton_block_sparse_csmm1_kernel(node_mars, element_mars, params, nids, epars = tl.load(epars_ptr) emars = tl.load(emars_ptr, mask = mask_batch[None,:]) - emars_max = tl.max(emars, axis = 0)[None,:] - emars_sub = tl.where(emars_max != -float("inf"), tl.exp(emars - emars_max), 0.0) + if propagation_alg_id == 1: + # MPE propagation method + lpars = tl.log(epars) + nmars = tl.max(lpars[:,:,None] + emars[None,:,:], axis = 1) + + acc = tl.maximum(acc, nmars) - if use_fp16 == 1: - # Simulated matmul kernel + float16 - epars = (epars * (2**4)).to(tl.float16) - emars_sub = emars_sub.to(tl.float16) - nmars = tl.sum(epars[:,:,None] * emars_sub[None,:,:], axis = 1).to(tl.float32) / (2**4) else: - # Simulated matmul kernel + float32 - nmars = tl.sum(epars[:,:,None] * emars_sub[None,:,:], axis = 1) - acc = tl.where(emars_max > acc, - tl.log(nmars + tl.exp(acc - emars_max)) + emars_max, - tl.log(tl.exp(emars_max - acc) * nmars + 1.0) + acc - ) + if propagation_alg_id == 0: + # LL propagation method + emars_max = tl.max(emars, axis = 0)[None,:] + emars_sub = tl.where(emars_max != -float("inf"), tl.exp(emars - emars_max), 0.0) + + if propagation_alg_id == 2: + # GeneralLL propagation method + + emars_max = tl.max(emars, axis = 0)[None,:] + # Compute p_i^{alpha} for every i + emars_sub = tl.where(emars_max != -float("inf"), tl.exp((emars - emars_max) * alpha), 0.0) + # Compute w_i^{alpha} for every i + epars = tl.exp(tl.log(epars) * alpha) + + # Also scale `emars_max` + emars_max *= alpha + + if use_fp16 == 1: + # Simulated matmul kernel + float16 + epars = (epars * (2**4)).to(tl.float16) + emars_sub = emars_sub.to(tl.float16) + nmars = tl.sum(epars[:,:,None] * emars_sub[None,:,:], axis = 1).to(tl.float32) / (2**4) + else: + # Simulated matmul kernel + float32 + nmars = tl.sum(epars[:,:,None] * emars_sub[None,:,:], axis = 1) + + acc = tl.where(emars_max > acc, + tl.log(nmars + tl.exp(acc - emars_max)) + emars_max, + tl.log(tl.exp(emars_max - acc) * nmars + 1.0) + acc + ) # Increment `epars_ptr` pids_inc = tl.load(pids_inc_ptr) @@ -568,6 +625,10 @@ def _fw_triton_block_sparse_csmm1_kernel(node_mars, element_mars, params, nids, emars_ptr += cids_inc[:,None] * batch_size cids_inc_ptr += TILE_SIZE_K + if propagation_alg_id == 2: + # Compute p_i^{1/alpha} + acc *= (1.0 / alpha) + # Write back off_nids = tl.load(nids + nblock_id) offs_nmars = (off_nids + offs_node[:,None]) * batch_size + offs_batch[None,:] @@ -579,7 +640,8 @@ def _fw_triton_block_sparse_csmm1_kernel(node_mars, element_mars, params, nids, def _fw_triton_block_sparse_csmm2_kernel(node_mars, element_mars, params, nids, cids_start, cids_increment, pids_start, pids_increment, local_ids, batch_size: tl.constexpr, partial_eval: tl.constexpr, BLOCK_B: tl.constexpr, TILE_SIZE_K: tl.constexpr, K_NUM_TILES: tl.constexpr, - TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, use_fp16: tl.constexpr): + TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, use_fp16: tl.constexpr, + propagation_alg_id: tl.constexpr, alpha = 0.0): pid_b = tl.program_id(0) # ID of size-`BLOCK_B` batches pid_m = tl.program_id(1) # ID of size-`TILE_SIZE_M` nodes @@ -629,16 +691,39 @@ def _fw_triton_block_sparse_csmm2_kernel(node_mars, element_mars, params, nids, epars = tl.load(epars_ptr) emars = tl.load(emars_ptr, mask = mask_batch[:,None]) - emars_max = tl.max(emars, axis = 1) - emars_sub = tl.where(emars_max[:,None] != -float("inf"), tl.exp(emars - emars_max[:,None]), 0.0) + if propagation_alg_id == 1: + # MPE propagation method + lpars = tl.log(epars) + nmars = tl.max(lpars[:,:,None] + tl.trans(emars)[None,:,:], axis = 1) - # Simulated matmul kernel + float32 - nmars = tl.sum(epars[:,:,None] * tl.trans(emars_sub)[None,:,:], axis = 1) + acc = tl.maximum(acc, nmars) - acc = tl.where(emars_max[None,:] > acc, - tl.log(nmars + tl.exp(acc - emars_max[None,:])) + emars_max[None,:], - tl.log(tl.exp(emars_max[None,:] - acc) * nmars + 1.0) + acc - ) + else: + + if propagation_alg_id == 0: + # LL propagation method + emars_max = tl.max(emars, axis = 1) + emars_sub = tl.where(emars_max[:,None] != -float("inf"), tl.exp(emars - emars_max[:,None]), 0.0) + + if propagation_alg_id == 2: + # GeneralLL propagation method + + emars_max = tl.max(emars, axis = 1) + # Compute p_i^{alpha} for every i + emars_sub = tl.where(emars_max[:,None] != -float("inf"), tl.exp((emars - emars_max[:,None]) * alpha), 0.0) + # Compute w_i^{alpha} for every i + epars = tl.exp(tl.log(epars) * alpha) + + # Also scale `emars_max` + emars_max *= alpha + + # Simulated matmul kernel + float32 + nmars = tl.sum(epars[:,:,None] * tl.trans(emars_sub)[None,:,:], axis = 1) + + acc = tl.where(emars_max[None,:] > acc, + tl.log(nmars + tl.exp(acc - emars_max[None,:])) + emars_max[None,:], + tl.log(tl.exp(emars_max[None,:] - acc) * nmars + 1.0) + acc + ) # Increment `epars_ptr` pids_inc = tl.load(pids_inc_ptr) @@ -650,6 +735,10 @@ def _fw_triton_block_sparse_csmm2_kernel(node_mars, element_mars, params, nids, emars_ptr += cids_inc[None,:] * batch_size cids_inc_ptr += TILE_SIZE_K + if propagation_alg_id == 2: + # Compute p_i^{1/alpha} + acc *= (1.0 / alpha) + # Write back off_nids = tl.load(nids + nblock_id) offs_nmars = (off_nids + offs_node[:,None]) * batch_size + offs_batch[None,:] @@ -659,7 +748,7 @@ def _forward_block_sparse(self, node_mars: torch.Tensor, element_mars: torch.Ten params: torch.Tensor, nids: torch.Tensor, cids: torch.Tensor, pids: torch.Tensor, local_ids: Optional[torch.Tensor] = None, partition_id: int = -1, force_use_fp16: bool = False, - force_use_fp32: bool = False) -> None: + force_use_fp32: bool = False, propagation_alg: str = "LL", **kwargs) -> None: """ Forward pass of sum layers with the block-sparse processing kernel. @@ -680,6 +769,10 @@ def _forward_block_sparse(self, node_mars: torch.Tensor, element_mars: torch.Ten batch_size = node_mars.size(1) BATCH_SIZE_NP2 = triton.next_power_of_2(batch_size) + # Propagation algorithm + propagation_alg_id = self.propagation_alg_mapping[propagation_alg] + propagation_alg_kwargs = self._get_propagation_alg_kwargs(propagation_alg, **kwargs) + # Heuristic to set `TILE_SIZE_M`, `TILE_SIZE_K`, and `BLOCK_B` base_size = min(self.block_size, num_edges, BATCH_SIZE_NP2, 128) if base_size >= 64: @@ -751,7 +844,9 @@ def _forward_block_sparse(self, node_mars: torch.Tensor, element_mars: torch.Ten K_NUM_TILES = K_NUM_TILES, TILE_SIZE_M = TILE_SIZE_M, BLOCK_SIZE_M = BLOCK_SIZE_M, - use_fp16 = use_fp16 + use_fp16 = use_fp16, + propagation_alg_id = propagation_alg_id, + **propagation_alg_kwargs ) elif TILE_SIZE_M >= 8 and TILE_SIZE_K >= 8 and BLOCK_B >= 8: @@ -772,8 +867,11 @@ def _forward_block_sparse(self, node_mars: torch.Tensor, element_mars: torch.Ten K_NUM_TILES = K_NUM_TILES, TILE_SIZE_M = TILE_SIZE_M, BLOCK_SIZE_M = BLOCK_SIZE_M, - use_fp16 = use_fp16 + use_fp16 = use_fp16, + propagation_alg_id = propagation_alg_id, + **propagation_alg_kwargs ) + else: self._fw_triton_block_sparse_csmm2_kernel[grid]( node_mars, @@ -792,7 +890,9 @@ def _forward_block_sparse(self, node_mars: torch.Tensor, element_mars: torch.Ten K_NUM_TILES = K_NUM_TILES, TILE_SIZE_M = TILE_SIZE_M, BLOCK_SIZE_M = BLOCK_SIZE_M, - use_fp16 = use_fp16 + use_fp16 = use_fp16, + propagation_alg_id = propagation_alg_id, + **propagation_alg_kwargs ) return None diff --git a/src/pyjuice/model/tensorcircuit.py b/src/pyjuice/model/tensorcircuit.py index 2fa826e3..183a6a8f 100644 --- a/src/pyjuice/model/tensorcircuit.py +++ b/src/pyjuice/model/tensorcircuit.py @@ -101,6 +101,10 @@ def __init__(self, root_ns: CircuitNodes, layer_sparsity_tol: float = 0.5, # CudaGraph options self._recorded_cuda_graphs = dict() + # Mode for forward and backward pass + self.default_propagation_alg = "LL" # Could be "LL", "MPE", or "GeneralLL" + self.propagation_alg_kwargs = dict() + def to(self, device): super(TensorCircuit, self).to(device) @@ -115,10 +119,26 @@ def to(self, device): self.par_update_kwargs = par_update_to_device(self.par_update_kwargs, device) return self + + def set_propagation_alg(self, propagation_alg: str, **kwargs): + if propagation_alg == "LL": + self.default_propagation_alg = "LL" + self.propagation_alg_kwargs.clear() + elif propagation_alg == "MPE": + self.default_propagation_alg = "MPE" + self.propagation_alg_kwargs.clear() + elif propagation_alg == "GeneralLL": + assert "alpha" in kwargs, "Argument `alpha` should be provided for the `GeneralLL` propagation algorithm." + self.default_propagation_alg = "GeneralLL" + self.propagation_alg_kwargs.clear() + self.propagation_alg_kwargs["alpha"] = kwargs["alpha"] + else: + raise NotImplementedError(f"Unknown propagation algorithm {propagation_alg}.") def forward(self, inputs: torch.Tensor, input_layer_fn: Optional[Union[str,Callable]] = None, cache: Optional[dict] = None, return_cache: bool = False, record_cudagraph: bool = False, - apply_cudagraph: bool = True, force_use_fp16: bool = False, force_use_fp32: bool = False, **kwargs): + apply_cudagraph: bool = True, force_use_fp16: bool = False, force_use_fp32: bool = False, + propagation_alg: Optional[str] = None, **kwargs): """ Forward evaluation of the PC. @@ -135,6 +155,11 @@ def forward(self, inputs: torch.Tensor, input_layer_fn: Optional[Union[str,Calla assert inputs.dim() == 2 and inputs.size(1) == self.num_vars inputs = inputs.permute(1, 0) + + # Set propagation algorithm + if propagation_alg is None: + propagation_alg = self.default_propagation_alg + kwargs.update(self.propagation_alg_kwargs) ## Initialize buffers for forward pass ## @@ -173,8 +198,9 @@ def _run_inner_layers(): elif layer_group.is_sum(): # Sum layer layer_group(self.node_mars, self.element_mars, self.params, - force_use_fp16 = force_use_fp16, - force_use_fp32 = force_use_fp32) + force_use_fp16 = force_use_fp16, + force_use_fp32 = force_use_fp32, + propagation_alg = propagation_alg, **kwargs) else: raise ValueError(f"Unknown layer type {type(layer)}.") diff --git a/tests/layer/propagation_algs_test.py b/tests/layer/propagation_algs_test.py new file mode 100644 index 00000000..a4113f93 --- /dev/null +++ b/tests/layer/propagation_algs_test.py @@ -0,0 +1,124 @@ +import pyjuice as juice +import torch +import numpy as np +import time +import random + +import pyjuice.nodes.distributions as dists +from pyjuice.utils import BitSet +from pyjuice.nodes import multiply, summate, inputs +from pyjuice.model import TensorCircuit + +from pyjuice.layer import InputLayer, ProdLayer, SumLayer + +import pytest + + +def general_ll_prop_test(): + + device = torch.device("cuda:0") + + batch_size = 16 + + for block_size in [4, 8, 16]: + + with juice.set_block_size(block_size): + + ni0 = inputs(0, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni1 = inputs(1, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni2 = inputs(2, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni3 = inputs(3, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + + np0 = multiply(ni0, ni1) + np1 = multiply(ni2, ni3) + np2 = multiply(ni1, ni2) + + ns0 = summate(np0, num_node_blocks = 2) + ns1 = summate(np1, num_node_blocks = 2) + ns2 = summate(np2, num_node_blocks = 2) + + input_layer = InputLayer([ni0, ni1, ni2, ni3], cum_nodes = block_size) + + prod_layer = ProdLayer([np0, np1, np2]) + + layer = SumLayer([ns0, ns1, ns2], global_nid_start = block_size, + global_pid_start = block_size ** 2, + global_pfid_start = 0, node2tiednodes = dict()) + + layer.to(device) + + ## Forward pass ## + + element_mars = torch.rand([block_size + 3 * 2 * 2 * block_size, batch_size]).log().to(device) + element_mars[:block_size,:] = -float("inf") + node_mars = torch.zeros([block_size + block_size * 2 * 3, batch_size]).to(device) + + params = torch.rand([block_size ** 2 + 3 * 4 * block_size * block_size]).to(device) + + alphas = [1.2, 2.0, 3.0] + + for alpha in alphas: + layer(node_mars, element_mars, params, propagation_alg = "GeneralLL", alpha = alpha) + + for i in range(block_size): + for j in range(6): + cmars = element_mars[layer.partitioned_cids[0][j,:]].exp() + epars = params[layer.partitioned_pids[0][j,:]+i] + scaled_lls = (epars[:,None]**alpha * cmars**alpha).sum(dim = 0).log() * (1.0 / alpha) + assert torch.all(torch.abs(node_mars[(j+1)*block_size+i,:] - scaled_lls) < 1e-3) + + +def mpe_prop_test(): + + device = torch.device("cuda:0") + + batch_size = 16 + + for block_size in [4, 8, 16]: + + with juice.set_block_size(block_size): + + ni0 = inputs(0, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni1 = inputs(1, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni2 = inputs(2, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni3 = inputs(3, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + + np0 = multiply(ni0, ni1) + np1 = multiply(ni2, ni3) + np2 = multiply(ni1, ni2) + + ns0 = summate(np0, num_node_blocks = 2) + ns1 = summate(np1, num_node_blocks = 2) + ns2 = summate(np2, num_node_blocks = 2) + + input_layer = InputLayer([ni0, ni1, ni2, ni3], cum_nodes = block_size) + + prod_layer = ProdLayer([np0, np1, np2]) + + layer = SumLayer([ns0, ns1, ns2], global_nid_start = block_size, + global_pid_start = block_size ** 2, + global_pfid_start = 0, node2tiednodes = dict()) + + layer.to(device) + + ## Forward pass ## + + element_mars = torch.rand([block_size + 3 * 2 * 2 * block_size, batch_size]).log().to(device) + element_mars[:block_size,:] = -float("inf") + node_mars = torch.zeros([block_size + block_size * 2 * 3, batch_size]).to(device) + + params = torch.rand([block_size ** 2 + 3 * 4 * block_size * block_size]).to(device) + + layer(node_mars, element_mars, params, propagation_alg = "MPE") + + for i in range(block_size): + for j in range(6): + cmars = element_mars[layer.partitioned_cids[0][j,:]].exp() + epars = params[layer.partitioned_pids[0][j,:]+i] + scaled_lls = (epars[:,None] * cmars).max(dim = 0).values.log() + assert torch.all(torch.abs(node_mars[(j+1)*block_size+i,:] - scaled_lls) < 1e-3) + + +if __name__ == "__main__": + general_ll_prop_test() + mpe_prop_test() From 851bd643cd28c7cff54b8230c5da7f0708846cc2 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Sun, 10 Mar 2024 01:41:40 +0800 Subject: [PATCH 02/53] `MPE` and `GeneralLL` for forward pass (sparse kernels) --- src/pyjuice/layer/sum_layer.py | 75 +++++++++++++++++++++------- tests/layer/propagation_algs_test.py | 4 +- 2 files changed, 59 insertions(+), 20 deletions(-) diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index 7bd58282..c184c71c 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -902,7 +902,7 @@ def _forward_block_sparse(self, node_mars: torch.Tensor, element_mars: torch.Ten @FastJITFunction def _fw_triton_sparse_kernel(node_mars, element_mars, params, nids, cids, pids, local_ids, batch_size, partial_eval: tl.constexpr, num_edges: tl.constexpr, - BLOCK_B: tl.constexpr, BLOCK_SIZE_M: tl.constexpr): + BLOCK_B: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, propagation_alg_id: tl.constexpr, alpha = 0.0): pid_b = tl.program_id(axis = 0) # ID of size-`BLOCK_B` batches pid_m = tl.program_id(axis = 1) # ID of size-`BLOCK_SIZE_M` nodes @@ -930,9 +930,15 @@ def _fw_triton_sparse_kernel(node_mars, element_mars, params, nids, cids, pids, offs_batch[None,:] emars = tl.load(emars_ptr, mask = mask_batch[None,:]) # [num_edges, BLOCK_B] - # Compute max and subtract - emars_max = tl.max(emars, axis = 0) - emars = tl.exp(emars - emars_max[None,:]) + # Compute max and subtract (only when using LL or GeneralLL propagation method) + if propagation_alg_id == 0: + emars_max = tl.max(emars, axis = 0) + emars = tl.exp(emars - emars_max[None,:]) + + if propagation_alg_id == 2: + emars_max = tl.max(emars, axis = 0) + emars = tl.exp((emars - emars_max[None,:]) * alpha) + emars_max *= alpha # Initialize pointers to `node_mars` off_nids = tl.load(nids + nblock_id) @@ -944,7 +950,16 @@ def _fw_triton_sparse_kernel(node_mars, element_mars, params, nids, cids, pids, for i in range(0, BLOCK_SIZE_M): epars = tl.load(epars_ptr) - nmars = tl.log(tl.sum(emars * epars[:,None], axis = 0)) + emars_max + if propagation_alg_id == 0: + nmars = tl.log(tl.sum(emars * epars[:,None], axis = 0)) + emars_max + + if propagation_alg_id == 1: + nmars = tl.max(emars + tl.log(epars)[:,None], axis = 0) + + if propagation_alg_id == 2: + epars = tl.exp(tl.log(epars) * alpha) + + nmars = (tl.log(tl.sum(emars * epars[:,None], axis = 0)) + emars_max) * (1.0 / alpha) tl.store(nmars_ptr, nmars, mask = mask_batch) @@ -957,10 +972,9 @@ def _fw_triton_sparse_kernel(node_mars, element_mars, params, nids, cids, pids, @staticmethod # @triton.jit @FastJITFunction - def _fw_triton_large_sparse_kernel(node_mars, element_mars, params, nids, cids, pids, - local_ids, batch_size, num_nodes, pid_m_offset, partial_eval: tl.constexpr, num_edges: tl.constexpr, - BLOCK_B: tl.constexpr, TILE_SIZE_M: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr): + def _fw_triton_large_sparse_kernel(node_mars, element_mars, params, nids, cids, pids, local_ids, batch_size, + num_nodes, pid_m_offset, partial_eval: tl.constexpr, num_edges: tl.constexpr, BLOCK_B: tl.constexpr, + TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, propagation_alg_id: tl.constexpr, alpha = 0.0): pid_b = tl.program_id(axis = 0) # ID of size-`BLOCK_B` batches pid_m = tl.program_id(axis = 1) + pid_m_offset # ID of size-`TILE_SIZE_M` nodes @@ -992,10 +1006,27 @@ def _fw_triton_large_sparse_kernel(node_mars, element_mars, params, nids, cids, offs_batch[None,None,:] # [TILE_SIZE_M, num_edges, BLOCK_B] emars = tl.load(emars_ptr, mask = (mask_m[:,None,None] & mask_batch[None,None,:]), other = 0.0) # [TILE_SIZE_M, num_edges, BLOCK_B] - # Compute max and subtract - emars_max = tl.max(emars, axis = 1) - emars = tl.exp(emars - emars_max[:,None,:]) - nmars = tl.log(tl.sum(emars * epars[:,:,None], axis = 1)) + emars_max + # Compute max and subtract (only when using LL or GeneralLL propagation method) + if propagation_alg_id == 0: + emars_max = tl.max(emars, axis = 1) + emars = tl.exp(emars - emars_max[:,None,:]) + + if propagation_alg_id == 2: + emars_max = tl.max(emars, axis = 1) + emars = tl.exp((emars - emars_max[:,None,:]) * alpha) + emars_max *= alpha + + # Compute sum node marginals + if propagation_alg_id == 0: + nmars = tl.log(tl.sum(emars * epars[:,:,None], axis = 1)) + emars_max + + if propagation_alg_id == 1: + nmars = tl.max(emars + tl.log(epars)[:,:,None], axis = 1) + + if propagation_alg_id == 2: + epars = tl.exp(tl.log(epars) * alpha) + + nmars = (tl.log(tl.sum(emars * epars[:,:,None], axis = 1)) + emars_max) * (1.0 / alpha) # Initialize pointers to `node_mars` off_nids = tl.load(nids + nblock_ids) # [TILE_SIZE_M] @@ -1008,7 +1039,7 @@ def _fw_triton_large_sparse_kernel(node_mars, element_mars, params, nids, cids, def _forward_sparse(self, node_mars: torch.Tensor, element_mars: torch.Tensor, params: torch.Tensor, nids: torch.Tensor, cids: torch.Tensor, pids: torch.Tensor, local_ids: Optional[torch.Tensor] = None, - partition_id: int = -1) -> None: + partition_id: int = -1, propagation_alg: str = "LL", **kwargs) -> None: """ Forward pass of sum layers with the sparse processing kernel. @@ -1027,7 +1058,9 @@ def _forward_sparse(self, node_mars: torch.Tensor, element_mars: torch.Tensor, batch_size = node_mars.size(1) BATCH_SIZE_NP2 = triton.next_power_of_2(batch_size) - # assert num_edges <= 16384, "The sparse forward kernel only support nodes with # edges smaller than 16384." + # Propagation algorithm + propagation_alg_id = self.propagation_alg_mapping[propagation_alg] + propagation_alg_kwargs = self._get_propagation_alg_kwargs(propagation_alg, **kwargs) if triton.cdiv(layer_n_nodes, self.block_size) <= 2048: BLOCK_B = max(min(2048 // num_edges, BATCH_SIZE_NP2), 1) @@ -1049,7 +1082,9 @@ def _forward_sparse(self, node_mars: torch.Tensor, element_mars: torch.Tensor, partial_eval = partial_eval, num_edges = num_edges, BLOCK_B = BLOCK_B, - BLOCK_SIZE_M = BLOCK_SIZE_M + BLOCK_SIZE_M = BLOCK_SIZE_M, + propagation_alg_id = propagation_alg_id, + **propagation_alg_kwargs ) else: @@ -1077,7 +1112,9 @@ def _forward_sparse(self, node_mars: torch.Tensor, element_mars: torch.Tensor, num_edges = num_edges, BLOCK_B = BLOCK_B, TILE_SIZE_M = TILE_SIZE_M, - BLOCK_SIZE_M = BLOCK_SIZE_M + BLOCK_SIZE_M = BLOCK_SIZE_M, + propagation_alg_id = propagation_alg_id, + **propagation_alg_kwargs ) else: for pid_m_start in range(0, grid[1], 32768): @@ -1100,7 +1137,9 @@ def _forward_sparse(self, node_mars: torch.Tensor, element_mars: torch.Tensor, num_edges = num_edges, BLOCK_B = BLOCK_B, TILE_SIZE_M = TILE_SIZE_M, - BLOCK_SIZE_M = BLOCK_SIZE_M + BLOCK_SIZE_M = BLOCK_SIZE_M, + propagation_alg_id = propagation_alg_id, + **propagation_alg_kwargs ) return None diff --git a/tests/layer/propagation_algs_test.py b/tests/layer/propagation_algs_test.py index a4113f93..cda8d94f 100644 --- a/tests/layer/propagation_algs_test.py +++ b/tests/layer/propagation_algs_test.py @@ -20,7 +20,7 @@ def general_ll_prop_test(): batch_size = 16 - for block_size in [4, 8, 16]: + for block_size in [1, 4, 8, 16]: with juice.set_block_size(block_size): @@ -74,7 +74,7 @@ def mpe_prop_test(): batch_size = 16 - for block_size in [4, 8, 16]: + for block_size in [1, 4, 8, 16]: with juice.set_block_size(block_size): From 5f48d1245000b099681616bcd807d014806b0f0a Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Sun, 10 Mar 2024 01:51:55 +0800 Subject: [PATCH 03/53] `MPE` and `GeneralLL` for forward pass (pytorch kernels) --- src/pyjuice/layer/sum_layer.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index c184c71c..de64ce8c 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -1148,7 +1148,7 @@ def _forward_sparse(self, node_mars: torch.Tensor, element_mars: torch.Tensor, @torch.compile def _forward_pytorch_kernel(node_mars: torch.Tensor, element_mars: torch.Tensor, params: torch.Tensor, nids: torch.Tensor, cids: torch.Tensor, pids: torch.Tensor, - local_ids: torch.Tensor): + local_ids: torch.Tensor, propagation_alg_id: int, alpha: float = 0.0): if local_ids is not None: nids = nids[local_ids] @@ -1164,18 +1164,33 @@ def _forward_pytorch_kernel(node_mars: torch.Tensor, element_mars: torch.Tensor, torch.arange(0, self.block_size, device = cids.device)[None,:,None]).reshape(num_nblocks * self.block_size, num_edges) ch_mars = element_mars[cids] - maxval = ch_mars.max(dim = 1, keepdim = True).values - node_mars[nids] = (((ch_mars - maxval).exp() * params[pids].unsqueeze(-1)).sum( - dim = 1).clamp(min = 1e-10)).log() + maxval.squeeze(1) + + if propagation_alg_id == 0: + maxval = ch_mars.max(dim = 1, keepdim = True).values + node_mars[nids] = (((ch_mars - maxval).exp() * params[pids].unsqueeze(-1)).sum( + dim = 1).clamp(min = 1e-10)).log() + maxval.squeeze(1) + + elif propagation_alg_id == 1: + node_mars[nids] = (ch_mars + params[pids].log().unsqueeze(-1)).max(dim = 1).values + + elif propagation_alg_id == 2: + maxval = ch_mars.max(dim = 1, keepdim = True).values + node_mars[nids] = ((((ch_mars - maxval).exp() * params[pids].unsqueeze(-1)) ** alpha).sum( + dim = 1).clamp(min = 1e-10)).log() ** (1.0 / alpha) + maxval.squeeze(1) return None def _forward_pytorch(node_mars: torch.Tensor, element_mars: torch.Tensor, params: torch.Tensor, nids: torch.Tensor, cids: torch.Tensor, pids: torch.Tensor, - local_ids: torch.Tensor): + local_ids: torch.Tensor, propagation_alg: str = "LL", **kwargs): + + # Propagation algorithm + propagation_alg_id = self.propagation_alg_mapping[propagation_alg] + propagation_alg_kwargs = self._get_propagation_alg_kwargs(propagation_alg, **kwargs) self._forward_pytorch_kernel( - node_mars, element_mars, params, nids, cids, pids, local_ids + node_mars, element_mars, params, nids, cids, pids, local_ids, + propagation_alg_id = propagation_alg_id, **propagation_alg_kwargs ) def _backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, From 4d6674d3fc6979d27c38862a5233a3160969d96e Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Sun, 10 Mar 2024 04:10:29 +0800 Subject: [PATCH 04/53] `MPE` and `GeneralLL` for backward pass (block sparse kernels) --- src/pyjuice/layer/sum_layer.py | 405 +++++++++++++++++++-------- src/pyjuice/model/tensorcircuit.py | 5 +- tests/layer/propagation_algs_test.py | 98 ++++++- 3 files changed, 394 insertions(+), 114 deletions(-) diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index de64ce8c..4db2f049 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -254,7 +254,7 @@ def forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, params: t def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, node_mars: torch.Tensor, element_mars: torch.Tensor, params: torch.Tensor, param_flows: Optional[torch.Tensor] = None, - allow_modify_flows: bool = False) -> None: + allow_modify_flows: bool = False, propagation_alg: str = "LL", **kwargs) -> None: """ Computes the forward pass of a sum layer: ``` @@ -286,7 +286,8 @@ def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, nids = self.partitioned_nids[partition_id] self._bk_triton_modify_flow( - node_flows, node_mars, nids, local_ids = None + node_flows, node_mars, nids, local_ids = None, + propagation_alg = propagation_alg, **kwargs ) ## Compute flows w.r.t. elements (i.e., product nodes) ## @@ -304,7 +305,9 @@ def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, chids = chids, parids = parids, parpids = parpids, cs_block_size = cs_block_size, partition_id = partition_id, - allow_modify_flows = allow_modify_flows + allow_modify_flows = allow_modify_flows, + propagation_alg = propagation_alg, + **kwargs ) else: @@ -322,7 +325,9 @@ def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, chids = chids, parids = parids, parpids = parpids, cs_block_size = cs_block_size, local_ids = local_ids, partition_id = partition_id, - allow_modify_flows = allow_modify_flows + allow_modify_flows = allow_modify_flows, + propagation_alg = propagation_alg, + **kwargs ) ## Compute flows w.r.t. sum parameters ## @@ -338,7 +343,9 @@ def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, element_mars, param_flows, nids = nids, cids = cids, pids = pids, pfids = pfids, partition_id = partition_id, - allow_modify_flows = allow_modify_flows + allow_modify_flows = allow_modify_flows, + propagation_alg = propagation_alg, + **kwargs ) return None @@ -1202,7 +1209,8 @@ def _backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, parpids: Optional[torch.Tensor] = None, cs_block_size: int = 0, local_ids: Optional[torch.Tensor] = None, partition_id: int = -1, mode: Optional[str] = None, - allow_modify_flows: bool = False) -> None: + allow_modify_flows: bool = False, + propagation_alg: str = "LL", **kwargs) -> None: """ Back pass of sum layers. @@ -1246,14 +1254,16 @@ def _backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, self._backward_block_sparse( node_flows, element_flows, params, node_mars, element_mars, param_flows, nids, cids, pids, pfids, chids, parids, parpids, cs_block_size, local_ids, - partition_id = partition_id, allow_modify_flows = allow_modify_flows + partition_id = partition_id, allow_modify_flows = allow_modify_flows, + propagation_alg = propagation_alg, **kwargs ) elif mode == self.SPARSE: self._backward_sparse( node_flows, element_flows, params, node_mars, element_mars, param_flows, nids, cids, pids, pfids, chids, parids, parpids, cs_block_size, local_ids, - partition_id = partition_id, allow_modify_flows = allow_modify_flows + partition_id = partition_id, allow_modify_flows = allow_modify_flows, + propagation_alg = propagation_alg, **kwargs ) elif mode == self.PYTORCH: @@ -1262,7 +1272,8 @@ def _backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, self._backward_pytorch( node_flows, element_flows, params, node_mars, element_mars, param_flows, nids, cids, pids, pfids, - chids, parids, parpids, cs_block_size + chids, parids, parpids, cs_block_size, + propagation_alg = propagation_alg, **kwargs ) else: raise ValueError(f"Not supported mode `{mode}`.") @@ -1273,7 +1284,7 @@ def _backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, # @triton.jit @FastJITFunction def _bk_triton_modify_flow_kernel(node_flows, node_mars, local_ids, nids, batch_size: tl.constexpr, partial_eval: tl.constexpr, - BLOCK_B: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr): + BLOCK_B: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, propagation_alg_id: tl.constexpr, alpha = 0.0): pid_b = tl.program_id(0) # ID of size-`BLOCK_B` examples pid_m = tl.program_id(1) # ID of size-`BLOCK_M` nodes @@ -1298,7 +1309,14 @@ def _bk_triton_modify_flow_kernel(node_flows, node_mars, local_ids, nids, batch_ nmars = tl.load(node_mars + offs_nmfs, mask = mask_batch[None,:]) nflows = tl.load(node_flows + offs_nmfs, mask = mask_batch[None,:]) - uflows = tl.where(nmars != -float("inf"), tl.log(nflows) - nmars, -float("inf")) + if propagation_alg_id == 0: + uflows = tl.where(nmars != -float("inf"), tl.log(nflows) - nmars, -float("inf")) + + if propagation_alg_id == 1: + uflows = nflows + + if propagation_alg_id == 2: + uflows = tl.where(nmars != -float("inf"), tl.log(nflows) - nmars * alpha, -float("inf")) tl.store(node_flows + offs_nmfs, uflows, mask = mask_batch[None,:]) @@ -1306,7 +1324,7 @@ def _bk_triton_modify_flow_kernel(node_flows, node_mars, local_ids, nids, batch_ # @triton.jit @FastJITFunction def _bk_triton_large_modify_flow_kernel(node_flows, node_mars, local_ids, nids, num_nodes, batch_size: tl.constexpr, partial_eval: tl.constexpr, - BLOCK_B: tl.constexpr, TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr): + BLOCK_B: tl.constexpr, TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, propagation_alg_id: tl.constexpr, alpha = 0.0): pid_b = tl.program_id(0) # ID of size-`BLOCK_B` examples pid_m = tl.program_id(1) # ID of size-`TILE_SIZE_M` nodes @@ -1333,12 +1351,20 @@ def _bk_triton_large_modify_flow_kernel(node_flows, node_mars, local_ids, nids, nmars = tl.load(node_mars + offs_nmfs, mask = (mask_m[:,None] & mask_batch[None,:])) nflows = tl.load(node_flows + offs_nmfs, mask = (mask_m[:,None] & mask_batch[None,:])) - uflows = tl.where(nmars != -float("inf"), tl.log(nflows) - nmars, -float("inf")) + if propagation_alg_id == 0: + uflows = tl.where(nmars != -float("inf"), tl.log(nflows) - nmars, -float("inf")) + + if propagation_alg_id == 1: + uflows = nflows + + if propagation_alg_id == 2: + uflows = tl.where(nmars != -float("inf"), tl.log(nflows) - nmars * alpha, -float("inf")) tl.store(node_flows + offs_nmfs, uflows, mask = (mask_m[:,None] & mask_batch[None,:])) def _bk_triton_modify_flow(self, node_flows: torch.Tensor, node_mars: torch.Tensor, - nids: torch.Tensor, local_ids: Optional[torch.Tensor] = None): + nids: torch.Tensor, local_ids: Optional[torch.Tensor] = None, + propagation_alg_id: str = "LL", **kwargs): """ Replace `node_flows[nids]` with `node_flows[nids].log() - node_mars[nids]` """ @@ -1348,6 +1374,10 @@ def _bk_triton_modify_flow(self, node_flows: torch.Tensor, node_mars: torch.Tens batch_size = node_mars.size(1) BATCH_SIZE_NP2 = triton.next_power_of_2(batch_size) + # Propagation algorithm + propagation_alg_id = self.propagation_alg_mapping[propagation_alg] + propagation_alg_kwargs = self._get_propagation_alg_kwargs(propagation_alg, **kwargs) + if triton.cdiv(layer_n_nodes, self.block_size) <= 4096: if BATCH_SIZE_NP2 >= 64 and self.block_size >= 64: @@ -1371,7 +1401,9 @@ def _bk_triton_modify_flow(self, node_flows: torch.Tensor, node_mars: torch.Tens partial_eval = partial_eval, BLOCK_B = BLOCK_B, BLOCK_M = BLOCK_M, - BLOCK_SIZE_M = BLOCK_SIZE_M + BLOCK_SIZE_M = BLOCK_SIZE_M, + propagation_alg_id = propagation_alg_id, + **kwargs ) else: @@ -1394,7 +1426,9 @@ def _bk_triton_modify_flow(self, node_flows: torch.Tensor, node_mars: torch.Tens partial_eval = partial_eval, BLOCK_B = BLOCK_B, TILE_SIZE_M = TILE_SIZE_M, - BLOCK_SIZE_M = BLOCK_SIZE_M + BLOCK_SIZE_M = BLOCK_SIZE_M, + propagation_alg_id = propagation_alg_id, + **kwargs ) return None @@ -1405,7 +1439,7 @@ def _backward_block_sparse(self, node_flows: torch.Tensor, element_flows: torch. nids: Optional[torch.Tensor], cids: Optional[torch.Tensor], pids: Optional[torch.Tensor], pfids: Optional[torch.Tensor], chids: Optional[torch.Tensor], parids: Optional[torch.Tensor], parpids: Optional[torch.Tensor], cs_block_size: int, local_ids: Optional[torch.Tensor] = None, - partition_id: int = -1, allow_modify_flows: bool = False) -> None: + partition_id: int = -1, allow_modify_flows: bool = False, propagation_alg: str = "LL", **kwargs) -> None: """ Back pass of sum layers with block-sparse processing kernel. @@ -1427,7 +1461,8 @@ def _backward_block_sparse(self, node_flows: torch.Tensor, element_flows: torch. node_flows, element_flows, params, node_mars, element_mars, chids = chids, parids = parids, parpids = parpids, cs_block_size = cs_block_size, local_ids = local_ids, - partition_id = partition_id, allow_modify_flows = allow_modify_flows + partition_id = partition_id, allow_modify_flows = allow_modify_flows, + propagation_alg = propagation_alg, **kwargs ) # Flows w.r.t. parameters @@ -1435,7 +1470,8 @@ def _backward_block_sparse(self, node_flows: torch.Tensor, element_flows: torch. self._backward_block_sparse_par_flows( node_flows, params, node_mars, element_mars, param_flows, nids = nids, cids = cids, pids = pids, pfids = pfids, - allow_modify_flows = allow_modify_flows + allow_modify_flows = allow_modify_flows, + propagation_alg = propagation_alg, **kwargs ) return None @@ -1448,7 +1484,7 @@ def _bk_triton_block_sparse_ele_kernel(node_flows, element_flows, node_mars, ele local_ids, batch_size: tl.constexpr, partial_eval: tl.constexpr, ptr_inc_step: tl.constexpr, allow_modify_flows: tl.constexpr, BLOCK_B: tl.constexpr, TILE_SIZE_K: tl.constexpr, K_NUM_TILES: tl.constexpr, TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, TL_DOT: tl.constexpr): + BLOCK_SIZE_K: tl.constexpr, TL_DOT: tl.constexpr, propagation_alg_id: tl.constexpr, alpha = 0.0): pid_b = tl.program_id(0) # ID of size-`BLOCK_B` batches pid_m = tl.program_id(1) # ID of size-`TILE_SIZE_M` nodes @@ -1488,40 +1524,66 @@ def _bk_triton_block_sparse_ele_kernel(node_flows, element_flows, node_mars, ele parids_inc_ptr = parids_increment + eleblock_id * (K_NUM_TILES * ptr_inc_step) + offs_edge_gid parpids_inc_ptr = parpids_increment + eleblock_id * (K_NUM_TILES * ptr_inc_step) + offs_edge_gid + # Initialize pointers to `element_mars` (only when using MPE propagation method) + if propagation_alg_id == 1: + off_eleids = tl.load(chids + eleblock_id) + emars_ptr = element_mars + (off_eleids + offs_ele[:,None]) * batch_size + offs_batch[None,:] + emars = tl.load(emars_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_M, BLOCK_B] + # Inner loop - acc = tl.zeros([TILE_SIZE_M, BLOCK_B], dtype = tl.float32) - float("inf") + if propagation_alg_id != 1: + acc = tl.zeros([TILE_SIZE_M, BLOCK_B], dtype = tl.float32) - float("inf") + else: + acc = tl.zeros([TILE_SIZE_M, BLOCK_B], dtype = tl.float32) for k in range(0, K_NUM_TILES): epars = tl.load(epars_ptr) # [TILE_SIZE_M, TILE_SIZE_K] - if allow_modify_flows == 1: - log_n_fdm = tl.load(nflows_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_K, BLOCK_B] - else: + if propagation_alg_id == 1: nflows = tl.load(nflows_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_K, BLOCK_B] nmars = tl.load(nmars_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_K, BLOCK_B] - log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars) + elpars = tl.log(tl.trans(epars)) # [TILE_SIZE_K, TILE_SIZE_M] - log_n_fdm_max = tl.max(log_n_fdm, axis = 0)[None,:] - n_fdm_sub = tl.where(log_n_fdm_max != -float("inf"), tl.exp(log_n_fdm - log_n_fdm_max), 0.0) + acc += tl.sum(tl.where(tl.abs(elpars[:,:,None] + emars[None,:,:] - nmars[:,None,:]) < 1e-6, nflows[:,None,:], 0.0), axis = 0) - if TL_DOT == 1: - partial_flows = tl.dot(epars, n_fdm_sub) else: - partial_flows = tl.sum(epars[:,:,None] * n_fdm_sub[None,:,:], axis = 1) - acc = tl.where(log_n_fdm_max == acc, - acc + 0.69314718056, # log(2) - tl.where(log_n_fdm_max > acc, - tl.log(partial_flows + tl.exp(acc - log_n_fdm_max)) + log_n_fdm_max, - tl.log(tl.exp(log_n_fdm_max - acc) * partial_flows + 1.0) + acc + if propagation_alg_id == 2: + epars = tl.exp(tl.log(epars) * alpha) + + if allow_modify_flows == 1: + log_n_fdm = tl.load(nflows_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_K, BLOCK_B] + else: + nflows = tl.load(nflows_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_K, BLOCK_B] + nmars = tl.load(nmars_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_K, BLOCK_B] + + if propagation_alg_id == 0: + log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars) + + if propagation_alg_id == 2: + log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars * alpha) + + log_n_fdm_max = tl.max(log_n_fdm, axis = 0)[None,:] + n_fdm_sub = tl.where(log_n_fdm_max != -float("inf"), tl.exp(log_n_fdm - log_n_fdm_max), 0.0) + + if TL_DOT == 1: + partial_flows = tl.dot(epars, n_fdm_sub) + else: + partial_flows = tl.sum(epars[:,:,None] * n_fdm_sub[None,:,:], axis = 1) + + acc = tl.where(log_n_fdm_max == acc, + acc + 0.69314718056, # log(2) + tl.where(log_n_fdm_max > acc, + tl.log(partial_flows + tl.exp(acc - log_n_fdm_max)) + log_n_fdm_max, + tl.log(tl.exp(log_n_fdm_max - acc) * partial_flows + 1.0) + acc + ) ) - ) - # neginf_flag = (log_n_fdm_max == -float("inf")) & (acc == -float("inf")) - # acc = tl.where(log_n_fdm_max > acc, - # tl.log(partial_flows + tl.exp(acc - log_n_fdm_max)) + log_n_fdm_max, - # tl.log(tl.exp(log_n_fdm_max - acc) * partial_flows + 1.0) + acc - # ) - # acc = tl.where(neginf_flag, -float("inf"), acc) + # neginf_flag = (log_n_fdm_max == -float("inf")) & (acc == -float("inf")) + # acc = tl.where(log_n_fdm_max > acc, + # tl.log(partial_flows + tl.exp(acc - log_n_fdm_max)) + log_n_fdm_max, + # tl.log(tl.exp(log_n_fdm_max - acc) * partial_flows + 1.0) + acc + # ) + # acc = tl.where(neginf_flag, -float("inf"), acc) # Increment `epars_ptr` parpids_inc = tl.load(parpids_inc_ptr) @@ -1535,12 +1597,19 @@ def _bk_triton_block_sparse_ele_kernel(node_flows, element_flows, node_mars, ele nflows_ptr += parids_inc[:,None] * batch_size parids_inc_ptr += ptr_inc_step - # Initialize pointers to `element_mars` - off_eleids = tl.load(chids + eleblock_id) - emars_ptr = element_mars + (off_eleids + offs_ele[:,None]) * batch_size + offs_batch[None,:] - emars = tl.load(emars_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_K, BLOCK_B] + # Initialize pointers to `element_mars` (only when NOT using MPE propagation method) + if propagation_alg_id != 1: + off_eleids = tl.load(chids + eleblock_id) + emars_ptr = element_mars + (off_eleids + offs_ele[:,None]) * batch_size + offs_batch[None,:] + emars = tl.load(emars_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_M, BLOCK_B] - eflows = tl.exp(acc + emars) + if propagation_alg_id == 2: + emars *= alpha + + if propagation_alg_id != 1: + eflows = tl.exp(acc + emars) + else: + eflows = acc # Write back offs_elemfs = (off_eleids + offs_ele[:,None]) * batch_size + offs_batch[None,:] @@ -1554,7 +1623,7 @@ def _bk_triton_block_sparse_ele_csmm2_kernel(node_flows, element_flows, node_mar local_ids, batch_size: tl.constexpr, partial_eval: tl.constexpr, ptr_inc_step: tl.constexpr, allow_modify_flows: tl.constexpr, BLOCK_B: tl.constexpr, TILE_SIZE_K: tl.constexpr, K_NUM_TILES: tl.constexpr, TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, TL_DOT: tl.constexpr): + BLOCK_SIZE_K: tl.constexpr, TL_DOT: tl.constexpr, propagation_alg_id: tl.constexpr, alpha = 0.0): pid_b = tl.program_id(0) # ID of size-`BLOCK_B` batches pid_m = tl.program_id(1) # ID of size-`TILE_SIZE_M` nodes @@ -1594,37 +1663,65 @@ def _bk_triton_block_sparse_ele_csmm2_kernel(node_flows, element_flows, node_mar parids_inc_ptr = parids_increment + eleblock_id * (K_NUM_TILES * ptr_inc_step) + offs_edge_gid parpids_inc_ptr = parpids_increment + eleblock_id * (K_NUM_TILES * ptr_inc_step) + offs_edge_gid + # Initialize pointers to `element_mars` + if propagation_alg_id == 1: + off_eleids = tl.load(chids + eleblock_id) + emars_ptr = element_mars + (off_eleids + offs_ele[:,None]) * batch_size + offs_batch[None,:] + emars = tl.load(emars_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_M, BLOCK_B] + # Inner loop - acc = tl.zeros([TILE_SIZE_M, BLOCK_B], dtype = tl.float32) - float("inf") + if propagation_alg_id != 1: + acc = tl.zeros([TILE_SIZE_M, BLOCK_B], dtype = tl.float32) - float("inf") + else: + acc = tl.zeros([TILE_SIZE_M, BLOCK_B], dtype = tl.float32) for k in range(0, K_NUM_TILES): epars = tl.load(epars_ptr) # [TILE_SIZE_M, TILE_SIZE_K] - if allow_modify_flows == 1: - log_n_fdm = tl.load(nflows_ptr, mask = mask_batch[:,None]) # [BLOCK_B, TILE_SIZE_K] - else: + if propagation_alg_id == 1: nflows = tl.load(nflows_ptr, mask = mask_batch[:,None]) # [BLOCK_B, TILE_SIZE_K] nmars = tl.load(nmars_ptr, mask = mask_batch[:,None]) # [BLOCK_B, TILE_SIZE_K] - log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars) + elpars = tl.log(tl.trans(epars)) # [TILE_SIZE_K, TILE_SIZE_M] + + eflows = tl.sum(tl.where(tl.abs(elpars[None,:,:] + tl.trans(emars)[:,None,:] - nmars[:,:,None]) < 1e-6, nflows[:,:,None], 0.0), axis = 1) + + acc += tl.trans(eflows) + + else: - log_n_fdm_max = tl.max(log_n_fdm, axis = 1) - n_fdm_sub = tl.where(log_n_fdm_max[:,None] != -float("inf"), tl.exp(log_n_fdm - log_n_fdm_max[:,None]), 0.0) + if propagation_alg_id == 2: + epars = tl.exp(tl.log(epars) * alpha) - partial_flows = tl.sum(epars[:,:,None] * tl.trans(n_fdm_sub)[None,:,:], axis = 1) + if allow_modify_flows == 1: + log_n_fdm = tl.load(nflows_ptr, mask = mask_batch[:,None]) # [BLOCK_B, TILE_SIZE_K] + else: + nflows = tl.load(nflows_ptr, mask = mask_batch[:,None]) # [BLOCK_B, TILE_SIZE_K] + nmars = tl.load(nmars_ptr, mask = mask_batch[:,None]) # [BLOCK_B, TILE_SIZE_K] - acc = tl.where(log_n_fdm_max[None,:] == acc, - acc + 0.69314718056, # log(2) - tl.where(log_n_fdm_max[None,:] > acc, - tl.log(partial_flows + tl.exp(acc - log_n_fdm_max[None,:])) + log_n_fdm_max[None,:], - tl.log(tl.exp(log_n_fdm_max[None,:] - acc) * partial_flows + 1.0) + acc + if propagation_alg_id == 0: + log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars) + + if propagation_alg_id == 2: + log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars * alpha) + + log_n_fdm_max = tl.max(log_n_fdm, axis = 1) + n_fdm_sub = tl.where(log_n_fdm_max[:,None] != -float("inf"), tl.exp(log_n_fdm - log_n_fdm_max[:,None]), 0.0) + + partial_flows = tl.sum(epars[:,:,None] * tl.trans(n_fdm_sub)[None,:,:], axis = 1) + + acc = tl.where(log_n_fdm_max[None,:] == acc, + acc + 0.69314718056, # log(2) + tl.where(log_n_fdm_max[None,:] > acc, + tl.log(partial_flows + tl.exp(acc - log_n_fdm_max[None,:])) + log_n_fdm_max[None,:], + tl.log(tl.exp(log_n_fdm_max[None,:] - acc) * partial_flows + 1.0) + acc + ) ) - ) - # neginf_flag = (log_n_fdm_max[None,:] == -float("inf")) & (acc == -float("inf")) - # acc = tl.where(log_n_fdm_max[None,:] > acc, - # tl.log(partial_flows + tl.exp(acc - log_n_fdm_max[None,:])) + log_n_fdm_max[None,:], - # tl.log(tl.exp(log_n_fdm_max[None,:] - acc) * partial_flows + 1.0) + acc - # ) - # acc = tl.where(neginf_flag, -float("inf"), acc) + # neginf_flag = (log_n_fdm_max[None,:] == -float("inf")) & (acc == -float("inf")) + # acc = tl.where(log_n_fdm_max[None,:] > acc, + # tl.log(partial_flows + tl.exp(acc - log_n_fdm_max[None,:])) + log_n_fdm_max[None,:], + # tl.log(tl.exp(log_n_fdm_max[None,:] - acc) * partial_flows + 1.0) + acc + # ) + # acc = tl.where(neginf_flag, -float("inf"), acc) # Increment `epars_ptr` parpids_inc = tl.load(parpids_inc_ptr) @@ -1639,11 +1736,18 @@ def _bk_triton_block_sparse_ele_csmm2_kernel(node_flows, element_flows, node_mar parids_inc_ptr += ptr_inc_step # Initialize pointers to `element_mars` - off_eleids = tl.load(chids + eleblock_id) - emars_ptr = element_mars + (off_eleids + offs_ele[:,None]) * batch_size + offs_batch[None,:] - emars = tl.load(emars_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_K, BLOCK_B] + if propagation_alg_id != 1: + off_eleids = tl.load(chids + eleblock_id) + emars_ptr = element_mars + (off_eleids + offs_ele[:,None]) * batch_size + offs_batch[None,:] + emars = tl.load(emars_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_M, BLOCK_B] + + if propagation_alg_id == 2: + emars *= alpha - eflows = tl.exp(acc + emars) + if propagation_alg_id != 1: + eflows = tl.exp(acc + emars) + else: + eflows = acc # Write back offs_elemfs = (off_eleids + offs_ele[:,None]) * batch_size + offs_batch[None,:] @@ -1653,7 +1757,8 @@ def _backward_block_sparse_ele_flows(self, node_flows: torch.Tensor, element_flo params: torch.Tensor, node_mars: torch.Tensor, element_mars: torch.Tensor, chids: torch.Tensor, parids: torch.Tensor, parpids: torch.Tensor, cs_block_size: int, local_ids: Optional[torch.Tensor] = None, - partition_id: int = -1, allow_modify_flows: bool = False) -> None: + partition_id: int = -1, allow_modify_flows: bool = False, + propagation_alg: str = "LL", **kwargs) -> None: assert params.dim() == 1, "Expecting a 1D `params`." @@ -1663,6 +1768,10 @@ def _backward_block_sparse_ele_flows(self, node_flows: torch.Tensor, element_flo batch_size = node_flows.size(1) BATCH_SIZE_NP2 = triton.next_power_of_2(batch_size) + # Propagation algorithm + propagation_alg_id = self.propagation_alg_mapping[propagation_alg] + propagation_alg_kwargs = self._get_propagation_alg_kwargs(propagation_alg, **kwargs) + # Heuristic to set `TILE_SIZE_M`, `TILE_SIZE_K`, and `BLOCK_B` base_size = min(self.block_size, num_edges, BATCH_SIZE_NP2, 64) if base_size >= 64: @@ -1752,7 +1861,9 @@ def _backward_block_sparse_ele_flows(self, node_flows: torch.Tensor, element_flo BLOCK_SIZE_K = BLOCK_SIZE_K, TL_DOT = TL_DOT, num_warps = 2, # TODO: test for different devices - num_stages = 1 + num_stages = 1, + propagation_alg_id = propagation_alg_id, + **propagation_alg_kwargs ) else: self._bk_triton_block_sparse_ele_csmm2_kernel[grid]( @@ -1779,7 +1890,9 @@ def _backward_block_sparse_ele_flows(self, node_flows: torch.Tensor, element_flo BLOCK_SIZE_K = BLOCK_SIZE_K, TL_DOT = TL_DOT, num_warps = 2, # TODO: test for different devices - num_stages = 1 + num_stages = 1, + propagation_alg_id = propagation_alg_id, + **propagation_alg_kwargs ) return None @@ -1790,7 +1903,8 @@ def _backward_block_sparse_ele_flows(self, node_flows: torch.Tensor, element_flo def _bk_triton_block_sparse_par_kernel(node_flows, node_mars, element_mars, params, param_flows, nids, cids, pids, pfids, batch_size: tl.constexpr, num_edges: tl.constexpr, allow_modify_flows: tl.constexpr, TILE_SIZE_B: tl.constexpr, B_NUM_TILES: tl.constexpr, TILE_SIZE_K: tl.constexpr, - TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, TL_DOT: tl.constexpr): + TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, TL_DOT: tl.constexpr, + propagation_alg_id: tl.constexpr, alpha = 0.0): pid_k = tl.program_id(0) # ID of size-`TILE_SIZE_K` edges pid_m = tl.program_id(1) # ID of size-`TILE_SIZE_M` nodes @@ -1816,30 +1930,53 @@ def _bk_triton_block_sparse_par_kernel(node_flows, node_mars, element_mars, para nmars_ptr = node_mars + (off_nids + offs_node[:,None]) * batch_size + offs_batch[None,:] nflows_ptr = node_flows + (off_nids + offs_node[:,None]) * batch_size + offs_batch[None,:] + # Initialize `params` (only when using MPE propagation method) + if propagation_alg_id == 1: + par_start = tl.load(pids + nblock_id * num_edges + offs_edge) + epars_offsets = offs_node[:,None] + par_start[None,:] # [TILE_SIZE_M, TILE_SIZE_K] + epars = tl.load(params + epars_offsets) + elpars = tl.log(epars) + # Inner loop acc = tl.zeros([TILE_SIZE_M, TILE_SIZE_K], dtype = tl.float32) for b in range(0, B_NUM_TILES): emars = tl.load(emars_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_K] - if allow_modify_flows == 1: - log_n_fdm = tl.load(nflows_ptr, mask = mask_batch[None,:], other = -float("inf")) # [TILE_SIZE_M, TILE_SIZE_B] - else: + if propagation_alg_id == 1: nflows = tl.load(nflows_ptr, mask = mask_batch[None,:], other = 0.0) # [TILE_SIZE_M, TILE_SIZE_B] nmars = tl.load(nmars_ptr, mask = mask_batch[None,:], other = 0.0) # [TILE_SIZE_M, TILE_SIZE_B] - log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars) - - log_n_fdm_max = tl.max(log_n_fdm, axis = 0) - n_fdm_sub = tl.where(log_n_fdm_max[None,:] != -float("inf"), tl.exp(log_n_fdm - log_n_fdm_max[None,:]), 0.0) - scaled_emars = tl.exp(emars + log_n_fdm_max[:,None]) + acc += tl.sum(tl.where(tl.abs(elpars[None,:,:] + emars[:,None,:] - tl.trans(nmars)[:,:,None]) < 1e-6, tl.trans(nflows)[:,:,None], 0.0), axis = 0) - if TL_DOT == 1: - partial_flows = tl.dot(n_fdm_sub, scaled_emars) else: - partial_flows = tl.sum(n_fdm_sub[:,:,None] * scaled_emars[None,:,:], axis = 1) - acc += partial_flows + if propagation_alg_id == 2: + emars *= alpha + + if allow_modify_flows == 1: + log_n_fdm = tl.load(nflows_ptr, mask = mask_batch[None,:], other = -float("inf")) # [TILE_SIZE_M, TILE_SIZE_B] + else: + nflows = tl.load(nflows_ptr, mask = mask_batch[None,:], other = 0.0) # [TILE_SIZE_M, TILE_SIZE_B] + nmars = tl.load(nmars_ptr, mask = mask_batch[None,:], other = 0.0) # [TILE_SIZE_M, TILE_SIZE_B] + + if propagation_alg_id == 0: + log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars) + + if propagation_alg_id == 2: + log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars * alpha) + + log_n_fdm_max = tl.max(log_n_fdm, axis = 0) + n_fdm_sub = tl.where(log_n_fdm_max[None,:] != -float("inf"), tl.exp(log_n_fdm - log_n_fdm_max[None,:]), 0.0) + + scaled_emars = tl.exp(emars + log_n_fdm_max[:,None]) + + if TL_DOT == 1: + partial_flows = tl.dot(n_fdm_sub, scaled_emars) + else: + partial_flows = tl.sum(n_fdm_sub[:,:,None] * scaled_emars[None,:,:], axis = 1) + + acc += partial_flows # Increment `emars_ptr`, `nmars_ptr`, and `nmars_ptr` emars_ptr += TILE_SIZE_B @@ -1851,12 +1988,19 @@ def _bk_triton_block_sparse_par_kernel(node_flows, node_mars, element_mars, para offs_batch += TILE_SIZE_B mask_batch = offs_batch < batch_size - # Initialize `params` - par_start = tl.load(pids + nblock_id * num_edges + offs_edge) - epars_offsets = offs_node[:,None] + par_start[None,:] # [TILE_SIZE_M, TILE_SIZE_K] - epars = tl.load(params + epars_offsets) + # Initialize `params` (only when NOT using MPE propagation method) + if propagation_alg_id != 1: + par_start = tl.load(pids + nblock_id * num_edges + offs_edge) + epars_offsets = offs_node[:,None] + par_start[None,:] # [TILE_SIZE_M, TILE_SIZE_K] + epars = tl.load(params + epars_offsets) - pflows = acc * epars + if propagation_alg_id == 2: + epars = tl.exp(tl.log(epars) * alpha) + + if propagation_alg_id != 1: + pflows = acc * epars + else: + pflows = acc parflow_start = tl.load(pfids + nblock_id * num_edges + offs_edge) eparflows_offsets = offs_node[:,None] + parflow_start[None,:] # [TILE_SIZE_M, TILE_SIZE_K] @@ -1869,7 +2013,8 @@ def _bk_triton_block_sparse_par_kernel(node_flows, node_mars, element_mars, para def _bk_triton_block_sparse_par_csmm2_kernel(node_flows, node_mars, element_mars, params, param_flows, nids, cids, pids, pfids, batch_size: tl.constexpr, num_edges: tl.constexpr, allow_modify_flows: tl.constexpr, TILE_SIZE_B: tl.constexpr, B_NUM_TILES: tl.constexpr, TILE_SIZE_K: tl.constexpr, - TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, TL_DOT: tl.constexpr): + TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, TL_DOT: tl.constexpr, + propagation_alg_id: tl.constexpr, alpha = 0.0): pid_k = tl.program_id(0) # ID of size-`TILE_SIZE_K` edges pid_m = tl.program_id(1) # ID of size-`TILE_SIZE_M` nodes @@ -1895,27 +2040,50 @@ def _bk_triton_block_sparse_par_csmm2_kernel(node_flows, node_mars, element_mars nmars_ptr = node_mars + (off_nids + offs_node[None,:]) * batch_size + offs_batch[:,None] nflows_ptr = node_flows + (off_nids + offs_node[None,:]) * batch_size + offs_batch[:,None] + # Initialize `params` (only when using MPE propagation method) + if propagation_alg_id == 1: + par_start = tl.load(pids + nblock_id * num_edges + offs_edge) + epars_offsets = offs_node[:,None] + par_start[None,:] # [TILE_SIZE_M, TILE_SIZE_K] + epars = tl.load(params + epars_offsets) + elpars = tl.log(epars) + # Inner loop acc = tl.zeros([TILE_SIZE_M, TILE_SIZE_K], dtype = tl.float32) for b in range(0, B_NUM_TILES): emars = tl.load(emars_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_K] - if allow_modify_flows == 1: - log_n_fdm = tl.load(nflows_ptr, mask = mask_batch[:,None], other = -float("inf")) # [TILE_SIZE_B, TILE_SIZE_M] - else: + if propagation_alg_id == 1: nflows = tl.load(nflows_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_M] nmars = tl.load(nmars_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_M] - log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars) - log_n_fdm_max = tl.max(log_n_fdm, axis = 1) - n_fdm_sub = tl.where(log_n_fdm_max[:,None] != -float("inf"), tl.exp(log_n_fdm - log_n_fdm_max[:,None]), 0.0) + acc += tl.sum(tl.where(tl.abs(elpars[None,:,:] + emars[:,None,:] - nmars[:,:,None]) < 1e-6, nflows[:,:,None], 0.0), axis = 0) + + else: + + if propagation_alg_id == 2: + emars *= alpha + + if allow_modify_flows == 1: + log_n_fdm = tl.load(nflows_ptr, mask = mask_batch[:,None], other = -float("inf")) # [TILE_SIZE_B, TILE_SIZE_M] + else: + nflows = tl.load(nflows_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_M] + nmars = tl.load(nmars_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_M] + + if propagation_alg_id == 0: + log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars) + + if propagation_alg_id == 2: + log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars * alpha) - scaled_emars = tl.exp(emars + log_n_fdm_max[:,None]) + log_n_fdm_max = tl.max(log_n_fdm, axis = 1) + n_fdm_sub = tl.where(log_n_fdm_max[:,None] != -float("inf"), tl.exp(log_n_fdm - log_n_fdm_max[:,None]), 0.0) - partial_flows = tl.sum(tl.trans(n_fdm_sub)[:,:,None] * scaled_emars[None,:,:], axis = 1) + scaled_emars = tl.exp(emars + log_n_fdm_max[:,None]) - acc += partial_flows + partial_flows = tl.sum(tl.trans(n_fdm_sub)[:,:,None] * scaled_emars[None,:,:], axis = 1) + + acc += partial_flows # Increment `emars_ptr`, `nmars_ptr`, and `nmars_ptr` emars_ptr += TILE_SIZE_B @@ -1927,12 +2095,19 @@ def _bk_triton_block_sparse_par_csmm2_kernel(node_flows, node_mars, element_mars offs_batch += TILE_SIZE_B mask_batch = offs_batch < batch_size - # Initialize `params` - par_start = tl.load(pids + nblock_id * num_edges + offs_edge) - epars_offsets = offs_node[:,None] + par_start[None,:] # [TILE_SIZE_M, TILE_SIZE_K] - epars = tl.load(params + epars_offsets) + # Initialize `params` (only when NOT using MPE propagation method) + if propagation_alg_id != 1: + par_start = tl.load(pids + nblock_id * num_edges + offs_edge) + epars_offsets = offs_node[:,None] + par_start[None,:] # [TILE_SIZE_M, TILE_SIZE_K] + epars = tl.load(params + epars_offsets) + + if propagation_alg_id == 2: + epars = tl.exp(tl.log(epars) * alpha) - pflows = acc * epars + if propagation_alg_id != 1: + pflows = acc * epars + else: + pflows = acc parflow_start = tl.load(pfids + nblock_id * num_edges + offs_edge) eparflows_offsets = offs_node[:,None] + parflow_start[None,:] # [TILE_SIZE_M, TILE_SIZE_K] @@ -1942,7 +2117,7 @@ def _bk_triton_block_sparse_par_csmm2_kernel(node_flows, node_mars, element_mars def _backward_block_sparse_par_flows(self, node_flows: torch.Tensor, params: torch.Tensor, node_mars: torch.Tensor, element_mars: torch.Tensor, param_flows: torch.Tensor, nids: torch.Tensor, cids: torch.Tensor, pids: torch.Tensor, pfids: torch.Tensor, - allow_modify_flows: bool = False) -> None: + allow_modify_flows: bool = False, propagation_alg: str = "LL", **kwargs) -> None: """ Backward pass of sum layers w.r.t. sum parameters with the block-sparse processing kernel. @@ -1966,6 +2141,10 @@ def _backward_block_sparse_par_flows(self, node_flows: torch.Tensor, params: tor batch_size = node_mars.size(1) BATCH_SIZE_NP2 = triton.next_power_of_2(batch_size) + # Propagation algorithm + propagation_alg_id = self.propagation_alg_mapping[propagation_alg] + propagation_alg_kwargs = self._get_propagation_alg_kwargs(propagation_alg, **kwargs) + # Heuristic to set `TILE_SIZE_M`, `TILE_SIZE_K`, and `BLOCK_B` base_size = min(self.block_size, num_edges, BATCH_SIZE_NP2) if base_size >= 64: @@ -2009,7 +2188,9 @@ def _backward_block_sparse_par_flows(self, node_flows: torch.Tensor, params: tor TILE_SIZE_K = TILE_SIZE_K, TILE_SIZE_M = TILE_SIZE_M, BLOCK_SIZE_M = self.block_size, - TL_DOT = TL_DOT + TL_DOT = TL_DOT, + propagation_alg_id = propagation_alg_id, + **propagation_alg_kwargs ) else: self._bk_triton_block_sparse_par_csmm2_kernel[grid]( @@ -2030,7 +2211,9 @@ def _backward_block_sparse_par_flows(self, node_flows: torch.Tensor, params: tor TILE_SIZE_K = TILE_SIZE_K, TILE_SIZE_M = TILE_SIZE_M, BLOCK_SIZE_M = self.block_size, - TL_DOT = TL_DOT + TL_DOT = TL_DOT, + propagation_alg_id = propagation_alg_id, + **propagation_alg_kwargs ) def _backward_sparse(self, node_flows: torch.Tensor, element_flows: torch.Tensor, diff --git a/src/pyjuice/model/tensorcircuit.py b/src/pyjuice/model/tensorcircuit.py index 183a6a8f..a9e000f6 100644 --- a/src/pyjuice/model/tensorcircuit.py +++ b/src/pyjuice/model/tensorcircuit.py @@ -28,6 +28,7 @@ def _pc_model_backward_hook(grad, pc, inputs, record_cudagraph, apply_cudagraph, flows_memory = pc._optim_hyperparams["flows_memory"], record_cudagraph = record_cudagraph, apply_cudagraph = apply_cudagraph, + propagation_alg = propagation_alg, **kwargs ) @@ -252,6 +253,7 @@ def _run_inner_layers(): inputs = inputs, record_cudagraph = record_cudagraph, apply_cudagraph = apply_cudagraph, + propagation_alg = propagation_alg, **kwargs ) ) @@ -271,6 +273,7 @@ def backward(self, inputs: Optional[torch.Tensor] = None, record_cudagraph: bool = False, apply_cudagraph: bool = True, allow_modify_flows: bool = True, + propagation_alg: str = "LL", **kwargs): """ Backward evaluation of the PC that computes node flows as well as parameter flows. @@ -345,7 +348,7 @@ def _run_inner_layers(): # Backward sum layer layer_group.backward(self.node_flows, self.element_flows, self.node_mars, self.element_mars, self.params, param_flows = self.param_flows if compute_param_flows else None, - allow_modify_flows = allow_modify_flows) + allow_modify_flows = allow_modify_flows, propagation_alg = propagation_alg, **kwargs) else: raise ValueError(f"Unknown layer type {type(layer)}.") diff --git a/tests/layer/propagation_algs_test.py b/tests/layer/propagation_algs_test.py index cda8d94f..7928a3ac 100644 --- a/tests/layer/propagation_algs_test.py +++ b/tests/layer/propagation_algs_test.py @@ -20,7 +20,7 @@ def general_ll_prop_test(): batch_size = 16 - for block_size in [1, 4, 8, 16]: + for block_size in [4, 8, 16]: with juice.set_block_size(block_size): @@ -67,6 +67,53 @@ def general_ll_prop_test(): scaled_lls = (epars[:,None]**alpha * cmars**alpha).sum(dim = 0).log() * (1.0 / alpha) assert torch.all(torch.abs(node_mars[(j+1)*block_size+i,:] - scaled_lls) < 1e-3) + ## Backward pass ## + + node_flows = torch.rand([block_size + block_size * 2 * 3, batch_size]).to(device) + element_flows = torch.zeros([block_size + 3 * 2 * 2 * block_size, batch_size]).to(device) + + param_flows = torch.zeros([block_size ** 2 + 3 * 4 * block_size * block_size]).to(device) + + layer.backward(node_flows, element_flows, node_mars, element_mars, params, param_flows, + propagation_alg = "GeneralLL", alpha = alpha) + + chids = layer.partitioned_chids[0] + parids = layer.partitioned_parids[0] + parpids = layer.partitioned_parpids[0] + + num_nblocks = chids.size(0) + num_eblocks = parids.size(1) + parids = (parids[:,:,None].repeat(1, 1, block_size) + torch.arange(0, block_size, device = parids.device)).reshape(num_nblocks, num_eblocks * block_size) + parpids_start = (parpids[:,:,None] + torch.arange(0, block_size, device = parids.device)).reshape( + num_nblocks, num_eblocks * block_size) + + for j in range(6): + parpids = parpids_start.clone() + for i in range(block_size): + nmars = node_mars[parids[j,:]].exp() + nflows = node_flows[parids[j,:]] + emars = element_mars[(j+1)*block_size+i,:].exp() + epars = params[parpids[j,:]] + eflows = (nflows * (epars[:,None] * emars[None,:]) ** alpha / nmars ** alpha).sum(dim = 0) + + assert torch.all(torch.abs(eflows - element_flows[(j+1)*block_size+i,:]) < 1e-2) + + parpids += block_size + + my_pflows = torch.zeros_like(param_flows) + + for i in range(block_size): + for j in range(6): + emars = element_mars[layer.partitioned_cids[0][j,:]].exp() + epars = params[layer.partitioned_pids[0][j,:]+i] + nmars = node_mars[(j+1)*block_size+i,:].exp() + nflows = node_flows[(j+1)*block_size+i,:] + pflows = epars ** alpha * (nflows[None,:] * emars ** alpha / nmars[None,:] ** alpha).sum(dim = 1) + + my_pflows[layer.partitioned_pfids[0][j,:]+i] = pflows + + assert torch.all(torch.abs(my_pflows - param_flows) < 2e-3) + def mpe_prop_test(): @@ -74,7 +121,7 @@ def mpe_prop_test(): batch_size = 16 - for block_size in [1, 4, 8, 16]: + for block_size in [4, 8, 16]: with juice.set_block_size(block_size): @@ -118,6 +165,53 @@ def mpe_prop_test(): scaled_lls = (epars[:,None] * cmars).max(dim = 0).values.log() assert torch.all(torch.abs(node_mars[(j+1)*block_size+i,:] - scaled_lls) < 1e-3) + ## Backward pass ## + + node_flows = torch.rand([block_size + block_size * 2 * 3, batch_size]).to(device) + element_flows = torch.zeros([block_size + 3 * 2 * 2 * block_size, batch_size]).to(device) + + param_flows = torch.zeros([block_size ** 2 + 3 * 4 * block_size * block_size]).to(device) + + layer.backward(node_flows, element_flows, node_mars, element_mars, params, param_flows, + allow_modify_flows = False, propagation_alg = "MPE") + + chids = layer.partitioned_chids[0] + parids = layer.partitioned_parids[0] + parpids = layer.partitioned_parpids[0] + + num_nblocks = chids.size(0) + num_eblocks = parids.size(1) + parids = (parids[:,:,None].repeat(1, 1, block_size) + torch.arange(0, block_size, device = parids.device)).reshape(num_nblocks, num_eblocks * block_size) + parpids_start = (parpids[:,:,None] + torch.arange(0, block_size, device = parids.device)).reshape( + num_nblocks, num_eblocks * block_size) + + for j in range(6): + parpids = parpids_start.clone() + for i in range(block_size): + nmars = node_mars[parids[j,:]].exp() + nflows = node_flows[parids[j,:]] + emars = element_mars[(j+1)*block_size+i,:].exp() + epars = params[parpids[j,:]] + eflows = (nflows * (((epars[:,None] * emars[None,:]) - nmars).abs() < 1e-6).float()).sum(dim = 0) + + assert torch.all(torch.abs(eflows - element_flows[(j+1)*block_size+i,:]) < 1e-2) + + parpids += block_size + + my_pflows = torch.zeros_like(param_flows) + + for i in range(block_size): + for j in range(6): + emars = element_mars[layer.partitioned_cids[0][j,:]].exp() + epars = params[layer.partitioned_pids[0][j,:]+i] + nmars = node_mars[(j+1)*block_size+i,:].exp() + nflows = node_flows[(j+1)*block_size+i,:] + pflows = (nflows[None,:] * ((epars[:,None] * emars - nmars[None,:]).abs() < 1e-6).float()).sum(dim = 1) + + my_pflows[layer.partitioned_pfids[0][j,:]+i] = pflows + + assert torch.all(torch.abs(my_pflows - param_flows) < 2e-3) + if __name__ == "__main__": general_ll_prop_test() From 5da7471dd2ef856dc6c9d5ee6ef37054fa56fb04 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Sun, 10 Mar 2024 18:29:56 +0800 Subject: [PATCH 05/53] `MPE` and `GeneralLL` backward pass for sparse kernels + triton bug fix of inaccurate `tr.log(a) - b` --- src/pyjuice/layer/layer.py | 4 +- src/pyjuice/layer/sum_layer.py | 319 ++++++++++------------ src/pyjuice/model/tensorcircuit.py | 2 +- tests/layer/propagation_algs_test.py | 380 ++++++++++++++++++--------- 4 files changed, 402 insertions(+), 303 deletions(-) diff --git a/src/pyjuice/layer/layer.py b/src/pyjuice/layer/layer.py index 00677336..f218cb12 100644 --- a/src/pyjuice/layer/layer.py +++ b/src/pyjuice/layer/layer.py @@ -70,9 +70,9 @@ def provided(self, var_name): def _get_propagation_alg_kwargs(self, propagation_alg: str, **kwargs): if propagation_alg == "LL": - return {} + return {"alpha": 0.0} elif propagation_alg == "MPE": - return {} + return {"alpha": 0.0} elif propagation_alg == "GeneralLL": return {"alpha": kwargs["alpha"]} else: diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index 4db2f049..41cf391d 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -1077,22 +1077,25 @@ def _forward_sparse(self, node_mars: torch.Tensor, element_mars: torch.Tensor, grid = (triton.cdiv(batch_size, BLOCK_B), triton.cdiv(layer_n_nodes, BLOCK_SIZE_M)) - self._fw_triton_sparse_kernel[grid]( - node_mars = node_mars, - element_mars = element_mars, - params = params, - nids = nids, - cids = cids, - pids = pids, - local_ids = local_ids, - batch_size = batch_size, - partial_eval = partial_eval, - num_edges = num_edges, - BLOCK_B = BLOCK_B, - BLOCK_SIZE_M = BLOCK_SIZE_M, - propagation_alg_id = propagation_alg_id, - **propagation_alg_kwargs - ) + try: + self._fw_triton_sparse_kernel[grid]( + node_mars = node_mars, + element_mars = element_mars, + params = params, + nids = nids, + cids = cids, + pids = pids, + local_ids = local_ids, + batch_size = batch_size, + partial_eval = partial_eval, + num_edges = num_edges, + BLOCK_B = BLOCK_B, + BLOCK_SIZE_M = BLOCK_SIZE_M, + propagation_alg_id = propagation_alg_id, + **propagation_alg_kwargs + ) + except TypeError: + import pdb; pdb.set_trace() else: BLOCK_B = max(min(2048 // num_edges, BATCH_SIZE_NP2), 1) @@ -1310,13 +1313,15 @@ def _bk_triton_modify_flow_kernel(node_flows, node_mars, local_ids, nids, batch_ nflows = tl.load(node_flows + offs_nmfs, mask = mask_batch[None,:]) if propagation_alg_id == 0: - uflows = tl.where(nmars != -float("inf"), tl.log(nflows) - nmars, -float("inf")) + lflows = tl.log(nflows) + uflows = tl.where(nmars != -float("inf"), lflows - nmars, -float("inf")) if propagation_alg_id == 1: uflows = nflows if propagation_alg_id == 2: - uflows = tl.where(nmars != -float("inf"), tl.log(nflows) - nmars * alpha, -float("inf")) + lflows = tl.log(nflows) + uflows = tl.where(nmars != -float("inf"), lflows - nmars * alpha, -float("inf")) tl.store(node_flows + offs_nmfs, uflows, mask = mask_batch[None,:]) @@ -1352,19 +1357,21 @@ def _bk_triton_large_modify_flow_kernel(node_flows, node_mars, local_ids, nids, nflows = tl.load(node_flows + offs_nmfs, mask = (mask_m[:,None] & mask_batch[None,:])) if propagation_alg_id == 0: - uflows = tl.where(nmars != -float("inf"), tl.log(nflows) - nmars, -float("inf")) + lflows = tl.log(nflows) + uflows = tl.where(nmars != -float("inf"), lflows - nmars, -float("inf")) if propagation_alg_id == 1: uflows = nflows if propagation_alg_id == 2: - uflows = tl.where(nmars != -float("inf"), tl.log(nflows) - nmars * alpha, -float("inf")) + lflows = tl.log(nflows) + uflows = tl.where(nmars != -float("inf"), lflows - nmars * alpha, -float("inf")) tl.store(node_flows + offs_nmfs, uflows, mask = (mask_m[:,None] & mask_batch[None,:])) def _bk_triton_modify_flow(self, node_flows: torch.Tensor, node_mars: torch.Tensor, nids: torch.Tensor, local_ids: Optional[torch.Tensor] = None, - propagation_alg_id: str = "LL", **kwargs): + propagation_alg: str = "LL", **kwargs): """ Replace `node_flows[nids]` with `node_flows[nids].log() - node_mars[nids]` """ @@ -2222,7 +2229,8 @@ def _backward_sparse(self, node_flows: torch.Tensor, element_flows: torch.Tensor nids: Optional[torch.Tensor], cids: Optional[torch.Tensor], pids: Optional[torch.Tensor], pfids: Optional[torch.Tensor], chids: Optional[torch.Tensor], parids: Optional[torch.Tensor], parpids: Optional[torch.Tensor], cs_block_size: int, local_ids: Optional[torch.Tensor] = None, - partition_id: int = -1, allow_modify_flows: bool = False) -> None: + partition_id: int = -1, allow_modify_flows: bool = False, + propagation_alg: str = "LL", **kwargs) -> None: """ Back pass of sum layers with sparse processing kernel. @@ -2244,7 +2252,8 @@ def _backward_sparse(self, node_flows: torch.Tensor, element_flows: torch.Tensor node_flows, element_flows, params, node_mars, element_mars, chids = chids, parids = parids, parpids = parpids, cs_block_size = cs_block_size, local_ids = local_ids, - allow_modify_flows = allow_modify_flows + allow_modify_flows = allow_modify_flows, + propagation_alg = propagation_alg, **kwargs ) # Flows w.r.t. parameters @@ -2252,7 +2261,8 @@ def _backward_sparse(self, node_flows: torch.Tensor, element_flows: torch.Tensor self._backward_sparse_par_flows( node_flows, params, node_mars, element_mars, param_flows, nids = nids, cids = cids, pids = pids, pfids = pfids, - allow_modify_flows = allow_modify_flows + allow_modify_flows = allow_modify_flows, + propagation_alg = propagation_alg, **kwargs ) return None @@ -2263,7 +2273,7 @@ def _backward_sparse(self, node_flows: torch.Tensor, element_flows: torch.Tensor def _bk_triton_sparse_ele_kernel(node_flows, element_flows, node_mars, element_mars, params, chids, parids, parpids, local_ids, batch_size: tl.constexpr, partial_eval: tl.constexpr, n_edge_blocks: tl.constexpr, allow_modify_flows: tl.constexpr, BLOCK_B: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr): + BLOCK_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, propagation_alg_id: tl.constexpr, alpha = 0.0): pid_b = tl.program_id(0) # ID of size-`BLOCK_B` batches pid_m = tl.program_id(1) # ID of size-`BLOCK_M` nodes @@ -2310,10 +2320,30 @@ def _bk_triton_sparse_ele_kernel(node_flows, element_flows, node_mars, element_m epars = tl.load(epars_ptr) # [num_edges] emars = tl.load(emars_ptr, mask = mask_batch) # [BLOCK_B] - if allow_modify_flows == 1: - eflows = tl.sum(epars[:,None] * tl.exp(emars[None,:] + log_n_fdm), axis = 0) + if propagation_alg_id == 1: + if allow_modify_flows: + nflows = log_n_fdm + + lpars = tl.log(epars) + eflows = tl.sum(tl.where(tl.abs(lpars[:,None] + emars[None,:] - nmars) < 1e-6, nflows, 0.0), axis = 0) + else: - eflows = tl.sum(nflows * epars[:,None] * tl.exp(emars[None,:] - nmars), axis = 0) + if propagation_alg_id == 2: + lpars = tl.log(epars) + epars = tl.exp(lpars * alpha) + + if allow_modify_flows == 1: + if propagation_alg_id == 0: + eflows = tl.sum(epars[:,None] * tl.exp(emars[None,:] + log_n_fdm), axis = 0) + + if propagation_alg_id == 2: + eflows = tl.sum(epars[:,None] * tl.exp(emars[None,:] * alpha + log_n_fdm), axis = 0) + else: + if propagation_alg_id == 0: + eflows = tl.sum(nflows * epars[:,None] * tl.exp(emars[None,:] - nmars), axis = 0) + + if propagation_alg_id == 2: + eflows = tl.sum(nflows * epars[:,None] * tl.exp((emars[None,:] - nmars) * alpha), axis = 0) tl.store(eflows_ptr, eflows, mask = mask_batch) @@ -2331,7 +2361,8 @@ def _bk_triton_large_sparse_ele_kernel(node_flows, element_flows, node_mars, ele chids, parids, parpids, local_ids, num_eles, pid_m_offset, batch_size: tl.constexpr, partial_eval: tl.constexpr, n_edge_blocks: tl.constexpr, allow_modify_flows: tl.constexpr, BLOCK_B: tl.constexpr, - TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr): + TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + propagation_alg_id: tl.constexpr, alpha = 0.0): pid_b = tl.program_id(0) # ID of size-`BLOCK_B` batches pid_m = tl.program_id(1) + pid_m_offset # ID of size-`TILE_SIZE_M` nodes @@ -2379,10 +2410,32 @@ def _bk_triton_large_sparse_ele_kernel(node_flows, element_flows, node_mars, ele mask = (mask_m[:,None] & mask_batch[None,:])) # [TILE_SIZE_M, BLOCK_B] # Compute eflows - if allow_modify_flows == 1: - eflows = tl.sum(epars[:,:,None] * tl.exp(emars[:,None,:] + log_n_fdm), axis = 1) + if propagation_alg_id == 1: + if allow_modify_flows: + nflows = log_n_fdm + + lpars = tl.log(epars) + eflows = tl.sum(tl.where(tl.abs(lpars[:,:,None] + emars[:,None,:] - nmars) < 1e-6, nflows, 0.0), axis = 1) + else: - eflows = tl.sum(nflows * epars[:,:,None] * tl.exp(emars[:,None,:] - nmars), axis = 1) + + if propagation_alg_id == 2: + lpars = tl.log(epars) + epars = tl.exp(lpars * alpha) + + if allow_modify_flows == 1: + if propagation_alg_id == 0: + eflows = tl.sum(epars[:,:,None] * tl.exp(emars[:,None,:] + log_n_fdm), axis = 1) + + if propagation_alg_id == 2: + eflows = tl.sum(epars[:,:,None] * tl.exp(emars[:,None,:] * alpha + log_n_fdm), axis = 1) + else: + + if propagation_alg_id == 0: + eflows = tl.sum(nflows * epars[:,:,None] * tl.exp(emars[:,None,:] - nmars), axis = 1) + + if propagation_alg_id == 2: + eflows = tl.sum(nflows * epars[:,:,None] * tl.exp((emars[:,None,:] - nmars) * alpha), axis = 0) tl.store(eflows_ptr, eflows, mask = (mask_m[:,None] & mask_batch[None,:])) @@ -2390,7 +2443,7 @@ def _backward_sparse_ele_flows(self, node_flows: torch.Tensor, element_flows: to params: torch.Tensor, node_mars: torch.Tensor, element_mars: torch.Tensor, chids: torch.Tensor, parids: torch.Tensor, parpids: torch.Tensor, cs_block_size: int, local_ids: Optional[torch.Tensor] = None, - allow_modify_flows: bool = False) -> None: + allow_modify_flows: bool = False, propagation_alg: str = "LL", **kwargs) -> None: assert params.dim() == 1, "Expecting a 1D `params`." @@ -2401,6 +2454,10 @@ def _backward_sparse_ele_flows(self, node_flows: torch.Tensor, element_flows: to batch_size = node_flows.size(1) BATCH_SIZE_NP2 = triton.next_power_of_2(batch_size) + # Propagation algorithm + propagation_alg_id = self.propagation_alg_mapping[propagation_alg] + propagation_alg_kwargs = self._get_propagation_alg_kwargs(propagation_alg, **kwargs) + assert num_edges <= 16384, "The sparse backward kernel only support nodes with # edges smaller than 16384." if triton.cdiv(layer_n_nodes, cs_block_size) <= 32768: @@ -2428,7 +2485,9 @@ def _backward_sparse_ele_flows(self, node_flows: torch.Tensor, element_flows: to allow_modify_flows = allow_modify_flows, BLOCK_B = BLOCK_B, BLOCK_M = BLOCK_M, - BLOCK_SIZE_K = self.block_size + BLOCK_SIZE_K = self.block_size, + propagation_alg_id = propagation_alg_id, + **propagation_alg_kwargs ) else: @@ -2460,7 +2519,9 @@ def _backward_sparse_ele_flows(self, node_flows: torch.Tensor, element_flows: to BLOCK_B = BLOCK_B, TILE_SIZE_M = TILE_SIZE_M, BLOCK_SIZE_M = cs_block_size, - BLOCK_SIZE_K = self.block_size + BLOCK_SIZE_K = self.block_size, + propagation_alg_id = propagation_alg_id, + **propagation_alg_kwargs ) else: @@ -2488,7 +2549,9 @@ def _backward_sparse_ele_flows(self, node_flows: torch.Tensor, element_flows: to BLOCK_B = BLOCK_B, TILE_SIZE_M = TILE_SIZE_M, BLOCK_SIZE_M = cs_block_size, - BLOCK_SIZE_K = self.block_size + BLOCK_SIZE_K = self.block_size, + propagation_alg_id = propagation_alg_id, + **propagation_alg_kwargs ) return None @@ -2499,7 +2562,7 @@ def _backward_sparse_ele_flows(self, node_flows: torch.Tensor, element_flows: to def _bk_triton_sparse_par_kernel(node_flows, node_mars, element_mars, params, param_flows, nids, cids, pids, pfids, pid_m_offset, num_edges: tl.constexpr, batch_size: tl.constexpr, allow_modify_flows: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_K: tl.constexpr, BLOCK_B: tl.constexpr, - TILE_SIZE_B: tl.constexpr, B_NUM_BLOCKS: tl.constexpr): + TILE_SIZE_B: tl.constexpr, B_NUM_BLOCKS: tl.constexpr, propagation_alg_id: tl.constexpr, alpha = 0.0): pid_b = tl.program_id(0) # ID of size-`BLOCK_B` samples pid_e = tl.program_id(1) # ID of size-`BLOCK_K` edges @@ -2525,6 +2588,12 @@ def _bk_triton_sparse_par_kernel(node_flows, node_mars, element_mars, params, pa nmars_ptr = node_mars + (off_nids + tile_id) * batch_size + offs_batch # [BLOCK_B] nflows_ptr = node_flows + (off_nids + tile_id) * batch_size + offs_batch # [BLOCK_B] + if propagation_alg_id == 1: + par_start = tl.load(pids + nblock_id * num_edges + offs_edge) + epars_ptr = params + par_start + tile_id + epars = tl.load(epars_ptr) # [BLOCK_K] + elpars = tl.log(epars) + # Inner loop acc = tl.zeros([BLOCK_K], dtype = tl.float32) @@ -2535,107 +2604,59 @@ def _bk_triton_sparse_par_kernel(node_flows, node_mars, element_mars, params, pa emars = tl.load(emars_ptr, mask = mask_batch[None,:], other = -float("inf")) # [BLOCK_K, BLOCK_B] - if allow_modify_flows == 1: - log_n_fdm = tl.load(nflows_ptr, mask = mask_batch, other = -float("inf")) # [BLOCK_B] - pflows = tl.sum(tl.exp(emars + log_n_fdm[None,:]), axis = 1) - else: + if propagation_alg_id == 1: nmars = tl.load(nmars_ptr, mask = mask_batch, other = 0.0) # [BLOCK_B] nflows = tl.load(nflows_ptr, mask = mask_batch, other = 0.0) # [BLOCK_B] - pflows = tl.sum(nflows[None,:] * tl.exp(emars - nmars[None,:]), axis = 1) - - acc += pflows - - # Increment `emars_ptr`, `nmars_ptr`, and `nmars_ptr` - emars_ptr += BLOCK_B - nmars_ptr += BLOCK_B - nflows_ptr += BLOCK_B - - par_start = tl.load(pids + nblock_id * num_edges + offs_edge) - epars_ptr = params + par_start + tile_id - epars = tl.load(epars_ptr) # [BLOCK_K] - - parflow_start = tl.load(pfids + nblock_id * num_edges + offs_edge) - eparflows_ptr = param_flows + parflow_start + tile_id - - curr_pflows = acc * epars - - tl.atomic_add(eparflows_ptr, curr_pflows) - - @staticmethod - # @triton.jit - @FastJITFunction - def _bk_triton_large_sparse_par_kernel(node_flows, node_mars, element_mars, params, param_flows, nids, cids, pids, pfids, - num_nodes, num_edges: tl.constexpr, batch_size: tl.constexpr, allow_modify_flows: tl.constexpr, - TILE_SIZE_M: tl.constexpr, BLOCK_K: tl.constexpr, BLOCK_B: tl.constexpr, - TILE_SIZE_B: tl.constexpr, B_NUM_BLOCKS: tl.constexpr, BLOCK_SIZE_M: tl.constexpr): - - pid_b = tl.program_id(0) # ID of size-`BLOCK_B` samples - pid_e = tl.program_id(1) # ID of size-`BLOCK_K` edges - pid_m = tl.program_id(2) # ID of size-`TILE_SIZE_M` nodes - - offs_m = tl.arange(0, TILE_SIZE_M) + pid_m * TILE_SIZE_M - mask_m = offs_m < num_nodes - - # Get inferred node block id from `pid_m` - nblock_ids = offs_m // BLOCK_SIZE_M - tile_ids = offs_m % BLOCK_SIZE_M - - # Batch offsets and mask - offs_batch = tl.arange(0, BLOCK_B) + pid_b * TILE_SIZE_B - mask_batch = offs_batch < batch_size - # Initialize pointers to `element_mars` - offs_edge = tl.arange(0, BLOCK_K) + pid_e * BLOCK_K - edge_start = tl.load(cids + nblock_ids[:,None] * num_edges + offs_edge[None,:], mask = mask_m[:,None]) - emars_ptr = element_mars + \ - edge_start[:,:,None] * batch_size + \ - offs_batch[None,None,:] # [TILE_SIZE_M, BLOCK_K, BLOCK_B] + acc += tl.sum(tl.where(tl.abs(elpars[:,None] + emars - nmars[None,:]) < 1e-6, nflows[None,:], 0.0), axis = 1) - # Initialize pointers to `node_flows` and `node_mars` - off_nids = tl.load(nids + nblock_ids, mask = mask_m) - nmars_ptr = node_mars + (off_nids + tile_ids)[:,None] * batch_size + offs_batch[None,:] # [TILE_SIZE_M, BLOCK_B] - nflows_ptr = node_flows + (off_nids + tile_ids)[:,None] * batch_size + offs_batch[None,:] # [TILE_SIZE_M, BLOCK_B] + else: - # Inner loop - acc = tl.zeros([TILE_SIZE_M, BLOCK_K], dtype = tl.float32) + 0.1 + if propagation_alg_id == 2: + emars *= alpha - for b in range(0, B_NUM_BLOCKS): - # Batch offsets and mask - offs_batch = tl.arange(0, BLOCK_B) + pid_b * TILE_SIZE_B + b * BLOCK_B - mask_batch = offs_batch < batch_size + if allow_modify_flows == 1: + log_n_fdm = tl.load(nflows_ptr, mask = mask_batch, other = -float("inf")) # [BLOCK_B] + pflows = tl.sum(tl.exp(emars + log_n_fdm[None,:]), axis = 1) + else: + nmars = tl.load(nmars_ptr, mask = mask_batch, other = 0.0) # [BLOCK_B] + nflows = tl.load(nflows_ptr, mask = mask_batch, other = 0.0) # [BLOCK_B] - emars = tl.load(emars_ptr, mask = (mask_m[:,None,None] & mask_batch[None,None,:]), other = -float("inf")) # [TILE_SIZE_M, BLOCK_K, BLOCK_B] + if propagation_alg_id == 0: + pflows = tl.sum(nflows[None,:] * tl.exp(emars - nmars[None,:]), axis = 1) - if allow_modify_flows == 1: - log_n_fdm = tl.load(nflows_ptr, mask = (mask_m[:,None] & mask_batch[None,:]), other = -float("inf")) # [TILE_SIZE_M, BLOCK_B] - pflows = tl.sum(tl.exp(emars + log_n_fdm[:,None,:]), axis = 2) - else: - nmars = tl.load(nmars_ptr, mask = (mask_m[:,None] & mask_batch[None,:]), other = 0.0) # [TILE_SIZE_M, BLOCK_B] - nflows = tl.load(nflows_ptr, mask = (mask_m[:,None] & mask_batch[None,:]), other = 0.0) # [TILE_SIZE_M, BLOCK_B] - pflows = tl.sum(nflows[:,None,:] * tl.exp(emars - nmars[:,None,:]), axis = 2) + if propagation_alg_id == 2: + pflows = tl.sum(nflows[None,:] * tl.exp(emars - nmars[None,:] * alpha), axis = 1) - acc += pflows + acc += pflows # Increment `emars_ptr`, `nmars_ptr`, and `nmars_ptr` emars_ptr += BLOCK_B nmars_ptr += BLOCK_B nflows_ptr += BLOCK_B - par_start = tl.load(pids + nblock_ids[:,None] * num_edges + offs_edge[None,:]) - epars_ptr = params + par_start + tile_ids[:,None] - epars = tl.load(epars_ptr, mask = mask_m[:,None]) # [TILE_SIZE_M, BLOCK_K] + if propagation_alg_id != 1: + par_start = tl.load(pids + nblock_id * num_edges + offs_edge) + epars_ptr = params + par_start + tile_id + epars = tl.load(epars_ptr) # [BLOCK_K] + + if propagation_alg_id == 2: + epars = tl.exp(tl.log(epars) * alpha) - parflow_start = tl.load(pfids + nblock_ids[:,None] * num_edges + offs_edge[None,:]) - eparflows_ptr = param_flows + parflow_start + tile_ids[:,None] + parflow_start = tl.load(pfids + nblock_id * num_edges + offs_edge) + eparflows_ptr = param_flows + parflow_start + tile_id - curr_pflows = acc * epars + if propagation_alg_id != 1: + curr_pflows = acc * epars + else: + curr_pflows = acc - tl.atomic_add(eparflows_ptr, curr_pflows, mask = mask_m[:,None]) + tl.atomic_add(eparflows_ptr, curr_pflows) def _backward_sparse_par_flows(self, node_flows: torch.Tensor, params: torch.Tensor, node_mars: torch.Tensor, element_mars: torch.Tensor, param_flows: torch.Tensor, nids: torch.Tensor, cids: torch.Tensor, pids: torch.Tensor, pfids: torch.Tensor, - allow_modify_flows: bool = False) -> None: + allow_modify_flows: bool = False, propagation_alg: str = "LL", **kwargs) -> None: """ Backward pass of sum layers w.r.t. sum parameters with the block-sparse processing kernel. @@ -2659,7 +2680,9 @@ def _backward_sparse_par_flows(self, node_flows: torch.Tensor, params: torch.Ten batch_size = node_mars.size(1) BATCH_SIZE_NP2 = triton.next_power_of_2(batch_size) - # assert num_edges <= 16384, "The sparse backward kernel only support nodes with # edges smaller than 16384." + # Propagation algorithm + propagation_alg_id = self.propagation_alg_mapping[propagation_alg] + propagation_alg_kwargs = self._get_propagation_alg_kwargs(propagation_alg, **kwargs) if num_edges <= 1024: BLOCK_B = max(min(2048 // num_edges, BATCH_SIZE_NP2), 1) @@ -2705,7 +2728,9 @@ def _backward_sparse_par_flows(self, node_flows: torch.Tensor, params: torch.Ten BLOCK_K = BLOCK_K, BLOCK_B = BLOCK_B, TILE_SIZE_B = TILE_SIZE_B, - B_NUM_BLOCKS = B_NUM_BLOCKS + B_NUM_BLOCKS = B_NUM_BLOCKS, + propagation_alg_id = propagation_alg_id, + **propagation_alg_kwargs ) else: @@ -2733,63 +2758,11 @@ def _backward_sparse_par_flows(self, node_flows: torch.Tensor, params: torch.Ten BLOCK_K = BLOCK_K, BLOCK_B = BLOCK_B, TILE_SIZE_B = TILE_SIZE_B, - B_NUM_BLOCKS = B_NUM_BLOCKS + B_NUM_BLOCKS = B_NUM_BLOCKS, + propagation_alg_id = propagation_alg_id, + **propagation_alg_kwargs ) - # else: - - # if num_edges <= 1024: - # BLOCK_B = max(min(2048 // num_edges, BATCH_SIZE_NP2), 1) - # BLOCK_K = num_edges - # TILE_SIZE_M = max(min(4096 // num_edges, triton.next_power_of_2(layer_n_nodes)), 1) - # else: - # BLOCK_B = min(512, BATCH_SIZE_NP2) - # BLOCK_K = min(2048 // BLOCK_B, num_edges) - # TILE_SIZE_M = max(min(2048 // num_edges, triton.next_power_of_2(layer_n_nodes)), 1) - # B_NUM_BLOCKS = triton.cdiv(batch_size, BLOCK_B) - # K_NUM_BLOCKS = triton.cdiv(num_edges, BLOCK_K) - - # # When a thread-block is allocated for too much work, the overhead - # # outweigh that incurred by `atomic_add`. Add more thread-blocks - # # for parallel processing in this case. - # if B_NUM_BLOCKS >= 4: - # TILE_SIZE_B = 4 * BLOCK_B - # B_NUM_BLOCKS = 4 - # else: - # TILE_SIZE_B = batch_size - # B_NUM_TILES = triton.cdiv(batch_size, TILE_SIZE_B) - - # allow_modify_flows = 1 if allow_modify_flows else 0 - - # grid = (B_NUM_TILES, K_NUM_BLOCKS, triton.cdiv(layer_n_nodes, TILE_SIZE_M)) - - # print(">>>G", grid, "in") - # # TODO: This kernel gets stuck for some input configurations. Fix it. - # if grid[0] == 2 and grid[1] == 1 and grid[2] == 308: - # import pdb; pdb.set_trace() - # self._bk_triton_large_sparse_par_kernel[grid]( - # node_flows = node_flows, - # node_mars = node_mars, - # element_mars = element_mars, - # params = params, - # param_flows = param_flows, - # nids = nids, - # cids = cids, - # pids = pids, - # pfids = pfids, - # num_nodes = layer_n_nodes, - # num_edges = num_edges, - # batch_size = batch_size, - # allow_modify_flows = allow_modify_flows, - # TILE_SIZE_M = TILE_SIZE_M, - # BLOCK_K = BLOCK_K, - # BLOCK_B = BLOCK_B, - # TILE_SIZE_B = TILE_SIZE_B, - # B_NUM_BLOCKS = B_NUM_BLOCKS, - # BLOCK_SIZE_M = self.block_size - # ) - # print(">>>G", grid, "out") - return None def _backward_pytorch(self, node_flows, element_flows, params, node_mars, diff --git a/src/pyjuice/model/tensorcircuit.py b/src/pyjuice/model/tensorcircuit.py index a9e000f6..c3995dd5 100644 --- a/src/pyjuice/model/tensorcircuit.py +++ b/src/pyjuice/model/tensorcircuit.py @@ -19,7 +19,7 @@ normalize_parameters -def _pc_model_backward_hook(grad, pc, inputs, record_cudagraph, apply_cudagraph, **kwargs): +def _pc_model_backward_hook(grad, pc, inputs, record_cudagraph, apply_cudagraph, propagation_alg, **kwargs): grad = grad.permute(1, 0) pc.backward( inputs = inputs, diff --git a/tests/layer/propagation_algs_test.py b/tests/layer/propagation_algs_test.py index 7928a3ac..03f37399 100644 --- a/tests/layer/propagation_algs_test.py +++ b/tests/layer/propagation_algs_test.py @@ -14,205 +14,331 @@ import pytest -def general_ll_prop_test(): +def ll_prop_test(): device = torch.device("cuda:0") batch_size = 16 - for block_size in [4, 8, 16]: - - with juice.set_block_size(block_size): + for block_size in [1, 4, 8, 16]: - ni0 = inputs(0, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) - ni1 = inputs(1, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) - ni2 = inputs(2, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) - ni3 = inputs(3, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + for allow_modify_flows in [True, False]: + + with juice.set_block_size(block_size): - np0 = multiply(ni0, ni1) - np1 = multiply(ni2, ni3) - np2 = multiply(ni1, ni2) + ni0 = inputs(0, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni1 = inputs(1, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni2 = inputs(2, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni3 = inputs(3, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) - ns0 = summate(np0, num_node_blocks = 2) - ns1 = summate(np1, num_node_blocks = 2) - ns2 = summate(np2, num_node_blocks = 2) + np0 = multiply(ni0, ni1) + np1 = multiply(ni2, ni3) + np2 = multiply(ni1, ni2) - input_layer = InputLayer([ni0, ni1, ni2, ni3], cum_nodes = block_size) + ns0 = summate(np0, num_node_blocks = 2) + ns1 = summate(np1, num_node_blocks = 2) + ns2 = summate(np2, num_node_blocks = 2) - prod_layer = ProdLayer([np0, np1, np2]) + input_layer = InputLayer([ni0, ni1, ni2, ni3], cum_nodes = block_size) - layer = SumLayer([ns0, ns1, ns2], global_nid_start = block_size, - global_pid_start = block_size ** 2, - global_pfid_start = 0, node2tiednodes = dict()) + prod_layer = ProdLayer([np0, np1, np2]) - layer.to(device) + layer = SumLayer([ns0, ns1, ns2], global_nid_start = block_size, + global_pid_start = block_size ** 2, + global_pfid_start = 0, node2tiednodes = dict()) - ## Forward pass ## + layer.to(device) - element_mars = torch.rand([block_size + 3 * 2 * 2 * block_size, batch_size]).log().to(device) - element_mars[:block_size,:] = -float("inf") - node_mars = torch.zeros([block_size + block_size * 2 * 3, batch_size]).to(device) + ## Forward pass ## - params = torch.rand([block_size ** 2 + 3 * 4 * block_size * block_size]).to(device) + element_mars = torch.rand([block_size + 3 * 2 * 2 * block_size, batch_size]).log().to(device) + element_mars[:block_size,:] = -float("inf") + node_mars = torch.zeros([block_size + block_size * 2 * 3, batch_size]).to(device) - alphas = [1.2, 2.0, 3.0] + params = torch.rand([block_size ** 2 + 3 * 4 * block_size * block_size]).to(device) - for alpha in alphas: - layer(node_mars, element_mars, params, propagation_alg = "GeneralLL", alpha = alpha) + layer(node_mars, element_mars, params, propagation_alg = "LL") for i in range(block_size): for j in range(6): cmars = element_mars[layer.partitioned_cids[0][j,:]].exp() epars = params[layer.partitioned_pids[0][j,:]+i] - scaled_lls = (epars[:,None]**alpha * cmars**alpha).sum(dim = 0).log() * (1.0 / alpha) - assert torch.all(torch.abs(node_mars[(j+1)*block_size+i,:] - scaled_lls) < 1e-3) + scaled_lls = (epars[:,None] * cmars).sum(dim = 0).log() + + assert torch.all(torch.abs(node_mars[(j+1)*block_size+i,:] - scaled_lls) < 2e-3) - ## Backward pass ## + ## Backward pass ## - node_flows = torch.rand([block_size + block_size * 2 * 3, batch_size]).to(device) - element_flows = torch.zeros([block_size + 3 * 2 * 2 * block_size, batch_size]).to(device) + node_flows = torch.rand([block_size + block_size * 2 * 3, batch_size]).to(device) + element_flows = torch.zeros([block_size + 3 * 2 * 2 * block_size, batch_size]).to(device) - param_flows = torch.zeros([block_size ** 2 + 3 * 4 * block_size * block_size]).to(device) + param_flows = torch.zeros([block_size ** 2 + 3 * 4 * block_size * block_size]).to(device) - layer.backward(node_flows, element_flows, node_mars, element_mars, params, param_flows, - propagation_alg = "GeneralLL", alpha = alpha) + origin_node_flows = node_flows.clone() - chids = layer.partitioned_chids[0] - parids = layer.partitioned_parids[0] - parpids = layer.partitioned_parpids[0] + layer.backward(node_flows, element_flows, node_mars, element_mars, params, param_flows, + allow_modify_flows = allow_modify_flows, propagation_alg = "LL") - num_nblocks = chids.size(0) - num_eblocks = parids.size(1) - parids = (parids[:,:,None].repeat(1, 1, block_size) + torch.arange(0, block_size, device = parids.device)).reshape(num_nblocks, num_eblocks * block_size) - parpids_start = (parpids[:,:,None] + torch.arange(0, block_size, device = parids.device)).reshape( - num_nblocks, num_eblocks * block_size) + chids = layer.partitioned_chids[0] + parids = layer.partitioned_parids[0] + parpids = layer.partitioned_parpids[0] - for j in range(6): - parpids = parpids_start.clone() - for i in range(block_size): - nmars = node_mars[parids[j,:]].exp() - nflows = node_flows[parids[j,:]] - emars = element_mars[(j+1)*block_size+i,:].exp() - epars = params[parpids[j,:]] - eflows = (nflows * (epars[:,None] * emars[None,:]) ** alpha / nmars ** alpha).sum(dim = 0) + num_nblocks = chids.size(0) + num_eblocks = parids.size(1) + parids = (parids[:,:,None].repeat(1, 1, block_size) + torch.arange(0, block_size, device = parids.device)).reshape(num_nblocks, num_eblocks * block_size) + parpids_start = (parpids[:,:,None] + torch.arange(0, block_size, device = parids.device)).reshape( + num_nblocks, num_eblocks * block_size) - assert torch.all(torch.abs(eflows - element_flows[(j+1)*block_size+i,:]) < 1e-2) + for j in range(6): + parpids = parpids_start.clone() + for i in range(block_size): + nmars = node_mars[parids[j,:]].exp() + nflows = origin_node_flows[parids[j,:]] + emars = element_mars[(j+1)*block_size+i,:].exp() + epars = params[parpids[j,:]] + eflows = (nflows * (epars[:,None] * emars[None,:]) / nmars).sum(dim = 0) - parpids += block_size + if allow_modify_flows: + uflows1 = node_flows[parids[j,:]] + uflows2 = origin_node_flows[parids[j,:]].log() - nmars.log() - my_pflows = torch.zeros_like(param_flows) + assert torch.all(torch.abs(uflows1 - uflows2) < 1e-3) - for i in range(block_size): - for j in range(6): - emars = element_mars[layer.partitioned_cids[0][j,:]].exp() - epars = params[layer.partitioned_pids[0][j,:]+i] - nmars = node_mars[(j+1)*block_size+i,:].exp() - nflows = node_flows[(j+1)*block_size+i,:] - pflows = epars ** alpha * (nflows[None,:] * emars ** alpha / nmars[None,:] ** alpha).sum(dim = 1) + assert torch.all(torch.abs(eflows - element_flows[(j+1)*block_size+i,:]) < 1e-2) - my_pflows[layer.partitioned_pfids[0][j,:]+i] = pflows + parpids += block_size - assert torch.all(torch.abs(my_pflows - param_flows) < 2e-3) + my_pflows = torch.zeros_like(param_flows) + for i in range(block_size): + for j in range(6): + emars = element_mars[layer.partitioned_cids[0][j,:]].exp() + epars = params[layer.partitioned_pids[0][j,:]+i] + nmars = node_mars[(j+1)*block_size+i,:].exp() + nflows = origin_node_flows[(j+1)*block_size+i,:] + pflows = epars * (nflows[None,:] * emars / nmars[None,:]).sum(dim = 1) + + my_pflows[layer.partitioned_pfids[0][j,:]+i] = pflows + + assert torch.all(torch.abs(my_pflows - param_flows) < 2e-3) -def mpe_prop_test(): +def general_ll_prop_test(): + device = torch.device("cuda:0") batch_size = 16 - for block_size in [4, 8, 16]: + for block_size in [1, 4, 8, 16]: + + for allow_modify_flows in [True, False]: - with juice.set_block_size(block_size): + with juice.set_block_size(block_size): - ni0 = inputs(0, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) - ni1 = inputs(1, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) - ni2 = inputs(2, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) - ni3 = inputs(3, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni0 = inputs(0, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni1 = inputs(1, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni2 = inputs(2, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni3 = inputs(3, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) - np0 = multiply(ni0, ni1) - np1 = multiply(ni2, ni3) - np2 = multiply(ni1, ni2) + np0 = multiply(ni0, ni1) + np1 = multiply(ni2, ni3) + np2 = multiply(ni1, ni2) - ns0 = summate(np0, num_node_blocks = 2) - ns1 = summate(np1, num_node_blocks = 2) - ns2 = summate(np2, num_node_blocks = 2) + ns0 = summate(np0, num_node_blocks = 2) + ns1 = summate(np1, num_node_blocks = 2) + ns2 = summate(np2, num_node_blocks = 2) - input_layer = InputLayer([ni0, ni1, ni2, ni3], cum_nodes = block_size) + input_layer = InputLayer([ni0, ni1, ni2, ni3], cum_nodes = block_size) - prod_layer = ProdLayer([np0, np1, np2]) + prod_layer = ProdLayer([np0, np1, np2]) - layer = SumLayer([ns0, ns1, ns2], global_nid_start = block_size, - global_pid_start = block_size ** 2, - global_pfid_start = 0, node2tiednodes = dict()) + layer = SumLayer([ns0, ns1, ns2], global_nid_start = block_size, + global_pid_start = block_size ** 2, + global_pfid_start = 0, node2tiednodes = dict()) - layer.to(device) + layer.to(device) - ## Forward pass ## + alphas = [1.2, 2.0, 3.0] - element_mars = torch.rand([block_size + 3 * 2 * 2 * block_size, batch_size]).log().to(device) - element_mars[:block_size,:] = -float("inf") - node_mars = torch.zeros([block_size + block_size * 2 * 3, batch_size]).to(device) + for alpha in alphas: - params = torch.rand([block_size ** 2 + 3 * 4 * block_size * block_size]).to(device) + ## Forward pass ## - layer(node_mars, element_mars, params, propagation_alg = "MPE") + element_mars = torch.rand([block_size + 3 * 2 * 2 * block_size, batch_size]).log().to(device) + element_mars[:block_size,:] = -float("inf") + node_mars = torch.zeros([block_size + block_size * 2 * 3, batch_size]).to(device) - for i in range(block_size): - for j in range(6): - cmars = element_mars[layer.partitioned_cids[0][j,:]].exp() - epars = params[layer.partitioned_pids[0][j,:]+i] - scaled_lls = (epars[:,None] * cmars).max(dim = 0).values.log() - assert torch.all(torch.abs(node_mars[(j+1)*block_size+i,:] - scaled_lls) < 1e-3) + params = torch.rand([block_size ** 2 + 3 * 4 * block_size * block_size]).to(device) + + layer(node_mars, element_mars, params, propagation_alg = "GeneralLL", alpha = alpha) - ## Backward pass ## + for i in range(block_size): + for j in range(6): + cmars = element_mars[layer.partitioned_cids[0][j,:]].exp() + epars = params[layer.partitioned_pids[0][j,:]+i] + scaled_lls = (epars[:,None]**alpha * cmars**alpha).sum(dim = 0).log() * (1.0 / alpha) - node_flows = torch.rand([block_size + block_size * 2 * 3, batch_size]).to(device) - element_flows = torch.zeros([block_size + 3 * 2 * 2 * block_size, batch_size]).to(device) + assert torch.all(torch.abs(node_mars[(j+1)*block_size+i,:] - scaled_lls) < 2e-3) - param_flows = torch.zeros([block_size ** 2 + 3 * 4 * block_size * block_size]).to(device) + ## Backward pass ## - layer.backward(node_flows, element_flows, node_mars, element_mars, params, param_flows, - allow_modify_flows = False, propagation_alg = "MPE") + node_flows = torch.rand([block_size + block_size * 2 * 3, batch_size]).to(device) + element_flows = torch.zeros([block_size + 3 * 2 * 2 * block_size, batch_size]).to(device) - chids = layer.partitioned_chids[0] - parids = layer.partitioned_parids[0] - parpids = layer.partitioned_parpids[0] + param_flows = torch.zeros([block_size ** 2 + 3 * 4 * block_size * block_size]).to(device) - num_nblocks = chids.size(0) - num_eblocks = parids.size(1) - parids = (parids[:,:,None].repeat(1, 1, block_size) + torch.arange(0, block_size, device = parids.device)).reshape(num_nblocks, num_eblocks * block_size) - parpids_start = (parpids[:,:,None] + torch.arange(0, block_size, device = parids.device)).reshape( - num_nblocks, num_eblocks * block_size) + origin_node_flows = node_flows.clone() + + layer.backward(node_flows, element_flows, node_mars, element_mars, params, param_flows, + allow_modify_flows = allow_modify_flows, propagation_alg = "GeneralLL", alpha = alpha) + + chids = layer.partitioned_chids[0] + parids = layer.partitioned_parids[0] + parpids = layer.partitioned_parpids[0] + + num_nblocks = chids.size(0) + num_eblocks = parids.size(1) + parids = (parids[:,:,None].repeat(1, 1, block_size) + torch.arange(0, block_size, device = parids.device)).reshape(num_nblocks, num_eblocks * block_size) + parpids_start = (parpids[:,:,None] + torch.arange(0, block_size, device = parids.device)).reshape( + num_nblocks, num_eblocks * block_size) + + for j in range(6): + parpids = parpids_start.clone() + for i in range(block_size): + nmars = node_mars[parids[j,:]].exp() + nflows = origin_node_flows[parids[j,:]] + emars = element_mars[(j+1)*block_size+i,:].exp() + epars = params[parpids[j,:]] + eflows = (nflows * (epars[:,None] * emars[None,:]) ** alpha / nmars ** alpha).sum(dim = 0) + + if allow_modify_flows: + uflows1 = node_flows[parids[j,:]] + uflows2 = origin_node_flows[parids[j,:]].log() - nmars.log() * alpha + assert torch.all(torch.abs(uflows1 - uflows2) < 1e-3) + + assert torch.all(torch.abs(eflows - element_flows[(j+1)*block_size+i,:]) < 1e-2) + + parpids += block_size + + my_pflows = torch.zeros_like(param_flows) + + for i in range(block_size): + for j in range(6): + emars = element_mars[layer.partitioned_cids[0][j,:]].exp() + epars = params[layer.partitioned_pids[0][j,:]+i] + nmars = node_mars[(j+1)*block_size+i,:].exp() + nflows = origin_node_flows[(j+1)*block_size+i,:] + pflows = epars ** alpha * (nflows[None,:] * emars ** alpha / nmars[None,:] ** alpha).sum(dim = 1) + + my_pflows[layer.partitioned_pfids[0][j,:]+i] = pflows + + assert torch.all(torch.abs(my_pflows - param_flows) < 2e-3) + + +def mpe_prop_test(): + + device = torch.device("cuda:0") + + batch_size = 16 + + for block_size in [1, 4, 8, 16]: + + for allow_modify_flows in [True, False]: + + with juice.set_block_size(block_size): + + ni0 = inputs(0, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni1 = inputs(1, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni2 = inputs(2, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni3 = inputs(3, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + + np0 = multiply(ni0, ni1) + np1 = multiply(ni2, ni3) + np2 = multiply(ni1, ni2) + + ns0 = summate(np0, num_node_blocks = 2) + ns1 = summate(np1, num_node_blocks = 2) + ns2 = summate(np2, num_node_blocks = 2) + + input_layer = InputLayer([ni0, ni1, ni2, ni3], cum_nodes = block_size) + + prod_layer = ProdLayer([np0, np1, np2]) + + layer = SumLayer([ns0, ns1, ns2], global_nid_start = block_size, + global_pid_start = block_size ** 2, + global_pfid_start = 0, node2tiednodes = dict()) + + layer.to(device) + + ## Forward pass ## + + element_mars = torch.rand([block_size + 3 * 2 * 2 * block_size, batch_size]).log().to(device) + element_mars[:block_size,:] = -float("inf") + node_mars = torch.zeros([block_size + block_size * 2 * 3, batch_size]).to(device) + + params = torch.rand([block_size ** 2 + 3 * 4 * block_size * block_size]).to(device) + + layer(node_mars, element_mars, params, propagation_alg = "MPE") - for j in range(6): - parpids = parpids_start.clone() for i in range(block_size): - nmars = node_mars[parids[j,:]].exp() - nflows = node_flows[parids[j,:]] - emars = element_mars[(j+1)*block_size+i,:].exp() - epars = params[parpids[j,:]] - eflows = (nflows * (((epars[:,None] * emars[None,:]) - nmars).abs() < 1e-6).float()).sum(dim = 0) + for j in range(6): + cmars = element_mars[layer.partitioned_cids[0][j,:]].exp() + epars = params[layer.partitioned_pids[0][j,:]+i] + scaled_lls = (epars[:,None] * cmars).max(dim = 0).values.log() + assert torch.all(torch.abs(node_mars[(j+1)*block_size+i,:] - scaled_lls) < 1e-3) + + ## Backward pass ## + + node_flows = torch.rand([block_size + block_size * 2 * 3, batch_size]).to(device) + element_flows = torch.zeros([block_size + 3 * 2 * 2 * block_size, batch_size]).to(device) + + param_flows = torch.zeros([block_size ** 2 + 3 * 4 * block_size * block_size]).to(device) - assert torch.all(torch.abs(eflows - element_flows[(j+1)*block_size+i,:]) < 1e-2) + origin_node_flows = node_flows.clone() - parpids += block_size + layer.backward(node_flows, element_flows, node_mars, element_mars, params, param_flows, + allow_modify_flows = allow_modify_flows, propagation_alg = "MPE") - my_pflows = torch.zeros_like(param_flows) + chids = layer.partitioned_chids[0] + parids = layer.partitioned_parids[0] + parpids = layer.partitioned_parpids[0] + + num_nblocks = chids.size(0) + num_eblocks = parids.size(1) + parids = (parids[:,:,None].repeat(1, 1, block_size) + torch.arange(0, block_size, device = parids.device)).reshape(num_nblocks, num_eblocks * block_size) + parpids_start = (parpids[:,:,None] + torch.arange(0, block_size, device = parids.device)).reshape( + num_nblocks, num_eblocks * block_size) - for i in range(block_size): for j in range(6): - emars = element_mars[layer.partitioned_cids[0][j,:]].exp() - epars = params[layer.partitioned_pids[0][j,:]+i] - nmars = node_mars[(j+1)*block_size+i,:].exp() - nflows = node_flows[(j+1)*block_size+i,:] - pflows = (nflows[None,:] * ((epars[:,None] * emars - nmars[None,:]).abs() < 1e-6).float()).sum(dim = 1) + parpids = parpids_start.clone() + for i in range(block_size): + nmars = node_mars[parids[j,:]].exp() + nflows = origin_node_flows[parids[j,:]] + emars = element_mars[(j+1)*block_size+i,:].exp() + epars = params[parpids[j,:]] + eflows = (nflows * (((epars[:,None] * emars[None,:]) - nmars).abs() < 1e-6).float()).sum(dim = 0) + + assert torch.all(torch.abs(eflows - element_flows[(j+1)*block_size+i,:]) < 1e-2) + + parpids += block_size + + my_pflows = torch.zeros_like(param_flows) + + for i in range(block_size): + for j in range(6): + emars = element_mars[layer.partitioned_cids[0][j,:]].exp() + epars = params[layer.partitioned_pids[0][j,:]+i] + nmars = node_mars[(j+1)*block_size+i,:].exp() + nflows = origin_node_flows[(j+1)*block_size+i,:] + pflows = (nflows[None,:] * ((epars[:,None] * emars - nmars[None,:]).abs() < 1e-6).float()).sum(dim = 1) - my_pflows[layer.partitioned_pfids[0][j,:]+i] = pflows + my_pflows[layer.partitioned_pfids[0][j,:]+i] = pflows - assert torch.all(torch.abs(my_pflows - param_flows) < 2e-3) + assert torch.all(torch.abs(my_pflows - param_flows) < 2e-3) if __name__ == "__main__": + torch.manual_seed(280) + ll_prop_test() general_ll_prop_test() mpe_prop_test() From ab27e0e76a4b98b42f88e8dbce1ebf8ecc473a01 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Sun, 10 Mar 2024 18:37:08 +0800 Subject: [PATCH 06/53] fix arg passing --- src/pyjuice/layer/sum_layer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index 41cf391d..8f832dfe 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -1410,7 +1410,7 @@ def _bk_triton_modify_flow(self, node_flows: torch.Tensor, node_mars: torch.Tens BLOCK_M = BLOCK_M, BLOCK_SIZE_M = BLOCK_SIZE_M, propagation_alg_id = propagation_alg_id, - **kwargs + **propagation_alg_kwargs ) else: @@ -1435,7 +1435,7 @@ def _bk_triton_modify_flow(self, node_flows: torch.Tensor, node_mars: torch.Tens TILE_SIZE_M = TILE_SIZE_M, BLOCK_SIZE_M = BLOCK_SIZE_M, propagation_alg_id = propagation_alg_id, - **kwargs + **propagation_alg_kwargs ) return None From c2bd0558de2e7fadec893c04c2b00c84f9827b0f Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Sun, 10 Mar 2024 21:11:58 +0800 Subject: [PATCH 07/53] fix visualization functions --- src/pyjuice/visualize/visualize.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/pyjuice/visualize/visualize.py b/src/pyjuice/visualize/visualize.py index e88a4cae..ca90cd39 100644 --- a/src/pyjuice/visualize/visualize.py +++ b/src/pyjuice/visualize/visualize.py @@ -33,6 +33,8 @@ def plot_pc(ns, """ G = nx.DiGraph() node_list = serialize_nodes(ns) + for item in node_list: + item["num_nodes"] = item["num_node_blocks"] * item["block_size"] pos = {} nx.set_node_attributes(G, [], "node_type") nx.set_node_attributes(G, [], "num_nodes") @@ -102,19 +104,13 @@ def plot_pc(ns, def plot_tensor_node_connection(ns, node_id : int = 0): G = nx.DiGraph() node_list = serialize_nodes(ns) + for item in node_list: + item["num_nodes"] = item["num_node_blocks"] * item["block_size"] node_target = node_list[node_id] pos = {} if node_target['type']=='Input': - # print(f'The target node {node_id} is a Input node...') - # input_node_list = [] - # for i in range(node_target['num_nodes']): - # G.add_node(f'i{i}') - # input_node_list.append(f'i{i}') - # pos = nx.circular_layout(G) - # nx.draw(G, pos, node_color="#A5CE9D", with_labels=True, font_size=6) # green is sum nodes - # return input_node_list #return input node list - print(f"\n>>The target node {node_id} is a Input node, it has {node_target['num_nodes']} nodes, & no connection among them.<<\n") + print(f"\n>>The target node {node_id} is an Input node, it has {node_target['num_nodes']} nodes, & no connection among them.<<\n") return elif node_target['type']=='Product': @@ -158,7 +154,7 @@ def plot_tensor_node_connection(ns, node_id : int = 0): # nx.draw_networkx_nodes(G, pos, nodelist=prod_node_list, node_color="#A2C8DD", **options) # blue is product nodes # nx.draw_networkx_edges(G, pos, width=0.1) - #adjacency matrix heatmap plot + # Adjacency matrix heatmap plot fig, ax = plt.subplots() ax.spy(adjacency_manual) ax.set_aspect('auto') From 8b49e7482bd7f8cb551fc1b8febf41fc5a838ea3 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Mon, 11 Mar 2024 00:40:12 +0800 Subject: [PATCH 08/53] improve numerical stability of hclt correctness tests --- tests/structures/hclt_correctness_test.py | 162 ++++++++++++++++++---- 1 file changed, 134 insertions(+), 28 deletions(-) diff --git a/tests/structures/hclt_correctness_test.py b/tests/structures/hclt_correctness_test.py index 45f58f60..d3823499 100644 --- a/tests/structures/hclt_correctness_test.py +++ b/tests/structures/hclt_correctness_test.py @@ -7,7 +7,7 @@ import pyjuice.nodes.distributions as dists -def hclt_forward_test(): +def test_hclt_forward(): device = torch.device("cuda:0") @@ -97,7 +97,7 @@ def hclt_forward_test(): assert torch.all(torch.abs(nmars - node_mars[sid:eid,:]) < 4e-3) -def hclt_backward_test(): +def test_hclt_single_layer_backward(): device = torch.device("cuda:0") @@ -128,19 +128,123 @@ def hclt_backward_test(): data_cpu = batch_data.cpu().long() batch_size = batch_data.size(0) + pc.init_param_flows(flows_memory = 0.0) + lls = pc(batch_data) pc.backward(batch_data.permute(1, 0), allow_modify_flows = False) pc.update_param_flows() - node_mars = pc.node_mars.cpu() - node_flows = pc.node_flows.cpu() + for layer_id in range(1, len(pc.inner_layer_groups) - 2, 2): + + node_mars = pc.node_mars.clone() + node_flows = pc.node_flows.clone() + element_mars = pc.element_mars.clone() + element_flows = pc.element_flows.clone() + params = pc.params.clone() + param_flows = pc.param_flows.clone().zero_() + + my_layer = pc.inner_layer_groups[layer_id][0] + previous_layer = pc.inner_layer_groups[layer_id-1][0] + + previous_layer.forward(node_mars, element_mars, _for_backward = True) + + my_layer.backward(node_flows, element_flows, node_mars, element_mars, params, + param_flows = param_flows, allow_modify_flows = False, propagation_alg = "LL") + + chids = my_layer.partitioned_chids[0] + parids = my_layer.partitioned_parids[0] + parpids = my_layer.partitioned_parpids[0] + + nids = my_layer.partitioned_nids[0] + cids = my_layer.partitioned_cids[0] + pids = my_layer.partitioned_pids[0] + pfids = my_layer.partitioned_pfids[0] + + for i in range(chids.size(0)): + eflows = torch.zeros([block_size, batch_size], dtype = torch.float32, device = device) + + for j in range(parids.size(1)): + nflows = node_flows[parids[i,j]:parids[i,j]+block_size,:] # [num_par_nodes, batch_size] + nmars = node_mars[parids[i,j]:parids[i,j]+block_size,:] # [num_par_nodes, batch_size] + emars = element_mars[chids[i]:chids[i]+block_size,:] # [num_ch_nodes, batch_size] + epars = params[parpids[i,j]:parpids[i,j]+block_size**2].reshape(block_size, block_size) # [num_ch_nodes, num_par_nodes] + fpars = param_flows[pfids[i,j]:pfids[i,j]+block_size**2].reshape(block_size, block_size) # [num_ch_nodes, num_par_nodes] + + curr_eflows = (nflows[None,:,:] * (epars.log()[:,:,None] + emars[:,None,:] - nmars[None,:,:]).exp()).sum(dim = 1) + eflows += curr_eflows + + assert torch.all(torch.abs(eflows - element_flows[chids[i]:chids[i]+block_size,:]) < 1e-3) + + for i in range(nids.size(0)): + for j in range(0, cids.size(1), block_size): + nflows = node_flows[nids[i]:nids[i]+block_size,:] # [num_par_nodes, batch_size] + nmars = node_mars[nids[i]:nids[i]+block_size,:] # [num_par_nodes, batch_size] + emars = element_mars[cids[i,j]:cids[i,j]+block_size,:] # [num_ch_nodes, batch_size] + epars = params[pids[i,j]:pids[i,j]+block_size**2].reshape(block_size, block_size) # [num_ch_nodes, num_par_nodes] + fpars = param_flows[pfids[i,j]:pfids[i,j]+block_size**2].reshape(block_size, block_size) # [num_ch_nodes, num_par_nodes] + + pflows = (nflows[None,:,:] * (epars.log()[:,:,None] + emars[:,None,:] - nmars[None,:,:]).exp()).sum(dim = 2) + + assert torch.all(torch.abs(fpars - pflows) < 3e-4 * batch_size) + + +def test_hclt_backward(): + + device = torch.device("cuda:0") + + train_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = True, download = True) + + train_data = train_dataset.data.reshape(60000, 28*28)[:5000,:] + + num_features = train_data.size(1) + num_latents = 128 + + root_ns = juice.structures.HCLT( + train_data.float().to(device), + num_bins = 32, + sigma = 0.5 / 32, + num_latents = num_latents, + chunk_size = 32 + ) + root_ns.init_parameters() + + pc = juice.TensorCircuit(root_ns) + + pc.to(device) + + block_size = root_ns.chs[0].block_size + num_blocks = num_latents // block_size + + batch_data = train_data[:512,:].contiguous().to(device) + data_cpu = batch_data.long() + batch_size = batch_data.size(0) + + pc.init_param_flows(flows_memory = 0.0) + + lls = pc(batch_data) + pc.backward(batch_data.permute(1, 0), allow_modify_flows = False) + + pc.update_param_flows() + + node_mars = pc.node_mars + node_flows = pc.node_flows ns2flows = dict() - ns2flows[root_ns] = torch.ones([1, batch_size]) + ns2flows[root_ns] = torch.ones([1, batch_size], device = device) + + ch2par = dict() + for ns in root_ns: + for cs in ns.chs: + if cs not in ch2par: + ch2par[cs] = set() + ch2par[cs].add(ns) + + visited = set() with torch.no_grad(): for ns in root_ns(reverse = True): + visited.add(ns) if ns == root_ns: sid, eid = ns._output_ind_range @@ -150,14 +254,14 @@ def hclt_backward_test(): nmars = node_mars[sid:eid,:] for i, cs in enumerate(ns.chs): - params = ns._params.reshape(1, num_blocks * ns.num_chs, 1, block_size).permute(0, 2, 1, 3) + params = ns._params.reshape(1, num_blocks * ns.num_chs, 1, block_size).permute(0, 2, 1, 3).to(device) params = params[:,:,i*num_blocks:(i+1)*num_blocks,:].reshape(1, num_latents) - param_flows = ns._param_flows.reshape(1, num_blocks * ns.num_chs, 1, block_size).permute(0, 2, 1, 3) + param_flows = ns._param_flows.reshape(1, num_blocks * ns.num_chs, 1, block_size).permute(0, 2, 1, 3).to(device) param_flows = param_flows[:,:,i*num_blocks:(i+1)*num_blocks,:].reshape(1, num_latents) if cs.is_prod(): - emars = torch.zeros([num_latents, batch_size]) + emars = torch.zeros([num_latents, batch_size], device = device) for cns in cs.chs: sid, eid = cns._output_ind_range emars += node_mars[sid:eid,:] @@ -170,58 +274,59 @@ def hclt_backward_test(): assert torch.all(torch.abs(pflows - param_flows[0,:]) < 6e-3) if cs not in ns2flows: - ns2flows[cs] = torch.zeros([num_latents, batch_size]) + ns2flows[cs] = torch.zeros([num_latents, batch_size], device = device) ns2flows[cs] += eflows elif ns.is_prod(): nflows = ns2flows[ns] for cs in ns.chs: if cs not in ns2flows: - ns2flows[cs] = torch.zeros([num_latents, batch_size]) + ns2flows[cs] = torch.zeros([num_latents, batch_size], device = device) ns2flows[cs] += nflows elif ns.is_sum(): + for par_cs in ch2par[ns]: + assert par_cs in visited + nflows = ns2flows[ns] sid, eid = ns._output_ind_range - assert (torch.abs(nflows - node_flows[sid:eid,:]) > 1e-3).float().mean() < 0.02 + if len(ns.scope) > 2: + assert torch.all(torch.abs(nflows - node_flows[sid:eid,:]) < 1e-3) nmars = node_mars[sid:eid,:] for i, cs in enumerate(ns.chs): - params = ns._params.reshape(num_blocks, num_blocks * ns.num_chs, block_size, block_size).permute(0, 2, 1, 3) + params = ns._params.reshape(num_blocks, num_blocks * ns.num_chs, block_size, block_size).permute(0, 2, 1, 3).to(device) params = params[:,:,i*num_blocks:(i+1)*num_blocks,:].reshape(num_latents, num_latents) - param_flows = ns._param_flows.reshape(num_blocks, num_blocks * ns.num_chs, block_size, block_size).permute(0, 2, 1, 3) + param_flows = ns._param_flows.reshape(num_blocks, num_blocks * ns.num_chs, block_size, block_size).permute(0, 2, 1, 3).to(device) param_flows = param_flows[:,:,i*num_blocks:(i+1)*num_blocks,:].reshape(num_latents, num_latents) if cs.is_prod(): - emars = torch.zeros([num_latents, batch_size]) + emars = torch.zeros([num_latents, batch_size], device = device) for cns in cs.chs: sid, eid = cns._output_ind_range emars += node_mars[sid:eid,:] else: raise ValueError() - log_n_fdm = nflows.log() - nmars - log_n_fdm_max = log_n_fdm.max(dim = 0).values - n_fdm_sub = (log_n_fdm - log_n_fdm_max[None,:]).exp() - - eflows = torch.matmul(params.permute(1, 0), n_fdm_sub) * (emars + log_n_fdm_max[None,:]).exp() + eflows = (nflows[None,:,:] * (params.permute(1, 0).log()[:,:,None] + emars[:,None,:] - nmars[None,:,:]).exp()).sum(dim = 1) + pflows = (nflows[None,:,:] * (params.permute(1, 0).log()[:,:,None] + emars[:,None,:] - nmars[None,:,:]).exp()).sum(dim = 2).permute(1, 0) - scaled_emars = (emars + log_n_fdm_max[None,:]).exp() - pflows = torch.matmul(n_fdm_sub, scaled_emars.permute(1, 0)) * params + if not torch.all(torch.abs(pflows - param_flows) < 1e-3 * batch_size): + import pdb; pdb.set_trace() - assert torch.all(torch.abs(pflows - param_flows) < 0.5) + assert torch.all(torch.abs(pflows - param_flows) < 1e-3 * batch_size) if cs not in ns2flows: - ns2flows[cs] = torch.zeros([num_latents, batch_size]) + ns2flows[cs] = torch.zeros([num_latents, batch_size], device = device) ns2flows[cs] += eflows -def hclt_em_test(): +def test_hclt_em(): device = torch.device("cuda:0") @@ -303,7 +408,8 @@ def hclt_em_test(): if __name__ == "__main__": - torch.manual_seed(320942) - hclt_forward_test() - hclt_backward_test() - hclt_em_test() + # torch.manual_seed(320942) + test_hclt_forward() + test_hclt_single_layer_backward() + test_hclt_backward() + test_hclt_em() From b925bdc8b3447b6d48a488fe89ef0acf5ba4ce72 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Mon, 11 Mar 2024 01:29:57 +0800 Subject: [PATCH 09/53] setup pytest --- pyproject.toml | 13 +++++++- tests/io/io_test.py | 8 ++--- tests/layer/input_layer_test.py | 13 ++++---- tests/layer/layer_compilation_test.py | 10 +++--- tests/layer/prod_layer_test.py | 9 +++--- tests/layer/propagation_algs_test.py | 12 ++++---- tests/layer/sparse_prod_layer_test.py | 4 +-- tests/layer/sum_layer_test.py | 18 ++++++----- tests/lvd/counting_lvd_test.py | 4 +-- tests/model/backward_test.py | 12 ++++---- tests/model/block_sparse_pc_test.py | 6 ++-- tests/model/compilation_speed_test.py | 10 +++--- tests/model/forward_test.py | 16 +++++----- tests/model/homogeneous_hmm_test.py | 4 +-- tests/model/non_sd_pcs_test.py | 4 +-- tests/model/parameter_tying_test.py | 8 ++--- tests/model/partial_eval_test.py | 8 ++--- tests/model/simple_model_test.py | 4 +-- tests/model/structured_blk_sparse_pc_test.py | 4 +-- tests/nodes/input_dists_test.py | 32 ++++++++++---------- tests/nodes/nodes_test.py | 4 +-- tests/queries/cond_test.py | 4 +-- tests/queries/marginal_test.py | 8 ++--- tests/structures/hclt_correctness_test.py | 5 ++- tests/structures/hclt_test.py | 8 ++--- tests/structures/hmm_correctness_test.py | 4 +-- tests/structures/hmm_speed_test.py | 7 +++-- tests/structures/pd_hclt_test.py | 10 +++--- tests/structures/pd_test.py | 6 ++-- tests/structures/rat_spn_test.py | 6 ++-- tests/transformations/blockify_test.py | 8 ++--- tests/transformations/copy_test.py | 4 +-- tests/transformations/merge_test.py | 12 ++++---- tests/transformations/pruning_test.py | 12 ++++---- tests/visualize/plots_test.py | 32 ++++++++++++-------- 35 files changed, 181 insertions(+), 148 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a8a612a8..51adc70a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,14 +21,25 @@ authors = [ {name="StarAI", email="guyvdb@cs.ucla.edu"}, ] +[project.optional-dependencies] +dev = [ + "pytest", + "pytest-xdist", + "pytest-skip-slow", + "torchvision", + "matplotlib" +] + [options.packages.find] where = "src" [tool.setuptools.dynamic] readme = {file = "README.md"} - [tool.pytest.ini_options] addopts = [ "--import-mode=importlib", ] +testpaths = [ + "tests/" +] \ No newline at end of file diff --git a/tests/io/io_test.py b/tests/io/io_test.py index 59313a83..6a50b3d6 100644 --- a/tests/io/io_test.py +++ b/tests/io/io_test.py @@ -11,7 +11,7 @@ import pytest -def io_test(): +def test_io(): num_node_blocks = 2 block_size = 4 @@ -46,7 +46,7 @@ def io_test(): assert n0.chs[1].chs[1].dist.num_cats == n0_dup.chs[1].chs[1].dist.num_cats -def io_param_test(): +def test_io_param(): num_node_blocks = 2 block_size = 4 @@ -79,5 +79,5 @@ def io_param_test(): if __name__ == "__main__": - io_test() - io_param_test() + test_io() + test_io_param() diff --git a/tests/layer/input_layer_test.py b/tests/layer/input_layer_test.py index e53200ff..d0e4b47b 100644 --- a/tests/layer/input_layer_test.py +++ b/tests/layer/input_layer_test.py @@ -13,7 +13,7 @@ import pytest -def input_layer_test(): +def test_input_layer(): device = torch.device("cuda:0") @@ -111,7 +111,7 @@ def input_layer_test(): assert torch.all(torch.abs(new_params - layer.params) < 1e-4) -def tied_bp_test(): +def test_tied_bp(): device = torch.device("cuda:0") @@ -166,7 +166,8 @@ def tied_bp_test(): assert torch.all(torch.abs(param_flows - layer.param_flows) < 1e-4) -def speed_test(): +@pytest.mark.slow +def test_speed(): device = torch.device("cuda:0") @@ -283,6 +284,6 @@ def speed_test(): if __name__ == "__main__": - input_layer_test() - tied_bp_test() - speed_test() \ No newline at end of file + test_input_layer() + test_tied_bp() + test_speed() \ No newline at end of file diff --git a/tests/layer/layer_compilation_test.py b/tests/layer/layer_compilation_test.py index f3c4eed6..01072d7c 100644 --- a/tests/layer/layer_compilation_test.py +++ b/tests/layer/layer_compilation_test.py @@ -14,7 +14,7 @@ import pytest -def prod_layer_compilation_test(): +def test_prod_layer_compilation(): for block_size in [1, 8, 16]: @@ -38,7 +38,7 @@ def prod_layer_compilation_test(): prod_layer_cpu = ProdLayer([np0, np1, np2, np3, np4, np5], layer_sparsity_tol = 0.1, disable_gpu_compilation = True) prod_layer_gpu = ProdLayer([np0, np1, np2, np3, np4, np5], layer_sparsity_tol = 0.1, force_gpu_compilation = True) - for i in range(3): + for i in range(2): assert torch.all(prod_layer_cpu.partitioned_nids[i] == prod_layer_gpu.partitioned_nids[i]) assert torch.all(prod_layer_cpu.partitioned_cids[i] == prod_layer_gpu.partitioned_cids[i]) @@ -47,7 +47,7 @@ def prod_layer_compilation_test(): assert torch.all(prod_layer_cpu.partitioned_parids[i] == prod_layer_gpu.partitioned_parids[i]) -def sum_layer_compilation_test(): +def test_sum_layer_compilation(): for block_size in [1, 8, 16]: @@ -112,5 +112,5 @@ def sum_layer_compilation_test(): if __name__ == "__main__": - # prod_layer_compilation_test() - sum_layer_compilation_test() + test_prod_layer_compilation() + test_sum_layer_compilation() diff --git a/tests/layer/prod_layer_test.py b/tests/layer/prod_layer_test.py index 9203e464..0daba678 100644 --- a/tests/layer/prod_layer_test.py +++ b/tests/layer/prod_layer_test.py @@ -14,7 +14,7 @@ import pytest -def prod_layer_test(): +def test_prod_layer(): device = torch.device("cuda:0") @@ -91,7 +91,8 @@ def prod_layer_test(): assert torch.all(torch.abs(node_flows[8*block_size+i,:] - element_flows[4*block_size+i,:]) < 1e-4) -def speed_test(): +@pytest.mark.slow +def test_speed(): device = torch.device("cuda:0") @@ -166,5 +167,5 @@ def speed_test(): if __name__ == "__main__": torch.manual_seed(2390) - prod_layer_test() - speed_test() + test_prod_layer() + test_speed() diff --git a/tests/layer/propagation_algs_test.py b/tests/layer/propagation_algs_test.py index 03f37399..c312a0d2 100644 --- a/tests/layer/propagation_algs_test.py +++ b/tests/layer/propagation_algs_test.py @@ -14,7 +14,7 @@ import pytest -def ll_prop_test(): +def test_ll_prop(): device = torch.device("cuda:0") @@ -123,7 +123,7 @@ def ll_prop_test(): assert torch.all(torch.abs(my_pflows - param_flows) < 2e-3) -def general_ll_prop_test(): +def test_general_ll_prop(): device = torch.device("cuda:0") @@ -235,7 +235,7 @@ def general_ll_prop_test(): assert torch.all(torch.abs(my_pflows - param_flows) < 2e-3) -def mpe_prop_test(): +def test_mpe_prop(): device = torch.device("cuda:0") @@ -339,6 +339,6 @@ def mpe_prop_test(): if __name__ == "__main__": torch.manual_seed(280) - ll_prop_test() - general_ll_prop_test() - mpe_prop_test() + test_ll_prop() + test_general_ll_prop() + test_mpe_prop() diff --git a/tests/layer/sparse_prod_layer_test.py b/tests/layer/sparse_prod_layer_test.py index d34c0343..4b02357b 100644 --- a/tests/layer/sparse_prod_layer_test.py +++ b/tests/layer/sparse_prod_layer_test.py @@ -14,7 +14,7 @@ import pytest -def sparse_prod_layer_test(): +def test_sparse_prod_layer(): device = torch.device("cuda:0") @@ -104,4 +104,4 @@ def sparse_prod_layer_test(): if __name__ == "__main__": torch.manual_seed(2390) - sparse_prod_layer_test() + test_sparse_prod_layer() diff --git a/tests/layer/sum_layer_test.py b/tests/layer/sum_layer_test.py index 683dacf2..85870d24 100644 --- a/tests/layer/sum_layer_test.py +++ b/tests/layer/sum_layer_test.py @@ -14,7 +14,7 @@ import pytest -def sum_layer_test(): +def test_sum_layer(): device = torch.device("cuda:0") @@ -141,7 +141,7 @@ def sum_layer_test(): assert torch.all(torch.abs(my_pflows - param_flows) < 2e-3) -def corner_case_test(): +def test_corner_case(): device = torch.device("cuda:0") @@ -280,7 +280,8 @@ def corner_case_test(): assert torch.all(torch.abs(my_pflows - param_flows) < 2e-3) -def speed_test(): +@pytest.mark.slow +def test_speed(): device = torch.device("cuda:0") @@ -360,7 +361,8 @@ def speed_test(): print("--------------------------------------------------------------") -def block_sparse_speed_test(): +@pytest.mark.slow +def test_block_sparse_speed(): device = torch.device("cuda:0") @@ -448,7 +450,7 @@ def block_sparse_speed_test(): if __name__ == "__main__": torch.manual_seed(3890) - sum_layer_test() - corner_case_test() - speed_test() - block_sparse_speed_test() \ No newline at end of file + test_sum_layer() + test_corner_case() + test_speed() + test_block_sparse_speed() \ No newline at end of file diff --git a/tests/lvd/counting_lvd_test.py b/tests/lvd/counting_lvd_test.py index 71aac6e6..16cd10e9 100644 --- a/tests/lvd/counting_lvd_test.py +++ b/tests/lvd/counting_lvd_test.py @@ -9,7 +9,7 @@ import pytest -def counting_lvd_test(): +def test_counting_lvd(): num_nodes = 2 with juice.LVDistiller(backend = "counting", pseudocount = 0.0): @@ -32,4 +32,4 @@ def counting_lvd_test(): if __name__ == "__main__": - counting_lvd_test() \ No newline at end of file + test_counting_lvd() \ No newline at end of file diff --git a/tests/model/backward_test.py b/tests/model/backward_test.py index 3fd59541..8da0ac33 100644 --- a/tests/model/backward_test.py +++ b/tests/model/backward_test.py @@ -10,7 +10,7 @@ import pytest -def backward_test(): +def test_backward(): ni0 = inputs(0, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) ni1 = inputs(1, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) @@ -119,7 +119,7 @@ def backward_test(): assert torch.abs(inner_param_flows[13] + inner_param_flows[14] - 1.0) < 1e-4 -def non_sd_pc_backward_test(): +def test_non_sd_pc_backward(): ni00 = inputs(0, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) ni10 = inputs(1, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) ni20 = inputs(2, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) @@ -177,7 +177,7 @@ def non_sd_pc_backward_test(): assert torch.abs(pc.param_flows[11] - fp4) < 1e-3 -def sparse_pc_backward_test(): +def test_sparse_pc_backward(): ni0 = inputs(0, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) ni1 = inputs(1, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) @@ -242,6 +242,6 @@ def sparse_pc_backward_test(): if __name__ == "__main__": - backward_test() - non_sd_pc_backward_test() - sparse_pc_backward_test() \ No newline at end of file + test_backward() + test_non_sd_pc_backward() + test_sparse_pc_backward() \ No newline at end of file diff --git a/tests/model/block_sparse_pc_test.py b/tests/model/block_sparse_pc_test.py index 7f62d4df..83153640 100644 --- a/tests/model/block_sparse_pc_test.py +++ b/tests/model/block_sparse_pc_test.py @@ -14,7 +14,7 @@ import pytest -def block_sparse_pc_test(): +def test_block_sparse_pc(): device = torch.device("cuda:0") @@ -120,9 +120,9 @@ def block_sparse_pc_test(): else: curr_par_flows = torch.matmul(1.0 / ns_vals[ni], np2_vals[ci-num_node_blocks*2].permute(1, 0)) * params[i] - assert torch.all(torch.abs(param_flows[i] - curr_par_flows) < 1e-2) + assert torch.all(torch.abs(param_flows[i] - curr_par_flows) < 1e-3 * curr_par_flows) if __name__ == "__main__": torch.manual_seed(3890) - block_sparse_pc_test() + test_block_sparse_pc() diff --git a/tests/model/compilation_speed_test.py b/tests/model/compilation_speed_test.py index 3f2b48b3..5012f501 100644 --- a/tests/model/compilation_speed_test.py +++ b/tests/model/compilation_speed_test.py @@ -11,7 +11,8 @@ import pytest -def compile_dense_pc_test(): +@pytest.mark.slow +def test_compile_dense_pc(): num_latents = 2048 num_cats = 512 num_vars = 64 @@ -33,7 +34,8 @@ def compile_dense_pc_test(): assert t1 - t0 < 60 -def compile_sparse_pc_test(): +@pytest.mark.slow +def test_compile_sparse_pc(): num_latents = 4096 num_cats = 200 num_vars = 16 @@ -71,5 +73,5 @@ def compile_sparse_pc_test(): if __name__ == "__main__": - compile_dense_pc_test() - compile_sparse_pc_test() \ No newline at end of file + test_compile_dense_pc() + test_compile_sparse_pc() \ No newline at end of file diff --git a/tests/model/forward_test.py b/tests/model/forward_test.py index 3dfdaa96..32f9263e 100644 --- a/tests/model/forward_test.py +++ b/tests/model/forward_test.py @@ -10,7 +10,7 @@ import pytest -def forward_test(): +def test_forward(): ni0 = inputs(0, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) ni1 = inputs(1, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) @@ -71,7 +71,7 @@ def forward_test(): assert torch.abs(pc.node_mars[13,0] - torch.log(s)) < 1e-4 -def non_sd_pc_forward_test(): +def test_non_sd_pc_forward(): ni00 = inputs(0, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) ni10 = inputs(1, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) ni20 = inputs(2, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) @@ -138,7 +138,7 @@ def non_sd_pc_forward_test(): assert torch.abs(torch.exp(pc.node_mars[17,0]) - (f9 + f10 + f11 + f12)) < 1e-3 -def sparse_pc_forward_test(): +def test_sparse_pc_forward(): ni0 = inputs(0, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) ni1 = inputs(1, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) @@ -199,7 +199,7 @@ def sparse_pc_forward_test(): assert torch.abs(pc.node_mars[13,0] - torch.log(s)) < 1e-4 -def non_sd_pc2_forward_test(): +def test_non_sd_pc2_forward(): n0 = inputs(0, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) n1 = multiply(inputs(0, num_nodes = 2, dist = dists.Categorical(num_cats = 2))) n2 = multiply(inputs(0, num_nodes = 2, dist = dists.Categorical(num_cats = 2))) @@ -221,7 +221,7 @@ def non_sd_pc2_forward_test(): if __name__ == "__main__": - forward_test() - non_sd_pc_forward_test() - sparse_pc_forward_test() - non_sd_pc2_forward_test() \ No newline at end of file + test_forward() + test_non_sd_pc_forward() + test_sparse_pc_forward() + test_non_sd_pc2_forward() \ No newline at end of file diff --git a/tests/model/homogeneous_hmm_test.py b/tests/model/homogeneous_hmm_test.py index e6cb8ef9..deb9119e 100644 --- a/tests/model/homogeneous_hmm_test.py +++ b/tests/model/homogeneous_hmm_test.py @@ -11,7 +11,7 @@ import pytest -def homogeneous_hmm_test(): +def test_homogeneous_hmm(): block_size = 1 @@ -217,4 +217,4 @@ def homogeneous_hmm_test(): if __name__ == "__main__": torch.manual_seed(2390) - homogeneous_hmm_test() + test_homogeneous_hmm() diff --git a/tests/model/non_sd_pcs_test.py b/tests/model/non_sd_pcs_test.py index c7be5e37..22b570ff 100644 --- a/tests/model/non_sd_pcs_test.py +++ b/tests/model/non_sd_pcs_test.py @@ -10,7 +10,7 @@ import pytest -def non_sd_test(): +def test_non_sd(): ni1 = inputs(0, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) ni2 = inputs(1, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) ni3 = inputs(2, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) @@ -103,4 +103,4 @@ def non_sd_test(): if __name__ == "__main__": torch.manual_seed(129) - non_sd_test() \ No newline at end of file + test_non_sd() \ No newline at end of file diff --git a/tests/model/parameter_tying_test.py b/tests/model/parameter_tying_test.py index 4a0f149f..02874cfb 100644 --- a/tests/model/parameter_tying_test.py +++ b/tests/model/parameter_tying_test.py @@ -11,7 +11,7 @@ import pytest -def simple_structure_test_block1(): +def test_simple_structure_block1(): block_size = 1 @@ -302,7 +302,7 @@ def simple_structure_test_block1(): assert torch.all(torch.abs(param_flows1.reshape(-1) - pc.params[5:13].cpu()) < 1e-4) -def simple_structure_test_block16(): +def test_simple_structure_block16(): block_size = 16 @@ -610,5 +610,5 @@ def simple_structure_test_block16(): if __name__ == "__main__": torch.manual_seed(2390) - simple_structure_test_block1() - simple_structure_test_block16() + test_simple_structure_block1() + test_simple_structure_block16() diff --git a/tests/model/partial_eval_test.py b/tests/model/partial_eval_test.py index e654c787..f6f3751e 100644 --- a/tests/model/partial_eval_test.py +++ b/tests/model/partial_eval_test.py @@ -12,7 +12,7 @@ import pytest -def partial_eval_forward_test(): +def test_partial_eval_forward(): ni0 = inputs(0, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) ni1 = inputs(1, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) @@ -68,7 +68,7 @@ def partial_eval_forward_test(): assert torch.all(torch.abs(lls - lls2) < 1e-4) -def partial_eval_backward_test(): +def test_partial_eval_backward(): ni0 = inputs(0, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) ni1 = inputs(1, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) @@ -112,5 +112,5 @@ def partial_eval_backward_test(): if __name__ == "__main__": - partial_eval_forward_test() - partial_eval_backward_test() + test_partial_eval_forward() + test_partial_eval_backward() diff --git a/tests/model/simple_model_test.py b/tests/model/simple_model_test.py index 0afbdc82..7f633f9c 100644 --- a/tests/model/simple_model_test.py +++ b/tests/model/simple_model_test.py @@ -14,7 +14,7 @@ import pytest -def simple_model_test(): +def test_simple_model(): device = torch.device("cuda:0") @@ -525,4 +525,4 @@ def simple_model_test(): if __name__ == "__main__": torch.manual_seed(23892) - simple_model_test() + test_simple_model() diff --git a/tests/model/structured_blk_sparse_pc_test.py b/tests/model/structured_blk_sparse_pc_test.py index 0a4041ba..f9a3ede4 100644 --- a/tests/model/structured_blk_sparse_pc_test.py +++ b/tests/model/structured_blk_sparse_pc_test.py @@ -11,7 +11,7 @@ import pytest -def structured_blk_sparse_pc_test(): +def test_structured_blk_sparse_pc(): device = torch.device("cuda:0") @@ -217,4 +217,4 @@ def structured_blk_sparse_pc_test(): if __name__ == "__main__": torch.manual_seed(89172) - structured_blk_sparse_pc_test() + test_structured_blk_sparse_pc() diff --git a/tests/nodes/input_dists_test.py b/tests/nodes/input_dists_test.py index a3040d60..3a8b27b2 100644 --- a/tests/nodes/input_dists_test.py +++ b/tests/nodes/input_dists_test.py @@ -11,7 +11,7 @@ import pytest -def categorical_nodes_test(): +def test_categorical_nodes(): ni0 = inputs(0, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) ni1 = inputs(1, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) @@ -81,7 +81,7 @@ def categorical_nodes_test(): assert torch.all(torch.abs(new_params - pc.input_layer_group[0].params) < 1e-4) -def bernoulli_nodes_test(): +def test_bernoulli_nodes(): ni0 = inputs(0, num_nodes = 2, dist = dists.Bernoulli()) ni1 = inputs(1, num_nodes = 2, dist = dists.Bernoulli()) @@ -152,7 +152,7 @@ def bernoulli_nodes_test(): assert torch.all(torch.abs(new_params - pc.input_layer_group[0].params) < 1e-4) -def gaussian_nodes_test(): +def test_gaussian_nodes(): ni0 = inputs(0, num_nodes = 2, dist = dists.Gaussian(mu = 0.0, sigma = 1.0)) ni1 = inputs(1, num_nodes = 2, dist = dists.Gaussian(mu = 0.0, sigma = 1.0)) @@ -229,7 +229,7 @@ def gaussian_nodes_test(): assert torch.all(torch.abs(updated_sigma.clamp(min = 0.01) - pc.input_layer_group[0].params.reshape(8, 2)[:,1]) < 1e-4) -def discrete_logistic_nodes_test(): +def test_discrete_logistic_nodes(): ni0 = inputs(0, num_nodes = 2, dist = dists.DiscreteLogistic(val_range = [-1.0, 1.0], num_cats = 5)) ni1 = inputs(1, num_nodes = 2, dist = dists.DiscreteLogistic(val_range = [-1.0, 1.0], num_cats = 5)) @@ -326,7 +326,7 @@ def discrete_logistic_nodes_test(): assert torch.all(torch.abs(updated_s - pc.input_layer_group[0].params.reshape(8, 2)[:,1]) < 1e-4) -def discrete_logistic_nodes_behavior_test(): +def test_discrete_logistic_nodes_behavior(): ni0 = inputs(0, num_nodes = 2, dist = dists.DiscreteLogistic(val_range = [-1.0, 1.0], num_cats = 5)) ni1 = inputs(1, num_nodes = 2, dist = dists.DiscreteLogistic(val_range = [-1.0, 1.0], num_cats = 5)) @@ -373,7 +373,7 @@ def discrete_logistic_nodes_behavior_test(): assert (ni3._params[0] > 0.65 and ni3._params[2] < 0.4) or (ni3._params[2] > 0.65 and ni3._params[0] < 0.4) -def masked_categorical_nodes_range_test(): +def test_masked_categorical_nodes_range(): ni0 = inputs(0, num_nodes = 2, dist = dists.MaskedCategorical(num_cats = 5, mask_mode = "range"), mask = torch.tensor([[2, 4], [3, 5]])) ni1 = inputs(1, num_nodes = 2, dist = dists.MaskedCategorical(num_cats = 5, mask_mode = "range"), mask = torch.tensor([[2, 4], [3, 5]])) @@ -495,7 +495,7 @@ def masked_categorical_nodes_range_test(): assert torch.all(torch.abs(updated_params - pc.input_layer_group[0].params.reshape(8, 8)) < 1e-4) -def masked_categorical_nodes_full_mask_test(): +def test_masked_categorical_nodes_full_mask(): ni0 = inputs(0, num_nodes = 2, dist = dists.MaskedCategorical(num_cats = 5, mask_mode = "full_mask"), mask = torch.tensor([[0, 0, 1, 1, 0], [0, 0, 0, 1, 1]])) ni1 = inputs(1, num_nodes = 2, dist = dists.MaskedCategorical(num_cats = 5, mask_mode = "full_mask"), mask = torch.tensor([[0, 0, 1, 1, 0], [0, 0, 0, 1, 1]])) @@ -617,7 +617,7 @@ def masked_categorical_nodes_full_mask_test(): assert torch.all(torch.abs(updated_params - pc.input_layer_group[0].params.reshape(8, 11)) < 1e-4) -def masked_categorical_nodes_rev_range_test(): +def test_masked_categorical_nodes_rev_range(): ni0 = inputs(0, num_nodes = 2, dist = dists.MaskedCategorical(num_cats = 5, mask_mode = "rev_range"), mask = torch.tensor([[2, 4], [3, 5]])) ni1 = inputs(1, num_nodes = 2, dist = dists.MaskedCategorical(num_cats = 5, mask_mode = "rev_range"), mask = torch.tensor([[2, 4], [3, 5]])) @@ -733,11 +733,11 @@ def masked_categorical_nodes_rev_range_test(): if __name__ == "__main__": torch.manual_seed(2390) - categorical_nodes_test() - bernoulli_nodes_test() - gaussian_nodes_test() - discrete_logistic_nodes_test() - discrete_logistic_nodes_behavior_test() - masked_categorical_nodes_range_test() - masked_categorical_nodes_full_mask_test() - masked_categorical_nodes_rev_range_test() + test_categorical_nodes() + test_bernoulli_nodes() + test_gaussian_nodes() + test_discrete_logistic_nodes() + test_discrete_logistic_nodes_behavior() + test_masked_categorical_nodes_range() + test_masked_categorical_nodes_full_mask() + test_masked_categorical_nodes_rev_range() diff --git a/tests/nodes/nodes_test.py b/tests/nodes/nodes_test.py index 3ec1d88f..ba492494 100644 --- a/tests/nodes/nodes_test.py +++ b/tests/nodes/nodes_test.py @@ -9,7 +9,7 @@ import pytest -def nodes_test(): +def test_nodes(): device = torch.device("cuda:0") @@ -45,4 +45,4 @@ def nodes_test(): if __name__ == "__main__": - nodes_test() \ No newline at end of file + test_nodes() \ No newline at end of file diff --git a/tests/queries/cond_test.py b/tests/queries/cond_test.py index 6bc31e8c..8298fe07 100644 --- a/tests/queries/cond_test.py +++ b/tests/queries/cond_test.py @@ -13,7 +13,7 @@ import pytest -def cat_soft_cond_test(): +def test_cat_soft_cond(): device = torch.device("cuda:0") @@ -54,4 +54,4 @@ def cat_soft_cond_test(): torch.manual_seed(123) torch.cuda.manual_seed(123) - cat_soft_cond_test() + test_cat_soft_cond() diff --git a/tests/queries/marginal_test.py b/tests/queries/marginal_test.py index e4ac1c11..ddb8035e 100644 --- a/tests/queries/marginal_test.py +++ b/tests/queries/marginal_test.py @@ -10,7 +10,7 @@ import pytest -def cat_hard_marginal_test(): +def test_cat_hard_marginal(): device = torch.device("cuda:0") @@ -38,7 +38,7 @@ def cat_hard_marginal_test(): assert torch.all((torch.log(p0 * pc.params[1] + p1 * pc.params[2]) - lls[:,0]).abs() < 1e-4) -def cat_soft_marginal_test(): +def test_cat_soft_marginal(): device = torch.device("cuda:0") @@ -71,5 +71,5 @@ def cat_soft_marginal_test(): torch.manual_seed(123) torch.cuda.manual_seed(123) - cat_hard_marginal_test() - cat_soft_marginal_test() \ No newline at end of file + test_cat_hard_marginal() + test_cat_soft_marginal() \ No newline at end of file diff --git a/tests/structures/hclt_correctness_test.py b/tests/structures/hclt_correctness_test.py index d3823499..2b7fb03e 100644 --- a/tests/structures/hclt_correctness_test.py +++ b/tests/structures/hclt_correctness_test.py @@ -191,6 +191,8 @@ def test_hclt_single_layer_backward(): def test_hclt_backward(): + torch.manual_seed(18329) + device = torch.device("cuda:0") train_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = True, download = True) @@ -316,9 +318,6 @@ def test_hclt_backward(): eflows = (nflows[None,:,:] * (params.permute(1, 0).log()[:,:,None] + emars[:,None,:] - nmars[None,:,:]).exp()).sum(dim = 1) pflows = (nflows[None,:,:] * (params.permute(1, 0).log()[:,:,None] + emars[:,None,:] - nmars[None,:,:]).exp()).sum(dim = 2).permute(1, 0) - if not torch.all(torch.abs(pflows - param_flows) < 1e-3 * batch_size): - import pdb; pdb.set_trace() - assert torch.all(torch.abs(pflows - param_flows) < 1e-3 * batch_size) if cs not in ns2flows: diff --git a/tests/structures/hclt_test.py b/tests/structures/hclt_test.py index 6d93d22b..c881dd56 100644 --- a/tests/structures/hclt_test.py +++ b/tests/structures/hclt_test.py @@ -64,7 +64,7 @@ def full_batch_em_epoch(pc, train_loader, test_loader, device): print(f"[train LL: {train_ll:.2f}; test LL: {test_ll:.2f}].....[train forward+backward+step {t1-t0:.2f}; test forward {t2-t1:.2f}] ") -def hclt_test(): +def test_hclt(): device = torch.device("cuda:0") @@ -146,7 +146,7 @@ def hclt_test(): assert test_ll > -770 -def hclt_logistic_test(): +def test_hclt_logistic(): device = torch.device("cuda:0") @@ -204,5 +204,5 @@ def hclt_logistic_test(): if __name__ == "__main__": torch.manual_seed(3289) - hclt_test() - hclt_logistic_test() + test_hclt() + test_hclt_logistic() diff --git a/tests/structures/hmm_correctness_test.py b/tests/structures/hmm_correctness_test.py index be551b36..6c1ba89f 100644 --- a/tests/structures/hmm_correctness_test.py +++ b/tests/structures/hmm_correctness_test.py @@ -7,7 +7,7 @@ import pyjuice.nodes.distributions as dists -def hmm_forward_backward_test(): +def test_hmm_forward_backward(): device = torch.device("cuda:0") @@ -206,4 +206,4 @@ def hmm_forward_backward_test(): if __name__ == "__main__": torch.manual_seed(23289) - hmm_forward_backward_test() + test_hmm_forward_backward() diff --git a/tests/structures/hmm_speed_test.py b/tests/structures/hmm_speed_test.py index 3f6ea6cb..2c9b58bc 100644 --- a/tests/structures/hmm_speed_test.py +++ b/tests/structures/hmm_speed_test.py @@ -3,8 +3,11 @@ import torch import time +import pytest -def hmm_speed_test(): + +@pytest.mark.slow +def test_hmm_speed(): device = torch.device("cuda:0") @@ -82,5 +85,5 @@ def hmm_speed_test(): if __name__ == "__main__": - hmm_speed_test() + test_hmm_speed() diff --git a/tests/structures/pd_hclt_test.py b/tests/structures/pd_hclt_test.py index e9fcc11a..f72c4fda 100644 --- a/tests/structures/pd_hclt_test.py +++ b/tests/structures/pd_hclt_test.py @@ -4,6 +4,8 @@ import time from torch.utils.data import TensorDataset, DataLoader +import pytest + def evaluate(pc, loader): lls_total = 0.0 @@ -65,7 +67,7 @@ def full_batch_em_epoch(pc, train_loader, test_loader, device): print(f"[train LL: {train_ll:.2f}; test LL: {test_ll:.2f}].....[train forward+backward+step {t1-t0:.2f}; test forward {t2-t1:.2f}] ") -def pd_hclt_degenerative_case_test(): +def test_pd_hclt_degenerative_case(): device = torch.device("cuda:0") @@ -118,7 +120,7 @@ def pd_hclt_degenerative_case_test(): assert test_ll > -690.0 -def pd_hclt_test(): +def test_pd_hclt(): device = torch.device("cuda:0") @@ -180,5 +182,5 @@ def pd_hclt_test(): if __name__ == "__main__": torch.manual_seed(2391) - pd_hclt_degenerative_case_test() - pd_hclt_test() + test_pd_hclt_degenerative_case() + test_pd_hclt() diff --git a/tests/structures/pd_test.py b/tests/structures/pd_test.py index 4cefa80b..6fb83592 100644 --- a/tests/structures/pd_test.py +++ b/tests/structures/pd_test.py @@ -4,6 +4,8 @@ import time from torch.utils.data import TensorDataset, DataLoader +import pytest + def evaluate(pc, loader): lls_total = 0.0 @@ -65,7 +67,7 @@ def full_batch_em_epoch(pc, train_loader, test_loader, device): print(f"[train LL: {train_ll:.2f}; test LL: {test_ll:.2f}].....[train forward+backward+step {t1-t0:.2f}; test forward {t2-t1:.2f}] ") -def pd_test(): +def test_pd(): device = torch.device("cuda:0") @@ -135,4 +137,4 @@ def pd_test(): if __name__ == "__main__": torch.manual_seed(2391) - pd_test() + test_pd() diff --git a/tests/structures/rat_spn_test.py b/tests/structures/rat_spn_test.py index f616a973..27271c50 100644 --- a/tests/structures/rat_spn_test.py +++ b/tests/structures/rat_spn_test.py @@ -5,6 +5,8 @@ from torch.utils.data import TensorDataset, DataLoader import pyjuice.nodes.distributions as dists +import pytest + def evaluate(pc, loader): lls_total = 0.0 @@ -64,7 +66,7 @@ def full_batch_em_epoch(pc, train_loader, test_loader, device): print(f"[train LL: {train_ll:.2f}; test LL: {test_ll:.2f}].....[train forward+backward+step {t1-t0:.2f}; test forward {t2-t1:.2f}] ") -def rat_spn_test(): +def test_rat_spn(): device = torch.device("cuda:0") @@ -124,4 +126,4 @@ def rat_spn_test(): if __name__ == "__main__": torch.manual_seed(3289) - rat_spn_test() + test_rat_spn() diff --git a/tests/transformations/blockify_test.py b/tests/transformations/blockify_test.py index 97e141ac..aced75af 100644 --- a/tests/transformations/blockify_test.py +++ b/tests/transformations/blockify_test.py @@ -11,7 +11,7 @@ import pytest -def block_test(): +def test_block(): with set_block_size(block_size = 2): @@ -78,7 +78,7 @@ def block_test(): assert torch.all(new_n2._params[0][2:4,2:4] == n2._params[3]) -def block_sparse_block_test(): +def test_block_sparse_block(): with set_block_size(block_size = 4): @@ -159,5 +159,5 @@ def block_sparse_block_test(): if __name__ == "__main__": - block_test() - block_sparse_block_test() + test_block() + test_block_sparse_block() diff --git a/tests/transformations/copy_test.py b/tests/transformations/copy_test.py index cac99794..e3dc9578 100644 --- a/tests/transformations/copy_test.py +++ b/tests/transformations/copy_test.py @@ -10,7 +10,7 @@ import pytest -def copy_test(): +def test_copy(): num_node_blocks_candidates = [2, 4, 7] block_size_candidates = [1, 4, 8] @@ -66,4 +66,4 @@ def copy_test(): if __name__ == "__main__": - copy_test() \ No newline at end of file + test_copy() \ No newline at end of file diff --git a/tests/transformations/merge_test.py b/tests/transformations/merge_test.py index fde18945..d7d95053 100644 --- a/tests/transformations/merge_test.py +++ b/tests/transformations/merge_test.py @@ -10,7 +10,7 @@ import pytest -def sum_nodes_merge_test(): +def test_sum_nodes_merge(): num_node_blocks = 2 for block_size in [1, 2, 4, 8]: @@ -41,7 +41,7 @@ def sum_nodes_merge_test(): assert n_new.chs[0] == m00 -def prod_nodes_merge_test(): +def test_prod_nodes_merge(): num_node_blocks = 2 for block_size in [1, 2, 4, 8]: @@ -70,7 +70,7 @@ def prod_nodes_merge_test(): assert m_new.chs[1] == i10 -def merge_by_region_node_test(): +def test_merge_by_region_node(): num_node_blocks = 2 for block_size in [1, 2, 4, 8]: @@ -121,6 +121,6 @@ def merge_by_region_node_test(): if __name__ == "__main__": - sum_nodes_merge_test() - prod_nodes_merge_test() - merge_by_region_node_test() \ No newline at end of file + test_sum_nodes_merge() + test_prod_nodes_merge() + test_merge_by_region_node() \ No newline at end of file diff --git a/tests/transformations/pruning_test.py b/tests/transformations/pruning_test.py index 51f51072..79e284e5 100644 --- a/tests/transformations/pruning_test.py +++ b/tests/transformations/pruning_test.py @@ -7,7 +7,7 @@ import pytest -def pruning_test(): +def test_pruning(): num_node_blocks = 2 for block_size in [1, 2, 4, 8]: @@ -50,7 +50,7 @@ def pruning_test(): assert torch.all(torch.abs(new_n2._params.sum(dim = 2) - 1.0) < 1e-4) -def pruning_with_param_tying_test(): +def test_pruning_with_param_tying(): num_node_blocks = 2 for block_size in [1, 2, 4, 8]: @@ -91,7 +91,7 @@ def pruning_with_param_tying_test(): assert torch.all(torch.abs(new_n2._source_node._params[0].sum(dim = 1) - 1.0) < 1e-4) -def pruning_by_flow_test(): +def test_pruning_by_flow(): num_nodes = 2 i0 = inputs(0, num_nodes, dists.Categorical(num_cats = 5)) @@ -130,6 +130,6 @@ def pruning_by_flow_test(): if __name__ == "__main__": - pruning_test() - pruning_with_param_tying_test() - pruning_by_flow_test() \ No newline at end of file + test_pruning() + test_pruning_with_param_tying() + test_pruning_by_flow() \ No newline at end of file diff --git a/tests/visualize/plots_test.py b/tests/visualize/plots_test.py index 9006896a..7db7e48f 100644 --- a/tests/visualize/plots_test.py +++ b/tests/visualize/plots_test.py @@ -5,6 +5,8 @@ import pyjuice.visualize as juice_vis +import pytest + def simple_pc_gen(): n0 = inputs(0, num_nodes=256, dist=dists.Categorical(num_cats=5)) @@ -26,19 +28,25 @@ def simple_pc_gen(): ns.init_parameters() return ns -ns = simple_pc_gen() -# case 1 -plt.figure() -juice_vis.plot_pc(ns, node_id=True, node_num_label=True) -plt.show() +def test_plots(): + ns = simple_pc_gen() + + # case 1 + plt.figure() + juice_vis.plot_pc(ns, node_id=True, node_num_label=True) + plt.show() + + # case 2 + juice_vis.plot_tensor_node_connection(ns, node_id=3) + + # case 3 + juice_vis.plot_tensor_node_connection(ns, node_id=4) + plt.show() -# case 2 -juice_vis.plot_tensor_node_connection(ns, node_id=3) + # case 4 + juice_vis.plot_tensor_node_connection(ns, node_id=0) -# case 3 -juice_vis.plot_tensor_node_connection(ns, node_id=4) -plt.show() -# case 4 -juice_vis.plot_tensor_node_connection(ns, node_id=0) +if __name__ == "__main__": + test_plots() \ No newline at end of file From 05b9160a527e0f62251f5d942c8d2c8d0317a37d Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Mon, 11 Mar 2024 02:17:39 +0800 Subject: [PATCH 10/53] stabilize runtests --- src/pyjuice/structures/pd.py | 20 ++-- src/pyjuice/structures/rat_spn.py | 6 +- tests/model/simple_model_test.py | 26 ++--- tests/nodes/input_dists_test.py | 6 +- tests/structures/hclt_test.py | 157 ++++++++++++++++++++++++------ 5 files changed, 162 insertions(+), 53 deletions(-) diff --git a/src/pyjuice/structures/pd.py b/src/pyjuice/structures/pd.py index 541b672f..a14861ec 100644 --- a/src/pyjuice/structures/pd.py +++ b/src/pyjuice/structures/pd.py @@ -22,8 +22,8 @@ def PD(data_shape: Tuple, num_latents: int, structure_type: str = "sum_dominated", input_layer_fn: Optional[Callable] = None, input_dist: Optional[Distribution] = None, - input_layer_type: Type[Distribution] = Categorical, - input_layer_params: Dict = {"num_cats": 256}, + input_node_type: Type[Distribution] = Categorical, + input_node_params: Dict = {"num_cats": 256}, use_linear_mixing: bool = False, block_size: Optional[int] = None): """ @@ -124,7 +124,7 @@ def create_input_ns(hypercube): else: input_nodes = [] for var in scope: - ns = inputs(var, num_node_blocks = num_node_blocks, dist = input_layer_type(**input_layer_params)) + ns = inputs(var, num_node_blocks = num_node_blocks, dist = input_node_type(**input_node_params)) input_nodes.append(ns) edge_ids = torch.arange(0, num_node_blocks)[None,:].repeat(2, 1) @@ -194,21 +194,25 @@ def PDHCLT(data: torch.Tensor, data_shape: Tuple, num_latents: int, max_split_depth: Optional[int] = None, max_prod_block_conns: int = 4, structure_type: str = "sum_dominated", - input_layer_type: Type[Distribution] = Categorical, - input_layer_params: Dict = {"num_cats": 256}, + input_dist: Optional[Distribution] = None, + input_node_type: Type[Distribution] = Categorical, + input_node_params: Dict = {"num_cats": 256}, hclt_kwargs: Dict = {"num_bins": 32, "sigma": 0.5 / 32, "chunk_size": 32}, block_size: Optional[int] = None): assert data.dim() == 2 assert data.size(1) == reduce(lambda x, y: x * y, data_shape) + if input_dist is not None: + input_node_type, input_node_params = input_dist._get_constructor() + def input_layer_fn(scope, num_latents, block_size): vars = torch.tensor(scope.to_list()).sort().values ns = HCLT( x = data[:,vars], num_latents = num_latents, - input_layer_type = input_layer_type, - input_layer_params = input_layer_params, + input_node_type = input_node_type, + input_node_params = input_node_params, num_root_ns = num_latents, block_size = block_size, **hclt_kwargs @@ -224,7 +228,7 @@ def input_layer_fn(scope, num_latents, block_size): split_intervals = split_intervals, split_points = split_points, max_split_depth = max_split_depth, max_prod_block_conns = max_prod_block_conns, structure_type = structure_type, input_layer_fn = input_layer_fn, - input_layer_type = input_layer_type, input_layer_params = input_layer_params, + input_node_type = input_node_type, input_node_params = input_node_params, block_size = block_size) if ns.num_node_blocks > 1: diff --git a/src/pyjuice/structures/rat_spn.py b/src/pyjuice/structures/rat_spn.py index 7af44608..7d011e21 100644 --- a/src/pyjuice/structures/rat_spn.py +++ b/src/pyjuice/structures/rat_spn.py @@ -14,8 +14,8 @@ def RAT_SPN(num_vars: int, num_latents: int, depth: int, num_repetitions: int, num_pieces: int = 2, input_dist: Optional[Distribution] = None, - input_layer_type: Type[Distribution] = Categorical, - input_layer_params: dict = {"num_cats": 256}, + input_node_type: Type[Distribution] = Categorical, + input_node_params: dict = {"num_cats": 256}, block_size: Optional[int] = None): """ Generate Random and Tensorized SPNs (https://proceedings.mlr.press/v115/peharz20a/peharz20a.pdf) @@ -57,7 +57,7 @@ def RAT_SPN(num_vars: int, num_latents: int, depth: int, num_repetitions: int, n # Input nodes input_ns = [] for v in range(num_vars): - ns = inputs(v, num_node_blocks = num_node_blocks, dist = input_layer_type(**input_layer_params)) + ns = inputs(v, num_node_blocks = num_node_blocks, dist = input_node_type(**input_node_params)) input_ns.append(ns) # Top-down partition diff --git a/tests/model/simple_model_test.py b/tests/model/simple_model_test.py index 7f633f9c..df4d13b6 100644 --- a/tests/model/simple_model_test.py +++ b/tests/model/simple_model_test.py @@ -190,7 +190,9 @@ def test_simple_model(): pc.to(device) - data = torch.randint(0, 4, [512, 4], device = device) + batch_size = 512 + + data = torch.randint(0, 4, [batch_size, 4], device = device) lls = pc(data) @@ -303,7 +305,7 @@ def test_simple_model(): ns_parflows = eflows.sum(dim = 1) ref_parflows = param_flows[4096:4192] - assert torch.all(torch.abs(ns_parflows - ref_parflows) < 1e-3) + assert torch.all(torch.abs(ns_parflows - ref_parflows) < 1e-4 * batch_size) sid, eid = ns0._output_ind_range ns0_flows = np4_flows @@ -339,7 +341,7 @@ def test_simple_model(): ns0_parflows = epars * torch.matmul(n_fdm_sub, emars.permute(1, 0)) ref_parflows = param_flows[0:2048].reshape(2, 64, 16).permute(0, 2, 1).reshape(32, 64) - assert torch.all(torch.abs(ns0_parflows - ref_parflows) < 1e-3) + assert torch.all(torch.abs(ns0_parflows - ref_parflows) < 1e-4 * batch_size) ch_lls = np1_lls epars = ns1._params.reshape(2, 2, 16, 16).permute(0, 2, 1, 3).reshape(32, 32) @@ -355,7 +357,7 @@ def test_simple_model(): ns1_parflows = epars * torch.matmul(n_fdm_sub, emars.permute(1, 0)) ref_parflows = param_flows[2048:3072].reshape(2, 32, 16).permute(0, 2, 1).reshape(32, 32) - assert torch.all(torch.abs(ns1_parflows - ref_parflows) < 1e-3) + assert torch.all(torch.abs(ns1_parflows - ref_parflows) < 1e-4 * batch_size) ch_lls = np2_lls epars = ns2._params.reshape(2, 2, 16, 16).permute(0, 2, 1, 3).reshape(32, 32) @@ -371,7 +373,7 @@ def test_simple_model(): ns2_parflows = epars * torch.matmul(n_fdm_sub, emars.permute(1, 0)) ref_parflows = param_flows[3072:4096].reshape(2, 32, 16).permute(0, 2, 1).reshape(32, 32) - assert torch.all(torch.abs(ns2_parflows - ref_parflows) < 1e-3) + assert torch.all(torch.abs(ns2_parflows - ref_parflows) < 1e-4 * batch_size) sid, eid = ni0._output_ind_range ni0_flows = np0_flows + np3_flows + np5_flows + np6_flows @@ -430,13 +432,13 @@ def test_simple_model(): pc.update_param_flows() ref_parflows = ns0._param_flows.reshape(2, 4, 16, 16).permute(0, 2, 1, 3).reshape(32, 64) - assert torch.all(torch.abs(ns0_parflows - ref_parflows) < 1e-3) + assert torch.all(torch.abs(ns0_parflows - ref_parflows) < 1e-4 * batch_size) ref_parflows = ns1._param_flows.reshape(2, 2, 16, 16).permute(0, 2, 1, 3).reshape(32, 32) - assert torch.all(torch.abs(ns1_parflows - ref_parflows) < 1e-3) + assert torch.all(torch.abs(ns1_parflows - ref_parflows) < 1e-4 * batch_size) ref_parflows = ns2._param_flows.reshape(2, 2, 16, 16).permute(0, 2, 1, 3).reshape(32, 32) - assert torch.all(torch.abs(ns2_parflows - ref_parflows) < 1e-3) + assert torch.all(torch.abs(ns2_parflows - ref_parflows) < 1e-4 * batch_size) par_start_ids, pflow_start_ids, blk_sizes, blk_intervals, global_nids, nchs, cum_pflows, metadata = pc.par_update_kwargs @@ -493,9 +495,9 @@ def test_simple_model(): ns1_parflows = ns1._param_flows.reshape(2, 2, 16, 16).permute(0, 2, 1, 3).reshape(32, 32) ns2_parflows = ns2._param_flows.reshape(2, 2, 16, 16).permute(0, 2, 1, 3).reshape(32, 32) - assert torch.all(torch.abs(ns0_parflows.sum(dim = 1) - cum_pflows[0:32]) < 1e-3) - assert torch.all(torch.abs(ns1_parflows.sum(dim = 1) - cum_pflows[32:64]) < 1e-3) - assert torch.all(torch.abs(ns2_parflows.sum(dim = 1) - cum_pflows[64:96]) < 1e-3) + assert torch.all(torch.abs(ns0_parflows.sum(dim = 1) - cum_pflows[0:32]) < 1e-4 * batch_size) + assert torch.all(torch.abs(ns1_parflows.sum(dim = 1) - cum_pflows[32:64]) < 1e-4 * batch_size) + assert torch.all(torch.abs(ns2_parflows.sum(dim = 1) - cum_pflows[64:96]) < 1e-4 * batch_size) assert torch.abs(ns_parflows.sum() - cum_pflows[96]) < 1e-3 ns0_new_params = (ns0_parflows + pseudocount / 64) / (ns0_parflows.sum(dim = 1, keepdim = True) + pseudocount) @@ -524,5 +526,5 @@ def test_simple_model(): if __name__ == "__main__": - torch.manual_seed(23892) + # torch.manual_seed(23892) test_simple_model() diff --git a/tests/nodes/input_dists_test.py b/tests/nodes/input_dists_test.py index 3a8b27b2..bf60f0c5 100644 --- a/tests/nodes/input_dists_test.py +++ b/tests/nodes/input_dists_test.py @@ -361,7 +361,9 @@ def test_discrete_logistic_nodes_behavior(): pc.backward(data.permute(1, 0), flows_memory = 0.0) - pc.mini_batch_em(step_size = 1.0, pseudocount = 0.01) + pc.mini_batch_em(step_size = 1.0, pseudocount = 0.001) + + lls = pc(data) assert lls.mean() > -3.5 @@ -732,7 +734,7 @@ def test_masked_categorical_nodes_rev_range(): if __name__ == "__main__": - torch.manual_seed(2390) + # torch.manual_seed(235) test_categorical_nodes() test_bernoulli_nodes() test_gaussian_nodes() diff --git a/tests/structures/hclt_test.py b/tests/structures/hclt_test.py index c881dd56..4377da58 100644 --- a/tests/structures/hclt_test.py +++ b/tests/structures/hclt_test.py @@ -5,6 +5,8 @@ from torch.utils.data import TensorDataset, DataLoader import pyjuice.nodes.distributions as dists +import pytest + def evaluate(pc, loader): lls_total = 0.0 @@ -93,9 +95,10 @@ def test_hclt(): train_data.float().to(device), num_bins = 32, sigma = 0.5 / 32, - num_latents = 256, + num_latents = 128, chunk_size = 32 ) + ns.init_parameters(perturbation = 2.0) pc = juice.TensorCircuit(ns) pc.to(device) @@ -115,35 +118,133 @@ def test_hclt(): lls.mean().backward() break - # for i, batch in enumerate(train_loader): - # x = batch[0].to(device) + mini_batch_em_epoch(5, pc, optimizer, scheduler, train_loader, test_loader, device) - # lls = pc(x, record_cudagraph = False) - # lls.mean().backward() - # if i > 5: - # break + test_ll = evaluate(pc, test_loader) - # from torch.profiler import profile, record_function, ProfilerActivity - # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], with_stack = True) as prof: - # for i, batch in enumerate(train_loader): - # x = batch[0].to(device) + assert test_ll > -785 - # lls = pc(x, record_cudagraph = False) - # lls.mean().backward() - # if i > 5: - # break - # prof.export_chrome_trace("trace3.json") - # # torch.autograd.profiler.tensorboard_trace_to_flame_graph('trace.json', 'flamegraph.svg') - # # prof.export_stacks("trace.txt", "cpu_time_total") - # import pdb; pdb.set_trace() - # exit() +@pytest.mark.slow +def test_small_hclt_full(): - mini_batch_em_epoch(5, pc, optimizer, scheduler, train_loader, test_loader, device) + device = torch.device("cuda:0") + + train_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = True, download = True) + test_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = False, download = True) + + train_data = train_dataset.data.reshape(60000, 28*28) + test_data = test_dataset.data.reshape(10000, 28*28) + + num_features = train_data.size(1) + + train_loader = DataLoader( + dataset = TensorDataset(train_data), + batch_size = 512, + shuffle = True, + drop_last = True + ) + test_loader = DataLoader( + dataset = TensorDataset(test_data), + batch_size = 512, + shuffle = False, + drop_last = True + ) + + ns = juice.structures.HCLT( + train_data.float().to(device), + num_bins = 32, + sigma = 0.5 / 32, + num_latents = 32, + chunk_size = 32 + ) + ns.init_parameters(perturbation = 2.0) + pc = juice.TensorCircuit(ns) + + pc.to(device) + + optimizer = juice.optim.CircuitOptimizer(pc, lr = 0.1, pseudocount = 0.1) + scheduler = juice.optim.CircuitScheduler( + optimizer, + method = "multi_linear", + lrs = [0.9, 0.1, 0.05], + milestone_steps = [0, len(train_loader) * 100, len(train_loader) * 350] + ) + + for batch in train_loader: + x = batch[0].to(device) + + lls = pc(x, record_cudagraph = True) + lls.mean().backward() + break + + mini_batch_em_epoch(350, pc, optimizer, scheduler, train_loader, test_loader, device) + full_batch_em_epoch(pc, train_loader, test_loader, device) + + test_ll = evaluate(pc, test_loader) + + assert test_ll > -660 + + +@pytest.mark.slow +def test_large_hclt_full(): + + device = torch.device("cuda:0") + + train_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = True, download = True) + test_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = False, download = True) + + train_data = train_dataset.data.reshape(60000, 28*28) + test_data = test_dataset.data.reshape(10000, 28*28) + + num_features = train_data.size(1) + + train_loader = DataLoader( + dataset = TensorDataset(train_data), + batch_size = 512, + shuffle = True, + drop_last = True + ) + test_loader = DataLoader( + dataset = TensorDataset(test_data), + batch_size = 512, + shuffle = False, + drop_last = True + ) + + ns = juice.structures.HCLT( + train_data.float().to(device), + num_bins = 32, + sigma = 0.5 / 32, + num_latents = 128, + chunk_size = 32 + ) + ns.init_parameters(perturbation = 2.0) + pc = juice.TensorCircuit(ns) + + pc.to(device) + + optimizer = juice.optim.CircuitOptimizer(pc, lr = 0.1, pseudocount = 0.1) + scheduler = juice.optim.CircuitScheduler( + optimizer, + method = "multi_linear", + lrs = [0.9, 0.1, 0.05], + milestone_steps = [0, len(train_loader) * 100, len(train_loader) * 350] + ) + + for batch in train_loader: + x = batch[0].to(device) + + lls = pc(x, record_cudagraph = True) + lls.mean().backward() + break + + mini_batch_em_epoch(350, pc, optimizer, scheduler, train_loader, test_loader, device) + full_batch_em_epoch(pc, train_loader, test_loader, device) test_ll = evaluate(pc, test_loader) - assert test_ll > -770 + assert test_ll > -640 def test_hclt_logistic(): @@ -175,12 +276,10 @@ def test_hclt_logistic(): train_data.float().to(device), num_bins = 32, sigma = 0.5 / 32, - num_latents = 64, + num_latents = 128, chunk_size = 32, - # input_layer_type = dists.Gaussian, - # input_layer_params = {"mu": 0.0, "sigma": 1.0 / 3,"min_sigma": 0.01} - input_layer_type = dists.DiscreteLogistic, - input_layer_params = {"val_range": (-1.0, 1.0), "num_cats": 256} + input_node_type = dists.DiscreteLogistic, + input_node_params = {"val_range": (-1.0, 1.0), "num_cats": 256} ) ns.init_parameters(perturbation = 4.0) pc = juice.TensorCircuit(ns) @@ -203,6 +302,8 @@ def test_hclt_logistic(): if __name__ == "__main__": - torch.manual_seed(3289) + # torch.manual_seed(3289) test_hclt() + test_small_hclt_full() + test_large_hclt_full() test_hclt_logistic() From 80059777ebf37a515d9493b4b5595a3197fd0ee1 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Mon, 11 Mar 2024 03:08:29 +0800 Subject: [PATCH 11/53] fix typo in HMM --- src/pyjuice/structures/hmm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pyjuice/structures/hmm.py b/src/pyjuice/structures/hmm.py index d897d936..7351cac3 100644 --- a/src/pyjuice/structures/hmm.py +++ b/src/pyjuice/structures/hmm.py @@ -50,7 +50,7 @@ def HMM(seq_length: int, num_latents: int, num_emits: int, homogeneous: bool = T block_size = min(max_cdf_power_of_2(num_latents), 1024) num_node_blocks = num_latents // block_size - with juice.set_block_size(block_size = block_size): + with set_block_size(block_size = block_size): ns_input = inputs( seq_length - 1, num_node_blocks = num_node_blocks, From ee0596b485c62fdc6a649355440893a2e901067b Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Mon, 11 Mar 2024 03:59:10 +0800 Subject: [PATCH 12/53] limit tile size for `MPE` propagation method to avoid kernel stall --- src/pyjuice/layer/sum_layer.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index 8f832dfe..94c36c2a 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -1954,7 +1954,8 @@ def _bk_triton_block_sparse_par_kernel(node_flows, node_mars, element_mars, para nflows = tl.load(nflows_ptr, mask = mask_batch[None,:], other = 0.0) # [TILE_SIZE_M, TILE_SIZE_B] nmars = tl.load(nmars_ptr, mask = mask_batch[None,:], other = 0.0) # [TILE_SIZE_M, TILE_SIZE_B] - acc += tl.sum(tl.where(tl.abs(elpars[None,:,:] + emars[:,None,:] - tl.trans(nmars)[:,:,None]) < 1e-6, tl.trans(nflows)[:,:,None], 0.0), axis = 0) + cond = tl.abs(elpars[:,None,:] + emars[None,:,:] - nmars[:,:,None]) < 1e-6 + acc += tl.sum(tl.where(cond, nflows[:,:,None], 0.0), axis = 1) else: @@ -2161,6 +2162,13 @@ def _backward_block_sparse_par_flows(self, node_flows: torch.Tensor, params: tor TILE_SIZE_B = min(2048 // remainder, base_size * remainder, BATCH_SIZE_NP2) TILE_SIZE_M = min(2048 // TILE_SIZE_B, self.block_size) TILE_SIZE_K = min(2048 // TILE_SIZE_B, num_edges) + + if propagation_alg_id == 1: + # The kernel will stale if the tile sizes are too large + TILE_SIZE_M = 16 + TILE_SIZE_K = 16 + TILE_SIZE_B = 16 + B_NUM_TILES = batch_size // TILE_SIZE_B allow_modify_flows = 1 if allow_modify_flows else 0 From a419fdbdf9b995bed14c362e88d83c5008ac39aa Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Mon, 11 Mar 2024 03:59:27 +0800 Subject: [PATCH 13/53] typo --- src/pyjuice/layer/sum_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index 94c36c2a..eebb5d9b 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -2164,7 +2164,7 @@ def _backward_block_sparse_par_flows(self, node_flows: torch.Tensor, params: tor TILE_SIZE_K = min(2048 // TILE_SIZE_B, num_edges) if propagation_alg_id == 1: - # The kernel will stale if the tile sizes are too large + # The kernel will stall if the tile sizes are too large TILE_SIZE_M = 16 TILE_SIZE_K = 16 TILE_SIZE_B = 16 From 3fd1c74b1d9ca3ba79a32f9054f86abdbc874ce8 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Mon, 11 Mar 2024 03:59:49 +0800 Subject: [PATCH 14/53] runtests for the viterbi algorithm and the generalized EM algorithm --- pyproject.toml | 1 + src/pyjuice/layer/input_layer.py | 4 +- src/pyjuice/layer/prod_layer.py | 4 +- tests/optim/hmm_general_em_test.py | 132 +++++++++++++++++++++++++++++ tests/optim/hmm_viterbi_test.py | 132 +++++++++++++++++++++++++++++ 5 files changed, 269 insertions(+), 4 deletions(-) create mode 100644 tests/optim/hmm_general_em_test.py create mode 100644 tests/optim/hmm_viterbi_test.py diff --git a/pyproject.toml b/pyproject.toml index 51adc70a..6aea82d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dev = [ "pytest-xdist", "pytest-skip-slow", "torchvision", + "torchtext", "matplotlib" ] diff --git a/src/pyjuice/layer/input_layer.py b/src/pyjuice/layer/input_layer.py index 7994206f..ede8601f 100644 --- a/src/pyjuice/layer/input_layer.py +++ b/src/pyjuice/layer/input_layer.py @@ -213,7 +213,7 @@ def init_param_flows(self, flows_memory: float = 1.0): def forward(self, data: torch.Tensor, node_mars: torch.Tensor, params: Optional[Dict] = None, missing_mask: Optional[torch.Tensor] = None, _batch_first: bool = True, - _apply_missing_mask_only: bool = False): + _apply_missing_mask_only: bool = False, **kwargs): self._used_external_params = (params is not None) if params is None: @@ -300,7 +300,7 @@ def forward(self, data: torch.Tensor, node_mars: torch.Tensor, params: Optional[ raise NotImplementedError("CPU forward fn for input nodes is not implemented.") def backward(self, data: torch.Tensor, node_flows: torch.Tensor, - node_mars: torch.Tensor, params: Optional[Dict] = None): + node_mars: torch.Tensor, params: Optional[Dict] = None, **kwargs): """ data: [num_vars, B] node_flows: [num_nodes, B] diff --git a/src/pyjuice/layer/prod_layer.py b/src/pyjuice/layer/prod_layer.py index 1236e043..1d4d901d 100644 --- a/src/pyjuice/layer/prod_layer.py +++ b/src/pyjuice/layer/prod_layer.py @@ -149,7 +149,7 @@ def __init__(self, nodes: Sequence[ProdNodes], global_nid_start: Optional[int] = self.partitioned_u_cids = FastParamList([nn.Parameter(tensor, requires_grad = False) for tensor in u_cids]) self.partitioned_parids = FastParamList([nn.Parameter(tensor, requires_grad = False) for tensor in parids]) - def forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, _for_backward: bool = False) -> None: + def forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, _for_backward: bool = False, **kwargs) -> None: """ Computes the forward pass of a product layer. If `block_size == 1`, it is equivalent to the following: ``` @@ -195,7 +195,7 @@ def forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, _for_back return None - def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor) -> None: + def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, **kwargs) -> None: """ Computes the backward pass of a product layer: ``` diff --git a/tests/optim/hmm_general_em_test.py b/tests/optim/hmm_general_em_test.py new file mode 100644 index 00000000..3d48933b --- /dev/null +++ b/tests/optim/hmm_general_em_test.py @@ -0,0 +1,132 @@ +import pyjuice as juice +import torch +import torchvision +import time +from tqdm import tqdm +from torchtext.datasets import PennTreebank +from torchtext.data.utils import get_tokenizer +from torchtext.vocab import build_vocab_from_iterator +from torch.utils.data import TensorDataset, DataLoader +import pyjuice.distributions as dists + +import pytest + + +def load_penn_treebank(seq_length = 32): + + CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .!?,;:-'\"()[]{}" + vocab = {char: idx for idx, char in enumerate(CHARS)} + + # Define a tokenizer + tokenizer = get_tokenizer("basic_english") + + # Load the Penn Treebank dataset + train_dataset, valid_dataset, test_dataset = PennTreebank(root = "./examples/data") + + train_data = [] + for sample in tqdm(train_dataset): + train_data.extend([vocab[token] if token in vocab else len(CHARS) for token in sample]) + + valid_data = [] + for sample in tqdm(valid_dataset): + valid_data.extend([vocab[token] if token in vocab else len(CHARS) for token in sample]) + + test_data = [] + for sample in tqdm(test_dataset): + test_data.extend([vocab[token] if token in vocab else len(CHARS) for token in sample]) + + # Convert to PyTorch tensors + train_data = torch.tensor(train_data) + valid_data = torch.tensor(valid_data) + test_data = torch.tensor(test_data) + + nsamples = train_data.size(0) // seq_length * seq_length + train_data = train_data[:nsamples].reshape(-1, seq_length) + + nsamples = valid_data.size(0) // seq_length * seq_length + valid_data = valid_data[:nsamples].reshape(-1, seq_length) + + nsamples = test_data.size(0) // seq_length * seq_length + test_data = test_data[:nsamples].reshape(-1, seq_length) + + return train_data, valid_data, test_data + + +@pytest.mark.slow +def test_hmm_viterbi(): + + device = torch.device("cuda:0") + + seq_length = 32 + + train_data, valid_data, test_data = load_penn_treebank(seq_length = seq_length) + + vocab_size = train_data.max().item() + 1 + + train_loader = DataLoader( + dataset = TensorDataset(train_data), + batch_size = 512, + shuffle = True, + drop_last = True + ) + valid_loader = DataLoader( + dataset = TensorDataset(valid_data), + batch_size = 512, + shuffle = False, + drop_last = True + ) + + print(f"> Number of training samples: {train_data.size(0)}") + print(f"> Number of validation samples: {valid_data.size(0)}") + + root_ns = juice.structures.HMM( + seq_length = seq_length, + num_latents = 256, + num_emits = vocab_size, + homogeneous = True + ) + + pc = juice.compile(root_ns) + pc.to(device) + + best_valid_ll = -10000.0 + for epoch in range(1, 40 + 1): + t0 = time.time() + train_ll = 0.0 + for batch in train_loader: + x = batch[0].to(device) + + lls = pc(x, propagation_alg = "GeneralLL", alpha = 1.2) + lls.mean().backward() + + train_ll += lls.mean().detach().cpu().numpy().item() + + train_ll /= len(train_loader) + + pc.mini_batch_em(step_size = 1.0, pseudocount = 0.01) + + t1 = time.time() + + with torch.no_grad(): + valid_ll = 0.0 + for batch in valid_loader: + x = batch[0].to(device) + + lls = pc(x, propagation_alg = "LL") + + valid_ll += lls.mean().detach().cpu().numpy().item() + + valid_ll /= len(valid_loader) + + t2 = time.time() + + print(f"[epoch {epoch:3d}][train LL: {train_ll:.2f}; valid LL: {valid_ll:.2f}].....[train forward+backward+step {t1-t0:.2f}; test forward {t2-t1:.2f}]") + + if valid_ll > best_valid_ll: + best_valid_ll = valid_ll + + assert best_valid_ll > -85.0 + + +if __name__ == "__main__": + test_hmm_viterbi() \ No newline at end of file diff --git a/tests/optim/hmm_viterbi_test.py b/tests/optim/hmm_viterbi_test.py new file mode 100644 index 00000000..7c703447 --- /dev/null +++ b/tests/optim/hmm_viterbi_test.py @@ -0,0 +1,132 @@ +import pyjuice as juice +import torch +import torchvision +import time +from tqdm import tqdm +from torchtext.datasets import PennTreebank +from torchtext.data.utils import get_tokenizer +from torchtext.vocab import build_vocab_from_iterator +from torch.utils.data import TensorDataset, DataLoader +import pyjuice.distributions as dists + +import pytest + + +def load_penn_treebank(seq_length = 32): + + CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .!?,;:-'\"()[]{}" + vocab = {char: idx for idx, char in enumerate(CHARS)} + + # Define a tokenizer + tokenizer = get_tokenizer("basic_english") + + # Load the Penn Treebank dataset + train_dataset, valid_dataset, test_dataset = PennTreebank(root = "./examples/data") + + train_data = [] + for sample in tqdm(train_dataset): + train_data.extend([vocab[token] if token in vocab else len(CHARS) for token in sample]) + + valid_data = [] + for sample in tqdm(valid_dataset): + valid_data.extend([vocab[token] if token in vocab else len(CHARS) for token in sample]) + + test_data = [] + for sample in tqdm(test_dataset): + test_data.extend([vocab[token] if token in vocab else len(CHARS) for token in sample]) + + # Convert to PyTorch tensors + train_data = torch.tensor(train_data) + valid_data = torch.tensor(valid_data) + test_data = torch.tensor(test_data) + + nsamples = train_data.size(0) // seq_length * seq_length + train_data = train_data[:nsamples].reshape(-1, seq_length) + + nsamples = valid_data.size(0) // seq_length * seq_length + valid_data = valid_data[:nsamples].reshape(-1, seq_length) + + nsamples = test_data.size(0) // seq_length * seq_length + test_data = test_data[:nsamples].reshape(-1, seq_length) + + return train_data, valid_data, test_data + + +@pytest.mark.slow +def test_hmm_viterbi(): + + device = torch.device("cuda:0") + + seq_length = 32 + + train_data, valid_data, test_data = load_penn_treebank(seq_length = seq_length) + + vocab_size = train_data.max().item() + 1 + + train_loader = DataLoader( + dataset = TensorDataset(train_data), + batch_size = 512, + shuffle = True, + drop_last = True + ) + valid_loader = DataLoader( + dataset = TensorDataset(valid_data), + batch_size = 512, + shuffle = False, + drop_last = True + ) + + print(f"> Number of training samples: {train_data.size(0)}") + print(f"> Number of validation samples: {valid_data.size(0)}") + + root_ns = juice.structures.HMM( + seq_length = seq_length, + num_latents = 256, + num_emits = vocab_size, + homogeneous = True + ) + + pc = juice.compile(root_ns) + pc.to(device) + + best_valid_ll = -10000.0 + for epoch in range(1, 20 + 1): + t0 = time.time() + train_ll = 0.0 + for batch in train_loader: + x = batch[0].to(device) + + lls = pc(x, propagation_alg = "MPE") + lls.mean().backward() + + train_ll += lls.mean().detach().cpu().numpy().item() + + train_ll /= len(train_loader) + + pc.mini_batch_em(step_size = 1.0, pseudocount = 0.01) + + t1 = time.time() + + with torch.no_grad(): + valid_ll = 0.0 + for batch in valid_loader: + x = batch[0].to(device) + + lls = pc(x, propagation_alg = "LL") + + valid_ll += lls.mean().detach().cpu().numpy().item() + + valid_ll /= len(valid_loader) + + t2 = time.time() + + print(f"[epoch {epoch:3d}][train LL: {train_ll:.2f}; valid LL: {valid_ll:.2f}].....[train forward+backward+step {t1-t0:.2f}; test forward {t2-t1:.2f}]") + + if valid_ll > best_valid_ll: + best_valid_ll = valid_ll + + assert best_valid_ll > -90.0 + + +if __name__ == "__main__": + test_hmm_viterbi() \ No newline at end of file From e78318799bc3f0c5c801a64116358f2127ac20eb Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Mon, 11 Mar 2024 04:10:10 +0800 Subject: [PATCH 15/53] fast runtest for `GeneralLL` with HMMs --- tests/optim/hmm_general_em_test.py | 103 +++++++++++++++++++++-------- 1 file changed, 76 insertions(+), 27 deletions(-) diff --git a/tests/optim/hmm_general_em_test.py b/tests/optim/hmm_general_em_test.py index 3d48933b..6ca100d0 100644 --- a/tests/optim/hmm_general_em_test.py +++ b/tests/optim/hmm_general_em_test.py @@ -52,8 +52,49 @@ def load_penn_treebank(seq_length = 32): return train_data, valid_data, test_data +def train(pc, num_epochs, train_loader, valid_loader, device): + + best_valid_ll = -10000.0 + for epoch in range(1, num_epochs + 1): + t0 = time.time() + train_ll = 0.0 + for batch in train_loader: + x = batch[0].to(device) + + lls = pc(x, propagation_alg = "GeneralLL", alpha = 1.2) + lls.mean().backward() + + train_ll += lls.mean().detach().cpu().numpy().item() + + train_ll /= len(train_loader) + + pc.mini_batch_em(step_size = 1.0, pseudocount = 0.01) + + t1 = time.time() + + with torch.no_grad(): + valid_ll = 0.0 + for batch in valid_loader: + x = batch[0].to(device) + + lls = pc(x, propagation_alg = "LL") + + valid_ll += lls.mean().detach().cpu().numpy().item() + + valid_ll /= len(valid_loader) + + t2 = time.time() + + print(f"[epoch {epoch:3d}][train LL: {train_ll:.2f}; valid LL: {valid_ll:.2f}].....[train forward+backward+step {t1-t0:.2f}; test forward {t2-t1:.2f}]") + + if valid_ll > best_valid_ll: + best_valid_ll = valid_ll + + return best_valid_ll + + @pytest.mark.slow -def test_hmm_viterbi(): +def test_hmm_general_ll(): device = torch.device("cuda:0") @@ -89,44 +130,52 @@ def test_hmm_viterbi(): pc = juice.compile(root_ns) pc.to(device) - best_valid_ll = -10000.0 - for epoch in range(1, 40 + 1): - t0 = time.time() - train_ll = 0.0 - for batch in train_loader: - x = batch[0].to(device) + best_valid_ll = train(pc, 40, train_loader, valid_loader, device) - lls = pc(x, propagation_alg = "GeneralLL", alpha = 1.2) - lls.mean().backward() - - train_ll += lls.mean().detach().cpu().numpy().item() + assert best_valid_ll > -85.0 - train_ll /= len(train_loader) - pc.mini_batch_em(step_size = 1.0, pseudocount = 0.01) +def test_hmm_general_ll_fast(): + + device = torch.device("cuda:0") - t1 = time.time() + seq_length = 32 - with torch.no_grad(): - valid_ll = 0.0 - for batch in valid_loader: - x = batch[0].to(device) + train_data, valid_data, test_data = load_penn_treebank(seq_length = seq_length) - lls = pc(x, propagation_alg = "LL") + vocab_size = train_data.max().item() + 1 - valid_ll += lls.mean().detach().cpu().numpy().item() + train_loader = DataLoader( + dataset = TensorDataset(train_data), + batch_size = 512, + shuffle = True, + drop_last = True + ) + valid_loader = DataLoader( + dataset = TensorDataset(valid_data), + batch_size = 512, + shuffle = False, + drop_last = True + ) - valid_ll /= len(valid_loader) + print(f"> Number of training samples: {train_data.size(0)}") + print(f"> Number of validation samples: {valid_data.size(0)}") - t2 = time.time() + root_ns = juice.structures.HMM( + seq_length = seq_length, + num_latents = 64, + num_emits = vocab_size, + homogeneous = True + ) - print(f"[epoch {epoch:3d}][train LL: {train_ll:.2f}; valid LL: {valid_ll:.2f}].....[train forward+backward+step {t1-t0:.2f}; test forward {t2-t1:.2f}]") + pc = juice.compile(root_ns) + pc.to(device) - if valid_ll > best_valid_ll: - best_valid_ll = valid_ll + best_valid_ll = train(pc, 10, train_loader, valid_loader, device) - assert best_valid_ll > -85.0 + assert best_valid_ll > -92.0 if __name__ == "__main__": - test_hmm_viterbi() \ No newline at end of file + test_hmm_general_ll() + test_hmm_general_ll_fast() \ No newline at end of file From d78cbf1cd1a4b0a2bcf0be42cc0f4c96540f154a Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Mon, 11 Mar 2024 04:20:56 +0800 Subject: [PATCH 16/53] add fast runtest for `MPE` propagation method + fix tile size allocation for `MPE` in backward pass --- src/pyjuice/layer/sum_layer.py | 43 +++++++------- tests/optim/hmm_viterbi_test.py | 99 ++++++++++++++++++++++++--------- 2 files changed, 95 insertions(+), 47 deletions(-) diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index eebb5d9b..7b9abcfd 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -1077,25 +1077,22 @@ def _forward_sparse(self, node_mars: torch.Tensor, element_mars: torch.Tensor, grid = (triton.cdiv(batch_size, BLOCK_B), triton.cdiv(layer_n_nodes, BLOCK_SIZE_M)) - try: - self._fw_triton_sparse_kernel[grid]( - node_mars = node_mars, - element_mars = element_mars, - params = params, - nids = nids, - cids = cids, - pids = pids, - local_ids = local_ids, - batch_size = batch_size, - partial_eval = partial_eval, - num_edges = num_edges, - BLOCK_B = BLOCK_B, - BLOCK_SIZE_M = BLOCK_SIZE_M, - propagation_alg_id = propagation_alg_id, - **propagation_alg_kwargs - ) - except TypeError: - import pdb; pdb.set_trace() + self._fw_triton_sparse_kernel[grid]( + node_mars = node_mars, + element_mars = element_mars, + params = params, + nids = nids, + cids = cids, + pids = pids, + local_ids = local_ids, + batch_size = batch_size, + partial_eval = partial_eval, + num_edges = num_edges, + BLOCK_B = BLOCK_B, + BLOCK_SIZE_M = BLOCK_SIZE_M, + propagation_alg_id = propagation_alg_id, + **propagation_alg_kwargs + ) else: BLOCK_B = max(min(2048 // num_edges, BATCH_SIZE_NP2), 1) @@ -2165,9 +2162,9 @@ def _backward_block_sparse_par_flows(self, node_flows: torch.Tensor, params: tor if propagation_alg_id == 1: # The kernel will stall if the tile sizes are too large - TILE_SIZE_M = 16 - TILE_SIZE_K = 16 - TILE_SIZE_B = 16 + TILE_SIZE_M = min(TILE_SIZE_M, 16) + TILE_SIZE_K = min(TILE_SIZE_K, 16) + TILE_SIZE_B = min(TILE_SIZE_B, 16) B_NUM_TILES = batch_size // TILE_SIZE_B @@ -2231,6 +2228,8 @@ def _backward_block_sparse_par_flows(self, node_flows: torch.Tensor, params: tor **propagation_alg_kwargs ) + return None + def _backward_sparse(self, node_flows: torch.Tensor, element_flows: torch.Tensor, params: torch.Tensor, node_mars: torch.Tensor, element_mars: torch.Tensor, param_flows: torch.Tensor, diff --git a/tests/optim/hmm_viterbi_test.py b/tests/optim/hmm_viterbi_test.py index 7c703447..748cf75f 100644 --- a/tests/optim/hmm_viterbi_test.py +++ b/tests/optim/hmm_viterbi_test.py @@ -52,6 +52,47 @@ def load_penn_treebank(seq_length = 32): return train_data, valid_data, test_data +def train(pc, num_epochs, train_loader, valid_loader, device): + + best_valid_ll = -10000.0 + for epoch in range(1, num_epochs + 1): + t0 = time.time() + train_ll = 0.0 + for batch in train_loader: + x = batch[0].to(device) + + lls = pc(x, propagation_alg = "MPE") + lls.mean().backward() + + train_ll += lls.mean().detach().cpu().numpy().item() + + train_ll /= len(train_loader) + + pc.mini_batch_em(step_size = 1.0, pseudocount = 0.01) + + t1 = time.time() + + with torch.no_grad(): + valid_ll = 0.0 + for batch in valid_loader: + x = batch[0].to(device) + + lls = pc(x, propagation_alg = "LL") + + valid_ll += lls.mean().detach().cpu().numpy().item() + + valid_ll /= len(valid_loader) + + t2 = time.time() + + print(f"[epoch {epoch:3d}][train LL: {train_ll:.2f}; valid LL: {valid_ll:.2f}].....[train forward+backward+step {t1-t0:.2f}; test forward {t2-t1:.2f}]") + + if valid_ll > best_valid_ll: + best_valid_ll = valid_ll + + return best_valid_ll + + @pytest.mark.slow def test_hmm_viterbi(): @@ -89,44 +130,52 @@ def test_hmm_viterbi(): pc = juice.compile(root_ns) pc.to(device) - best_valid_ll = -10000.0 - for epoch in range(1, 20 + 1): - t0 = time.time() - train_ll = 0.0 - for batch in train_loader: - x = batch[0].to(device) - - lls = pc(x, propagation_alg = "MPE") - lls.mean().backward() + best_valid_ll = train(pc, 20, train_loader, valid_loader, device) - train_ll += lls.mean().detach().cpu().numpy().item() + assert best_valid_ll > -90.0 - train_ll /= len(train_loader) - pc.mini_batch_em(step_size = 1.0, pseudocount = 0.01) +def test_hmm_viterbi_fast(): + + device = torch.device("cuda:0") - t1 = time.time() + seq_length = 32 - with torch.no_grad(): - valid_ll = 0.0 - for batch in valid_loader: - x = batch[0].to(device) + train_data, valid_data, test_data = load_penn_treebank(seq_length = seq_length) - lls = pc(x, propagation_alg = "LL") + vocab_size = train_data.max().item() + 1 - valid_ll += lls.mean().detach().cpu().numpy().item() + train_loader = DataLoader( + dataset = TensorDataset(train_data), + batch_size = 512, + shuffle = True, + drop_last = True + ) + valid_loader = DataLoader( + dataset = TensorDataset(valid_data), + batch_size = 512, + shuffle = False, + drop_last = True + ) - valid_ll /= len(valid_loader) + print(f"> Number of training samples: {train_data.size(0)}") + print(f"> Number of validation samples: {valid_data.size(0)}") - t2 = time.time() + root_ns = juice.structures.HMM( + seq_length = seq_length, + num_latents = 64, + num_emits = vocab_size, + homogeneous = True + ) - print(f"[epoch {epoch:3d}][train LL: {train_ll:.2f}; valid LL: {valid_ll:.2f}].....[train forward+backward+step {t1-t0:.2f}; test forward {t2-t1:.2f}]") + pc = juice.compile(root_ns) + pc.to(device) - if valid_ll > best_valid_ll: - best_valid_ll = valid_ll + best_valid_ll = train(pc, 5, train_loader, valid_loader, device) assert best_valid_ll > -90.0 if __name__ == "__main__": - test_hmm_viterbi() \ No newline at end of file + # test_hmm_viterbi() + test_hmm_viterbi_fast() \ No newline at end of file From c7caf3f5dc065d51cbf9047cf0d8a858fb0c5b53 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Mon, 11 Mar 2024 04:39:34 +0800 Subject: [PATCH 17/53] speedup runtests --- tests/structures/pd_hclt_test.py | 4 ++-- tests/structures/rat_spn_test.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/structures/pd_hclt_test.py b/tests/structures/pd_hclt_test.py index f72c4fda..84a18e9c 100644 --- a/tests/structures/pd_hclt_test.py +++ b/tests/structures/pd_hclt_test.py @@ -148,7 +148,7 @@ def test_pd_hclt(): ns = juice.structures.PDHCLT( train_data.cuda(), data_shape = (28, 28), - num_latents = 128, + num_latents = 64, split_intervals = (4, 4), structure_type = "sum_dominated" ) @@ -177,7 +177,7 @@ def test_pd_hclt(): test_ll = evaluate(pc, test_loader) - assert test_ll > -680.0 + assert test_ll > -692.0 if __name__ == "__main__": diff --git a/tests/structures/rat_spn_test.py b/tests/structures/rat_spn_test.py index 27271c50..dc220c17 100644 --- a/tests/structures/rat_spn_test.py +++ b/tests/structures/rat_spn_test.py @@ -93,7 +93,7 @@ def test_rat_spn(): ns = juice.structures.RAT_SPN( num_vars = 28 * 28, - num_latents = 256, + num_latents = 64, depth = 5, num_repetitions = 4, num_pieces = 2 @@ -121,7 +121,7 @@ def test_rat_spn(): test_ll = evaluate(pc, test_loader) - assert test_ll > -1015 + assert test_ll > -1020 if __name__ == "__main__": From 9bb3bc91b72d26f46dad85f3310a579608681e86 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Mon, 11 Mar 2024 15:41:12 +0800 Subject: [PATCH 18/53] update optim tests --- tests/optim/hmm_em_test.py | 181 +++++++++++++++++++++++++++++ tests/optim/hmm_general_em_test.py | 47 +++++++- tests/optim/hmm_viterbi_test.py | 53 ++++++++- 3 files changed, 274 insertions(+), 7 deletions(-) create mode 100644 tests/optim/hmm_em_test.py diff --git a/tests/optim/hmm_em_test.py b/tests/optim/hmm_em_test.py new file mode 100644 index 00000000..dc6498fb --- /dev/null +++ b/tests/optim/hmm_em_test.py @@ -0,0 +1,181 @@ +import pyjuice as juice +import torch +import torchvision +import time +from tqdm import tqdm +from torchtext.datasets import PennTreebank +from torchtext.data.utils import get_tokenizer +from torchtext.vocab import build_vocab_from_iterator +from torch.utils.data import TensorDataset, DataLoader +import pyjuice.distributions as dists + +import pytest + + +def load_penn_treebank(seq_length = 32): + + CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .!?,;:-'\"()[]{}" + vocab = {char: idx for idx, char in enumerate(CHARS)} + + # Define a tokenizer + tokenizer = get_tokenizer("basic_english") + + # Load the Penn Treebank dataset + train_dataset, valid_dataset, test_dataset = PennTreebank(root = "./examples/data") + + train_data = [] + for sample in tqdm(train_dataset): + train_data.extend([vocab[token] if token in vocab else len(CHARS) for token in sample]) + + valid_data = [] + for sample in tqdm(valid_dataset): + valid_data.extend([vocab[token] if token in vocab else len(CHARS) for token in sample]) + + test_data = [] + for sample in tqdm(test_dataset): + test_data.extend([vocab[token] if token in vocab else len(CHARS) for token in sample]) + + # Convert to PyTorch tensors + train_data = torch.tensor(train_data) + valid_data = torch.tensor(valid_data) + test_data = torch.tensor(test_data) + + nsamples = train_data.size(0) // seq_length * seq_length + train_data = train_data[:nsamples].reshape(-1, seq_length) + + nsamples = valid_data.size(0) // seq_length * seq_length + valid_data = valid_data[:nsamples].reshape(-1, seq_length) + + nsamples = test_data.size(0) // seq_length * seq_length + test_data = test_data[:nsamples].reshape(-1, seq_length) + + return train_data, valid_data, test_data + + +def train(pc, num_epochs, train_loader, valid_loader, device, propagation_alg, **kwargs): + + best_valid_ll = -10000.0 + for epoch in range(1, num_epochs + 1): + t0 = time.time() + train_ll = 0.0 + for batch in train_loader: + x = batch[0].to(device) + + lls = pc(x, propagation_alg = propagation_alg, **kwargs) + lls.mean().backward() + + train_ll += lls.mean().detach().cpu().numpy().item() + + train_ll /= len(train_loader) + + pc.mini_batch_em(step_size = 1.0, pseudocount = 0.01) + + t1 = time.time() + + with torch.no_grad(): + valid_ll = 0.0 + for batch in valid_loader: + x = batch[0].to(device) + + lls = pc(x, propagation_alg = "LL") + + valid_ll += lls.mean().detach().cpu().numpy().item() + + valid_ll /= len(valid_loader) + + t2 = time.time() + + print(f"[epoch {epoch:3d}][train LL: {train_ll:.2f}; valid LL: {valid_ll:.2f}].....[train forward+backward+step {t1-t0:.2f}; test forward {t2-t1:.2f}]") + + if valid_ll > best_valid_ll: + best_valid_ll = valid_ll + + return best_valid_ll + + +def test_hmm_em(): + + device = torch.device("cuda:0") + + seq_length = 32 + + train_data, valid_data, test_data = load_penn_treebank(seq_length = seq_length) + + vocab_size = train_data.max().item() + 1 + + train_loader = DataLoader( + dataset = TensorDataset(train_data), + batch_size = 512, + shuffle = True, + drop_last = True + ) + valid_loader = DataLoader( + dataset = TensorDataset(valid_data), + batch_size = 512, + shuffle = False, + drop_last = True + ) + + print(f"> Number of training samples: {train_data.size(0)}") + print(f"> Number of validation samples: {valid_data.size(0)}") + + root_ns = juice.structures.HMM( + seq_length = seq_length, + num_latents = 256, + num_emits = vocab_size, + homogeneous = True + ) + + pc = juice.compile(root_ns) + pc.to(device) + + best_valid_ll = train(pc, 20, train_loader, valid_loader, device, propagation_alg = "LL") + + assert best_valid_ll > -95.0 + + +@pytest.mark.slow +def test_hmm_em_slow(): + + device = torch.device("cuda:0") + + seq_length = 32 + + train_data, valid_data, test_data = load_penn_treebank(seq_length = seq_length) + + vocab_size = train_data.max().item() + 1 + + train_loader = DataLoader( + dataset = TensorDataset(train_data), + batch_size = 512, + shuffle = True, + drop_last = True + ) + valid_loader = DataLoader( + dataset = TensorDataset(valid_data), + batch_size = 512, + shuffle = False, + drop_last = True + ) + + print(f"> Number of training samples: {train_data.size(0)}") + print(f"> Number of validation samples: {valid_data.size(0)}") + + root_ns = juice.structures.HMM( + seq_length = seq_length, + num_latents = 256, + num_emits = vocab_size, + homogeneous = True + ) + + pc = juice.compile(root_ns) + pc.to(device) + + best_valid_ll = train(pc, 200, train_loader, valid_loader, device, propagation_alg = "LL") + + assert best_valid_ll > -95.0 + + +if __name__ == "__main__": + test_hmm_em() + test_hmm_em_slow() \ No newline at end of file diff --git a/tests/optim/hmm_general_em_test.py b/tests/optim/hmm_general_em_test.py index 6ca100d0..d71c10ba 100644 --- a/tests/optim/hmm_general_em_test.py +++ b/tests/optim/hmm_general_em_test.py @@ -135,6 +135,48 @@ def test_hmm_general_ll(): assert best_valid_ll > -85.0 +@pytest.mark.slow +def test_hmm_general_ll_slow(): + + device = torch.device("cuda:0") + + seq_length = 32 + + train_data, valid_data, test_data = load_penn_treebank(seq_length = seq_length) + + vocab_size = train_data.max().item() + 1 + + train_loader = DataLoader( + dataset = TensorDataset(train_data), + batch_size = 512, + shuffle = True, + drop_last = True + ) + valid_loader = DataLoader( + dataset = TensorDataset(valid_data), + batch_size = 512, + shuffle = False, + drop_last = True + ) + + print(f"> Number of training samples: {train_data.size(0)}") + print(f"> Number of validation samples: {valid_data.size(0)}") + + root_ns = juice.structures.HMM( + seq_length = seq_length, + num_latents = 256, + num_emits = vocab_size, + homogeneous = True + ) + + pc = juice.compile(root_ns) + pc.to(device) + + best_valid_ll = train(pc, 300, train_loader, valid_loader, device) + + assert best_valid_ll > -85.0 + + def test_hmm_general_ll_fast(): device = torch.device("cuda:0") @@ -177,5 +219,6 @@ def test_hmm_general_ll_fast(): if __name__ == "__main__": - test_hmm_general_ll() - test_hmm_general_ll_fast() \ No newline at end of file + # test_hmm_general_ll() + # test_hmm_general_ll_fast() + test_hmm_general_ll_slow() \ No newline at end of file diff --git a/tests/optim/hmm_viterbi_test.py b/tests/optim/hmm_viterbi_test.py index 748cf75f..f265b82a 100644 --- a/tests/optim/hmm_viterbi_test.py +++ b/tests/optim/hmm_viterbi_test.py @@ -52,7 +52,7 @@ def load_penn_treebank(seq_length = 32): return train_data, valid_data, test_data -def train(pc, num_epochs, train_loader, valid_loader, device): +def train(pc, num_epochs, train_loader, valid_loader, device, propagation_alg, **kwargs): best_valid_ll = -10000.0 for epoch in range(1, num_epochs + 1): @@ -61,7 +61,7 @@ def train(pc, num_epochs, train_loader, valid_loader, device): for batch in train_loader: x = batch[0].to(device) - lls = pc(x, propagation_alg = "MPE") + lls = pc(x, propagation_alg = propagation_alg, **kwargs) lls.mean().backward() train_ll += lls.mean().detach().cpu().numpy().item() @@ -130,7 +130,49 @@ def test_hmm_viterbi(): pc = juice.compile(root_ns) pc.to(device) - best_valid_ll = train(pc, 20, train_loader, valid_loader, device) + best_valid_ll = train(pc, 20, train_loader, valid_loader, device, propagation_alg = "MPE") + + assert best_valid_ll > -90.0 + + +@pytest.mark.slow +def test_hmm_viterbi_slow(): + + device = torch.device("cuda:0") + + seq_length = 32 + + train_data, valid_data, test_data = load_penn_treebank(seq_length = seq_length) + + vocab_size = train_data.max().item() + 1 + + train_loader = DataLoader( + dataset = TensorDataset(train_data), + batch_size = 512, + shuffle = True, + drop_last = True + ) + valid_loader = DataLoader( + dataset = TensorDataset(valid_data), + batch_size = 512, + shuffle = False, + drop_last = True + ) + + print(f"> Number of training samples: {train_data.size(0)}") + print(f"> Number of validation samples: {valid_data.size(0)}") + + root_ns = juice.structures.HMM( + seq_length = seq_length, + num_latents = 256, + num_emits = vocab_size, + homogeneous = True + ) + + pc = juice.compile(root_ns) + pc.to(device) + + best_valid_ll = train(pc, 200, train_loader, valid_loader, device, propagation_alg = "MPE") assert best_valid_ll > -90.0 @@ -171,11 +213,12 @@ def test_hmm_viterbi_fast(): pc = juice.compile(root_ns) pc.to(device) - best_valid_ll = train(pc, 5, train_loader, valid_loader, device) + best_valid_ll = train(pc, 5, train_loader, valid_loader, device, propagation_alg = "MPE") assert best_valid_ll > -90.0 if __name__ == "__main__": # test_hmm_viterbi() - test_hmm_viterbi_fast() \ No newline at end of file + test_hmm_viterbi_slow() + # test_hmm_viterbi_fast() \ No newline at end of file From b3ac66a0f88f3956089254c7ab30c5c6c560103a Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Tue, 12 Mar 2024 17:20:00 +0800 Subject: [PATCH 19/53] use `bfloat16` in forward pass instead of `float16` --- src/pyjuice/layer/sum_layer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index 7b9abcfd..daccab0d 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -494,9 +494,9 @@ def _fw_triton_block_sparse_tlmm_kernel(node_mars, element_mars, params, nids, c if use_fp16 == 1: # Built-in matmul kernel of triton + float16 - epars_fp16 = (epars * (2**12)).to(tl.float16) + epars_fp16 = epars.to(tl.float16) emars_fp16 = emars_sub.to(tl.float16) - nmars = tl.dot(epars_fp16, emars_fp16).to(tl.float32) / (2**12) + nmars = tl.dot(epars_fp16, emars_fp16).to(tl.float32) else: # Built-in matmul kernel of triton + float32 nmars = tl.dot(epars, emars_sub) @@ -610,9 +610,9 @@ def _fw_triton_block_sparse_csmm1_kernel(node_mars, element_mars, params, nids, if use_fp16 == 1: # Simulated matmul kernel + float16 - epars = (epars * (2**4)).to(tl.float16) - emars_sub = emars_sub.to(tl.float16) - nmars = tl.sum(epars[:,:,None] * emars_sub[None,:,:], axis = 1).to(tl.float32) / (2**4) + epars = epars.to(tl.bfloat16) + emars_sub = emars_sub.to(tl.bfloat16) + nmars = tl.sum(epars[:,:,None] * emars_sub[None,:,:], axis = 1).to(tl.float32) else: # Simulated matmul kernel + float32 nmars = tl.sum(epars[:,:,None] * emars_sub[None,:,:], axis = 1) From 3c763b1fbbe7acadf68c83f99680e70789e715ae Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Tue, 12 Mar 2024 17:22:40 +0800 Subject: [PATCH 20/53] use `bfloat16` instead of `float16` in forward pas --- src/pyjuice/layer/sum_layer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index daccab0d..5831242c 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -494,8 +494,8 @@ def _fw_triton_block_sparse_tlmm_kernel(node_mars, element_mars, params, nids, c if use_fp16 == 1: # Built-in matmul kernel of triton + float16 - epars_fp16 = epars.to(tl.float16) - emars_fp16 = emars_sub.to(tl.float16) + epars_fp16 = epars.to(tl.bfloat16) + emars_fp16 = emars_sub.to(tl.bfloat16) nmars = tl.dot(epars_fp16, emars_fp16).to(tl.float32) else: # Built-in matmul kernel of triton + float32 From 16c41e0ee6c28e329630da78cbb92b7e3a49f0ce Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Tue, 12 Mar 2024 22:41:46 +0800 Subject: [PATCH 21/53] fix general EM parameter update kernels --- src/pyjuice/layer/sum_layer.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index 5831242c..b3e15a12 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -1956,11 +1956,14 @@ def _bk_triton_block_sparse_par_kernel(node_flows, node_mars, element_mars, para else: - if propagation_alg_id == 2: - emars *= alpha + # if propagation_alg_id == 2: + # emars *= alpha if allow_modify_flows == 1: log_n_fdm = tl.load(nflows_ptr, mask = mask_batch[None,:], other = -float("inf")) # [TILE_SIZE_M, TILE_SIZE_B] + nmars = tl.load(nmars_ptr, mask = mask_batch[None,:], other = 0.0) # [TILE_SIZE_M, TILE_SIZE_B] + + log_n_fdm += (1.0 - alpha) * nmars else: nflows = tl.load(nflows_ptr, mask = mask_batch[None,:], other = 0.0) # [TILE_SIZE_M, TILE_SIZE_B] nmars = tl.load(nmars_ptr, mask = mask_batch[None,:], other = 0.0) # [TILE_SIZE_M, TILE_SIZE_B] @@ -1969,7 +1972,7 @@ def _bk_triton_block_sparse_par_kernel(node_flows, node_mars, element_mars, para log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars) if propagation_alg_id == 2: - log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars * alpha) + log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars) log_n_fdm_max = tl.max(log_n_fdm, axis = 0) n_fdm_sub = tl.where(log_n_fdm_max[None,:] != -float("inf"), tl.exp(log_n_fdm - log_n_fdm_max[None,:]), 0.0) @@ -1999,8 +2002,8 @@ def _bk_triton_block_sparse_par_kernel(node_flows, node_mars, element_mars, para epars_offsets = offs_node[:,None] + par_start[None,:] # [TILE_SIZE_M, TILE_SIZE_K] epars = tl.load(params + epars_offsets) - if propagation_alg_id == 2: - epars = tl.exp(tl.log(epars) * alpha) + # if propagation_alg_id == 2: + # epars = tl.exp(tl.log(epars) * alpha) if propagation_alg_id != 1: pflows = acc * epars @@ -2066,11 +2069,14 @@ def _bk_triton_block_sparse_par_csmm2_kernel(node_flows, node_mars, element_mars else: - if propagation_alg_id == 2: - emars *= alpha + # if propagation_alg_id == 2: + # emars *= alpha if allow_modify_flows == 1: log_n_fdm = tl.load(nflows_ptr, mask = mask_batch[:,None], other = -float("inf")) # [TILE_SIZE_B, TILE_SIZE_M] + nmars = tl.load(nmars_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_M] + + log_n_fdm += (1.0 - alpha) * nmars else: nflows = tl.load(nflows_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_M] nmars = tl.load(nmars_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_M] @@ -2079,7 +2085,7 @@ def _bk_triton_block_sparse_par_csmm2_kernel(node_flows, node_mars, element_mars log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars) if propagation_alg_id == 2: - log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars * alpha) + log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars) log_n_fdm_max = tl.max(log_n_fdm, axis = 1) n_fdm_sub = tl.where(log_n_fdm_max[:,None] != -float("inf"), tl.exp(log_n_fdm - log_n_fdm_max[:,None]), 0.0) @@ -2106,8 +2112,8 @@ def _bk_triton_block_sparse_par_csmm2_kernel(node_flows, node_mars, element_mars epars_offsets = offs_node[:,None] + par_start[None,:] # [TILE_SIZE_M, TILE_SIZE_K] epars = tl.load(params + epars_offsets) - if propagation_alg_id == 2: - epars = tl.exp(tl.log(epars) * alpha) + # if propagation_alg_id == 2: + # epars = tl.exp(tl.log(epars) * alpha) if propagation_alg_id != 1: pflows = acc * epars From ae66bc6c609f91bce5a40acaf04e56ec5faf1b55 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Tue, 12 Mar 2024 22:41:55 +0800 Subject: [PATCH 22/53] temp commit --- src/pyjuice/model/tensorcircuit.py | 9 +++++++++ src/pyjuice/optim/optim.py | 4 ++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/pyjuice/model/tensorcircuit.py b/src/pyjuice/model/tensorcircuit.py index c3995dd5..f26390fb 100644 --- a/src/pyjuice/model/tensorcircuit.py +++ b/src/pyjuice/model/tensorcircuit.py @@ -407,6 +407,15 @@ def _run_inner_layers(): else: return None + def forward_ll(self, *args, **kwargs): + self.forward(*args, propagation_alg = "LL", **kwargs) + + def forward_mpe(self, *args, **kwargs): + self.forward(*args, propagation_alg = "MPE", **kwargs) + + def forward_general_ll(self, *args, alpha: float = 1.0, **kwargs): + self.forward(*args, propagation_alg = "GeneralLL", **kwargs) + def mini_batch_em(self, step_size: float, pseudocount: float = 0.0, keep_zero_params: bool = False): """ Perform an EM parameter update step using the accumulated parameter flows. diff --git a/src/pyjuice/optim/optim.py b/src/pyjuice/optim/optim.py index 16258395..3c89cb19 100644 --- a/src/pyjuice/optim/optim.py +++ b/src/pyjuice/optim/optim.py @@ -8,10 +8,10 @@ class CircuitOptimizer(): - SUPPORTED_OPTIM_METHODS = ["EM"] + SUPPORTED_OPTIM_METHODS = ["EM", "Viterbi", "GeneralEM"] def __init__(self, pc: TensorCircuit, base_optimizer: Optional[Optimizer] = None, method: str = "EM", lr: float = 0.1, - pseudocount: float = 0.1): + pseudocount: float = 0.1, **kwargs): self.pc = pc From 6299bdcd144b1eedfe372d7a5cc46fece5eeeba9 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Tue, 12 Mar 2024 23:54:03 +0800 Subject: [PATCH 23/53] fix runtests except `hclt_correctness_test` --- src/pyjuice/layer/sum_layer.py | 40 ++++++++++++++--------- tests/layer/propagation_algs_test.py | 4 +-- tests/optim/hmm_general_em_test.py | 4 +-- tests/structures/hclt_correctness_test.py | 8 +++-- 4 files changed, 34 insertions(+), 22 deletions(-) diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index b3e15a12..7ee7f3a2 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -494,9 +494,9 @@ def _fw_triton_block_sparse_tlmm_kernel(node_mars, element_mars, params, nids, c if use_fp16 == 1: # Built-in matmul kernel of triton + float16 - epars_fp16 = epars.to(tl.bfloat16) - emars_fp16 = emars_sub.to(tl.bfloat16) - nmars = tl.dot(epars_fp16, emars_fp16).to(tl.float32) + epars_fp16 = (epars * (2**12)).to(tl.float16) + emars_fp16 = emars_sub.to(tl.float16) + nmars = tl.dot(epars_fp16, emars_fp16).to(tl.float32) / (2**12) else: # Built-in matmul kernel of triton + float32 nmars = tl.dot(epars, emars_sub) @@ -610,9 +610,9 @@ def _fw_triton_block_sparse_csmm1_kernel(node_mars, element_mars, params, nids, if use_fp16 == 1: # Simulated matmul kernel + float16 - epars = epars.to(tl.bfloat16) - emars_sub = emars_sub.to(tl.bfloat16) - nmars = tl.sum(epars[:,:,None] * emars_sub[None,:,:], axis = 1).to(tl.float32) + epars = (epars * (2**4)).to(tl.float16) + emars_sub = emars_sub.to(tl.float16) + nmars = tl.sum(epars[:,:,None] * emars_sub[None,:,:], axis = 1).to(tl.float32) / (2**4) else: # Simulated matmul kernel + float32 nmars = tl.sum(epars[:,:,None] * emars_sub[None,:,:], axis = 1) @@ -1961,9 +1961,10 @@ def _bk_triton_block_sparse_par_kernel(node_flows, node_mars, element_mars, para if allow_modify_flows == 1: log_n_fdm = tl.load(nflows_ptr, mask = mask_batch[None,:], other = -float("inf")) # [TILE_SIZE_M, TILE_SIZE_B] - nmars = tl.load(nmars_ptr, mask = mask_batch[None,:], other = 0.0) # [TILE_SIZE_M, TILE_SIZE_B] - log_n_fdm += (1.0 - alpha) * nmars + if propagation_alg_id == 2: + nmars = tl.load(nmars_ptr, mask = mask_batch[None,:], other = 0.0) # [TILE_SIZE_M, TILE_SIZE_B] + log_n_fdm += (alpha - 1.0) * nmars else: nflows = tl.load(nflows_ptr, mask = mask_batch[None,:], other = 0.0) # [TILE_SIZE_M, TILE_SIZE_B] nmars = tl.load(nmars_ptr, mask = mask_batch[None,:], other = 0.0) # [TILE_SIZE_M, TILE_SIZE_B] @@ -2074,9 +2075,10 @@ def _bk_triton_block_sparse_par_csmm2_kernel(node_flows, node_mars, element_mars if allow_modify_flows == 1: log_n_fdm = tl.load(nflows_ptr, mask = mask_batch[:,None], other = -float("inf")) # [TILE_SIZE_B, TILE_SIZE_M] - nmars = tl.load(nmars_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_M] - log_n_fdm += (1.0 - alpha) * nmars + if propagation_alg_id == 2: + nmars = tl.load(nmars_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_M] + log_n_fdm += (alpha - 1.0) * nmars else: nflows = tl.load(nflows_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_M] nmars = tl.load(nmars_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_M] @@ -2625,12 +2627,18 @@ def _bk_triton_sparse_par_kernel(node_flows, node_mars, element_mars, params, pa else: - if propagation_alg_id == 2: - emars *= alpha + # if propagation_alg_id == 2: + # emars *= alpha if allow_modify_flows == 1: log_n_fdm = tl.load(nflows_ptr, mask = mask_batch, other = -float("inf")) # [BLOCK_B] - pflows = tl.sum(tl.exp(emars + log_n_fdm[None,:]), axis = 1) + + if propagation_alg_id == 0: + pflows = tl.sum(tl.exp(emars + log_n_fdm[None,:]), axis = 1) + + if propagation_alg_id == 2: + nmars = tl.load(nmars_ptr, mask = mask_batch, other = 0.0) # [BLOCK_B] + pflows = tl.sum(tl.exp(emars + log_n_fdm[None,:] + (alpha - 1.0) * nmars[None,:]), axis = 1) else: nmars = tl.load(nmars_ptr, mask = mask_batch, other = 0.0) # [BLOCK_B] nflows = tl.load(nflows_ptr, mask = mask_batch, other = 0.0) # [BLOCK_B] @@ -2639,7 +2647,7 @@ def _bk_triton_sparse_par_kernel(node_flows, node_mars, element_mars, params, pa pflows = tl.sum(nflows[None,:] * tl.exp(emars - nmars[None,:]), axis = 1) if propagation_alg_id == 2: - pflows = tl.sum(nflows[None,:] * tl.exp(emars - nmars[None,:] * alpha), axis = 1) + pflows = tl.sum(nflows[None,:] * tl.exp(emars - nmars[None,:]), axis = 1) acc += pflows @@ -2653,8 +2661,8 @@ def _bk_triton_sparse_par_kernel(node_flows, node_mars, element_mars, params, pa epars_ptr = params + par_start + tile_id epars = tl.load(epars_ptr) # [BLOCK_K] - if propagation_alg_id == 2: - epars = tl.exp(tl.log(epars) * alpha) + # if propagation_alg_id == 2: + # epars = tl.exp(tl.log(epars) * alpha) parflow_start = tl.load(pfids + nblock_id * num_edges + offs_edge) eparflows_ptr = param_flows + parflow_start + tile_id diff --git a/tests/layer/propagation_algs_test.py b/tests/layer/propagation_algs_test.py index c312a0d2..db90f9bf 100644 --- a/tests/layer/propagation_algs_test.py +++ b/tests/layer/propagation_algs_test.py @@ -228,11 +228,11 @@ def test_general_ll_prop(): epars = params[layer.partitioned_pids[0][j,:]+i] nmars = node_mars[(j+1)*block_size+i,:].exp() nflows = origin_node_flows[(j+1)*block_size+i,:] - pflows = epars ** alpha * (nflows[None,:] * emars ** alpha / nmars[None,:] ** alpha).sum(dim = 1) + pflows = epars * (nflows[None,:] * emars / nmars[None,:]).sum(dim = 1) my_pflows[layer.partitioned_pfids[0][j,:]+i] = pflows - assert torch.all(torch.abs(my_pflows - param_flows) < 2e-3) + assert torch.all(torch.abs(my_pflows - param_flows) < 4e-3) def test_mpe_prop(): diff --git a/tests/optim/hmm_general_em_test.py b/tests/optim/hmm_general_em_test.py index d71c10ba..008ab4f9 100644 --- a/tests/optim/hmm_general_em_test.py +++ b/tests/optim/hmm_general_em_test.py @@ -219,6 +219,6 @@ def test_hmm_general_ll_fast(): if __name__ == "__main__": - # test_hmm_general_ll() - # test_hmm_general_ll_fast() + test_hmm_general_ll() + test_hmm_general_ll_fast() test_hmm_general_ll_slow() \ No newline at end of file diff --git a/tests/structures/hclt_correctness_test.py b/tests/structures/hclt_correctness_test.py index 2b7fb03e..fa284ffd 100644 --- a/tests/structures/hclt_correctness_test.py +++ b/tests/structures/hclt_correctness_test.py @@ -295,8 +295,12 @@ def test_hclt_backward(): sid, eid = ns._output_ind_range - if len(ns.scope) > 2: - assert torch.all(torch.abs(nflows - node_flows[sid:eid,:]) < 1e-3) + # if len(ns.scope) > 2: + # if not torch.all(torch.abs(nflows - node_flows[sid:eid,:]) < 1e-3): + # import pdb; pdb.set_trace() + # assert torch.all(torch.abs(nflows - node_flows[sid:eid,:]) < 1e-3) + + # ns2flows[ns] = node_flows[sid:eid,:] nmars = node_mars[sid:eid,:] From 2787361e186eb5a443bacd6ab2384ba52d14db07 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Wed, 13 Mar 2024 04:17:02 +0800 Subject: [PATCH 24/53] stabilize hclt correctness test --- tests/structures/hclt_correctness_test.py | 83 ++++++++++++++++++++++- 1 file changed, 80 insertions(+), 3 deletions(-) diff --git a/tests/structures/hclt_correctness_test.py b/tests/structures/hclt_correctness_test.py index fa284ffd..3886a6e5 100644 --- a/tests/structures/hclt_correctness_test.py +++ b/tests/structures/hclt_correctness_test.py @@ -191,8 +191,6 @@ def test_hclt_single_layer_backward(): def test_hclt_backward(): - torch.manual_seed(18329) - device = torch.device("cuda:0") train_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = True, download = True) @@ -232,9 +230,18 @@ def test_hclt_backward(): node_mars = pc.node_mars node_flows = pc.node_flows + temp_node_mars = pc.node_mars.clone() + temp_node_flows = pc.node_flows.clone() + temp_element_mars = pc.element_mars.clone() + temp_element_flows = pc.element_flows.clone() + temp_params = pc.params + temp_param_flows = pc.param_flows.clone() + ns2flows = dict() ns2flows[root_ns] = torch.ones([1, batch_size], device = device) + gt_ch_flows = dict() + ch2par = dict() for ns in root_ns: for cs in ns.chs: @@ -279,8 +286,14 @@ def test_hclt_backward(): ns2flows[cs] = torch.zeros([num_latents, batch_size], device = device) ns2flows[cs] += eflows + gt_ch_flows[cs] = eflows.detach().clone() + elif ns.is_prod(): nflows = ns2flows[ns] + gt_flows = gt_ch_flows[ns] + + assert torch.all(torch.abs(gt_flows - nflows) < 1e-4) + for cs in ns.chs: if cs not in ns2flows: ns2flows[cs] = torch.zeros([num_latents, batch_size], device = device) @@ -301,9 +314,15 @@ def test_hclt_backward(): # assert torch.all(torch.abs(nflows - node_flows[sid:eid,:]) < 1e-3) # ns2flows[ns] = node_flows[sid:eid,:] + # print(">>>>>>", torch.abs(nflows - node_flows[sid:eid,:]).max()) + + nflows = node_flows[sid:eid,:] nmars = node_mars[sid:eid,:] + ch_eflows = [] + ch_pflows = [] + for i, cs in enumerate(ns.chs): params = ns._params.reshape(num_blocks, num_blocks * ns.num_chs, block_size, block_size).permute(0, 2, 1, 3).to(device) params = params[:,:,i*num_blocks:(i+1)*num_blocks,:].reshape(num_latents, num_latents) @@ -322,12 +341,71 @@ def test_hclt_backward(): eflows = (nflows[None,:,:] * (params.permute(1, 0).log()[:,:,None] + emars[:,None,:] - nmars[None,:,:]).exp()).sum(dim = 1) pflows = (nflows[None,:,:] * (params.permute(1, 0).log()[:,:,None] + emars[:,None,:] - nmars[None,:,:]).exp()).sum(dim = 2).permute(1, 0) + # log_n_fdm = nflows.log() - nmars + # log_n_fdm_max = log_n_fdm.max(dim = 0).values + # n_fdm_sub = (log_n_fdm - log_n_fdm_max[None,:]).exp() + + # eflows_prim = torch.matmul(params.permute(1, 0), n_fdm_sub) * (emars + log_n_fdm_max[None,:]).exp() + + # From `pc` + pc_eflows = (node_flows[sid:eid,:][None,:,:] * (params.permute(1, 0).log()[:,:,None] + emars[:,None,:] - nmars[None,:,:]).exp()).sum(dim = 1) + pc_pflows = (node_flows[sid:eid,:][None,:,:] * (params.permute(1, 0).log()[:,:,None] + emars[:,None,:] - nmars[None,:,:]).exp()).sum(dim = 2).permute(1, 0) + + # log_n_fdm = node_flows[sid:eid,:].log() - nmars + # log_n_fdm_max = log_n_fdm.max(dim = 0).values + # n_fdm_sub = (log_n_fdm - log_n_fdm_max[None,:]).exp() + + # pc_eflows_prim = torch.matmul(params.permute(1, 0), n_fdm_sub) * (emars + log_n_fdm_max[None,:]).exp() + + # print(torch.abs(eflows - pc_eflows).max()) + + ch_eflows.append(eflows) + ch_pflows.append(pflows) + assert torch.all(torch.abs(pflows - param_flows) < 1e-3 * batch_size) if cs not in ns2flows: ns2flows[cs] = torch.zeros([num_latents, batch_size], device = device) ns2flows[cs] += eflows + ## Run the actual layer ## + + curr_layer_id = -1 + curr_layer = None + for layer_id in range(1, len(pc.inner_layer_groups), 2): + layer = pc.inner_layer_groups[layer_id][0] + if ns in layer.nodes: + curr_layer_id = layer_id + curr_layer = layer + + assert curr_layer is not None + + nsid, neid = ns._output_ind_range + + temp_node_flows[nsid:neid,:] = nflows + temp_param_flows[:] = 0.0 + + pc.inner_layer_groups[curr_layer_id - 1].forward(temp_node_mars, temp_element_mars, _for_backward = True) + + curr_layer.backward(temp_node_flows, temp_element_flows, temp_node_mars, temp_element_mars, temp_params, + param_flows = temp_param_flows, allow_modify_flows = False, propagation_alg = "LL") + + pfsid, pfeid = ns._param_flow_range + + for i, cs in enumerate(ns.chs): + eflows = ch_eflows[i] + pflows = ch_pflows[i] + + csid, ceid = cs._output_ind_range + + # print("value", torch.abs(eflows - temp_element_flows[csid:ceid,:]).max()) + + assert torch.all(torch.abs(eflows - temp_element_flows[csid:ceid,:]) < 1e-3) + assert torch.all(torch.abs(temp_param_flows[pfsid:pfeid].reshape(num_latents, num_latents) - pflows.permute(1, 0)) < batch_size * 1e-4) + + assert cs not in gt_ch_flows + gt_ch_flows[cs] = eflows.detach().clone() + def test_hclt_em(): @@ -411,7 +489,6 @@ def test_hclt_em(): if __name__ == "__main__": - # torch.manual_seed(320942) test_hclt_forward() test_hclt_single_layer_backward() test_hclt_backward() From 80ef1a2fc408fa80158678ee30d6b23d806580e0 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Wed, 13 Mar 2024 04:18:49 +0800 Subject: [PATCH 25/53] improve sum layer ele backward pass (compute on arithmetic space) --- src/pyjuice/layer/sum_layer.py | 96 ++++++++-------------------------- 1 file changed, 22 insertions(+), 74 deletions(-) diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index 7ee7f3a2..ca9aa623 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -1529,16 +1529,15 @@ def _bk_triton_block_sparse_ele_kernel(node_flows, element_flows, node_mars, ele parpids_inc_ptr = parpids_increment + eleblock_id * (K_NUM_TILES * ptr_inc_step) + offs_edge_gid # Initialize pointers to `element_mars` (only when using MPE propagation method) - if propagation_alg_id == 1: - off_eleids = tl.load(chids + eleblock_id) - emars_ptr = element_mars + (off_eleids + offs_ele[:,None]) * batch_size + offs_batch[None,:] - emars = tl.load(emars_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_M, BLOCK_B] + off_eleids = tl.load(chids + eleblock_id) + emars_ptr = element_mars + (off_eleids + offs_ele[:,None]) * batch_size + offs_batch[None,:] + emars = tl.load(emars_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_M, BLOCK_B] + + if propagation_alg_id == 2: + emars *= alpha # Inner loop - if propagation_alg_id != 1: - acc = tl.zeros([TILE_SIZE_M, BLOCK_B], dtype = tl.float32) - float("inf") - else: - acc = tl.zeros([TILE_SIZE_M, BLOCK_B], dtype = tl.float32) + acc = tl.zeros([TILE_SIZE_M, BLOCK_B], dtype = tl.float32) for k in range(0, K_NUM_TILES): epars = tl.load(epars_ptr) # [TILE_SIZE_M, TILE_SIZE_K] @@ -1575,19 +1574,7 @@ def _bk_triton_block_sparse_ele_kernel(node_flows, element_flows, node_mars, ele else: partial_flows = tl.sum(epars[:,:,None] * n_fdm_sub[None,:,:], axis = 1) - acc = tl.where(log_n_fdm_max == acc, - acc + 0.69314718056, # log(2) - tl.where(log_n_fdm_max > acc, - tl.log(partial_flows + tl.exp(acc - log_n_fdm_max)) + log_n_fdm_max, - tl.log(tl.exp(log_n_fdm_max - acc) * partial_flows + 1.0) + acc - ) - ) - # neginf_flag = (log_n_fdm_max == -float("inf")) & (acc == -float("inf")) - # acc = tl.where(log_n_fdm_max > acc, - # tl.log(partial_flows + tl.exp(acc - log_n_fdm_max)) + log_n_fdm_max, - # tl.log(tl.exp(log_n_fdm_max - acc) * partial_flows + 1.0) + acc - # ) - # acc = tl.where(neginf_flag, -float("inf"), acc) + acc += partial_flows * tl.exp(emars + log_n_fdm_max) # Increment `epars_ptr` parpids_inc = tl.load(parpids_inc_ptr) @@ -1601,23 +1588,9 @@ def _bk_triton_block_sparse_ele_kernel(node_flows, element_flows, node_mars, ele nflows_ptr += parids_inc[:,None] * batch_size parids_inc_ptr += ptr_inc_step - # Initialize pointers to `element_mars` (only when NOT using MPE propagation method) - if propagation_alg_id != 1: - off_eleids = tl.load(chids + eleblock_id) - emars_ptr = element_mars + (off_eleids + offs_ele[:,None]) * batch_size + offs_batch[None,:] - emars = tl.load(emars_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_M, BLOCK_B] - - if propagation_alg_id == 2: - emars *= alpha - - if propagation_alg_id != 1: - eflows = tl.exp(acc + emars) - else: - eflows = acc - # Write back offs_elemfs = (off_eleids + offs_ele[:,None]) * batch_size + offs_batch[None,:] - tl.store(element_flows + offs_elemfs, eflows, mask = mask_batch[None,:]) + tl.store(element_flows + offs_elemfs, acc, mask = mask_batch[None,:]) @staticmethod # @triton.jit @@ -1668,16 +1641,15 @@ def _bk_triton_block_sparse_ele_csmm2_kernel(node_flows, element_flows, node_mar parpids_inc_ptr = parpids_increment + eleblock_id * (K_NUM_TILES * ptr_inc_step) + offs_edge_gid # Initialize pointers to `element_mars` - if propagation_alg_id == 1: - off_eleids = tl.load(chids + eleblock_id) - emars_ptr = element_mars + (off_eleids + offs_ele[:,None]) * batch_size + offs_batch[None,:] - emars = tl.load(emars_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_M, BLOCK_B] + off_eleids = tl.load(chids + eleblock_id) + emars_ptr = element_mars + (off_eleids + offs_ele[:,None]) * batch_size + offs_batch[None,:] + emars = tl.load(emars_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_M, BLOCK_B] + + if propagation_alg_id == 2: + emars *= alpha # Inner loop - if propagation_alg_id != 1: - acc = tl.zeros([TILE_SIZE_M, BLOCK_B], dtype = tl.float32) - float("inf") - else: - acc = tl.zeros([TILE_SIZE_M, BLOCK_B], dtype = tl.float32) + acc = tl.zeros([TILE_SIZE_M, BLOCK_B], dtype = tl.float32) for k in range(0, K_NUM_TILES): epars = tl.load(epars_ptr) # [TILE_SIZE_M, TILE_SIZE_K] @@ -1713,19 +1685,7 @@ def _bk_triton_block_sparse_ele_csmm2_kernel(node_flows, element_flows, node_mar partial_flows = tl.sum(epars[:,:,None] * tl.trans(n_fdm_sub)[None,:,:], axis = 1) - acc = tl.where(log_n_fdm_max[None,:] == acc, - acc + 0.69314718056, # log(2) - tl.where(log_n_fdm_max[None,:] > acc, - tl.log(partial_flows + tl.exp(acc - log_n_fdm_max[None,:])) + log_n_fdm_max[None,:], - tl.log(tl.exp(log_n_fdm_max[None,:] - acc) * partial_flows + 1.0) + acc - ) - ) - # neginf_flag = (log_n_fdm_max[None,:] == -float("inf")) & (acc == -float("inf")) - # acc = tl.where(log_n_fdm_max[None,:] > acc, - # tl.log(partial_flows + tl.exp(acc - log_n_fdm_max[None,:])) + log_n_fdm_max[None,:], - # tl.log(tl.exp(log_n_fdm_max[None,:] - acc) * partial_flows + 1.0) + acc - # ) - # acc = tl.where(neginf_flag, -float("inf"), acc) + acc += partial_flows * tl.exp(emars + log_n_fdm_max[None,:]) # Increment `epars_ptr` parpids_inc = tl.load(parpids_inc_ptr) @@ -1739,23 +1699,9 @@ def _bk_triton_block_sparse_ele_csmm2_kernel(node_flows, element_flows, node_mar nflows_ptr += parids_inc[None,:] * batch_size parids_inc_ptr += ptr_inc_step - # Initialize pointers to `element_mars` - if propagation_alg_id != 1: - off_eleids = tl.load(chids + eleblock_id) - emars_ptr = element_mars + (off_eleids + offs_ele[:,None]) * batch_size + offs_batch[None,:] - emars = tl.load(emars_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_M, BLOCK_B] - - if propagation_alg_id == 2: - emars *= alpha - - if propagation_alg_id != 1: - eflows = tl.exp(acc + emars) - else: - eflows = acc - # Write back offs_elemfs = (off_eleids + offs_ele[:,None]) * batch_size + offs_batch[None,:] - tl.store(element_flows + offs_elemfs, eflows, mask = mask_batch[None,:]) + tl.store(element_flows + offs_elemfs, acc, mask = mask_batch[None,:]) def _backward_block_sparse_ele_flows(self, node_flows: torch.Tensor, element_flows: torch.Tensor, params: torch.Tensor, node_mars: torch.Tensor, @@ -1959,15 +1905,17 @@ def _bk_triton_block_sparse_par_kernel(node_flows, node_mars, element_mars, para # if propagation_alg_id == 2: # emars *= alpha + nmars = tl.load(nmars_ptr, mask = mask_batch[None,:], other = 0.0) # [TILE_SIZE_M, TILE_SIZE_B] + if allow_modify_flows == 1: log_n_fdm = tl.load(nflows_ptr, mask = mask_batch[None,:], other = -float("inf")) # [TILE_SIZE_M, TILE_SIZE_B] if propagation_alg_id == 2: - nmars = tl.load(nmars_ptr, mask = mask_batch[None,:], other = 0.0) # [TILE_SIZE_M, TILE_SIZE_B] + # nmars = tl.load(nmars_ptr, mask = mask_batch[None,:], other = 0.0) # [TILE_SIZE_M, TILE_SIZE_B] log_n_fdm += (alpha - 1.0) * nmars else: nflows = tl.load(nflows_ptr, mask = mask_batch[None,:], other = 0.0) # [TILE_SIZE_M, TILE_SIZE_B] - nmars = tl.load(nmars_ptr, mask = mask_batch[None,:], other = 0.0) # [TILE_SIZE_M, TILE_SIZE_B] + # nmars = tl.load(nmars_ptr, mask = mask_batch[None,:], other = 0.0) # [TILE_SIZE_M, TILE_SIZE_B] if propagation_alg_id == 0: log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars) From 42c80d59121301d53cec04ce151c4e2ee2f8c4b6 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Wed, 13 Mar 2024 16:33:31 +0800 Subject: [PATCH 26/53] ensure nodes in a layer are different --- src/pyjuice/layer/input_layer.py | 2 ++ src/pyjuice/layer/prod_layer.py | 1 + 2 files changed, 3 insertions(+) diff --git a/src/pyjuice/layer/input_layer.py b/src/pyjuice/layer/input_layer.py index ede8601f..41ce3c8d 100644 --- a/src/pyjuice/layer/input_layer.py +++ b/src/pyjuice/layer/input_layer.py @@ -30,6 +30,8 @@ def __init__(self, nodes: Sequence[InputNodes], cum_nodes: int = 0, pc_num_vars: in gradient accumulation. """ + assert len(nodes) == len(set(nodes)), "Input node list contains duplicates." + nn.Module.__init__(self) Layer.__init__(self, nodes, disable_block_size_check = True) diff --git a/src/pyjuice/layer/prod_layer.py b/src/pyjuice/layer/prod_layer.py index 1d4d901d..1e7b9534 100644 --- a/src/pyjuice/layer/prod_layer.py +++ b/src/pyjuice/layer/prod_layer.py @@ -27,6 +27,7 @@ def __init__(self, nodes: Sequence[ProdNodes], global_nid_start: Optional[int] = nn.Module.__init__(self) assert len(nodes) > 0, "No input node." + assert len(nodes) == len(set(nodes)), "Input node list contains duplicates." use_block_sparse_edges = True for nid in range(0, len(nodes)): From 4741c3f531705c35e4b63c92cfb61511adcc8eb5 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Wed, 13 Mar 2024 16:34:02 +0800 Subject: [PATCH 27/53] fix layering when adding new nodes --- src/pyjuice/layer/sum_layer.py | 1 + src/pyjuice/model/tensorcircuit.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index ca9aa623..f00660cc 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -40,6 +40,7 @@ def __init__(self, nodes: Sequence[SumNodes], global_nid_start: int, nn.Module.__init__(self) assert len(nodes) > 0, "No input node." + assert len(nodes) == len(set(nodes)), "Input node list contains duplicates." self.nodes = nodes diff --git a/src/pyjuice/model/tensorcircuit.py b/src/pyjuice/model/tensorcircuit.py index f26390fb..a8848677 100644 --- a/src/pyjuice/model/tensorcircuit.py +++ b/src/pyjuice/model/tensorcircuit.py @@ -847,12 +847,12 @@ def dfs(ns: CircuitNodes): ns.chs[idx] = pass_prod_ns depth2nodes[cs_depth]["sum"].append(pass_sum_ns) - depth2nodes[depth]["prod"].append(pass_prod_ns) nodes2depth[pass_sum_ns] = cs_depth nodes2depth[pass_prod_ns] = depth depth2nodes[depth]["sum"].append(ns) + if ns.block_size > max_node_block_size: max_node_block_size = ns.block_size elif ns.is_prod(): @@ -872,6 +872,7 @@ def dfs(ns: CircuitNodes): assert pns2layer[id(cs)] == layer, "Disallowed circumstance: a product node requested by sum nodes at different layers." else: depth2nodes[layer]["prod"].append(cs) + pns2layer[id(cs)] = layer return depth2nodes, num_layers, max_node_block_size, max_ele_block_size From 4de40523b0dd81ea226efa201e2b5e9b2fcbfa05 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Wed, 13 Mar 2024 16:34:18 +0800 Subject: [PATCH 28/53] more runtests --- tests/model/non_sd_pcs_test.py | 42 ++++- tests/structures/hclt_correctness_test.py | 14 +- tests/structures/hmm_correctness_test.py | 214 +++++++++++++++++++++- 3 files changed, 263 insertions(+), 7 deletions(-) diff --git a/tests/model/non_sd_pcs_test.py b/tests/model/non_sd_pcs_test.py index 22b570ff..b39578d0 100644 --- a/tests/model/non_sd_pcs_test.py +++ b/tests/model/non_sd_pcs_test.py @@ -101,6 +101,46 @@ def test_non_sd(): assert torch.all(torch.abs(bk78 - pc.node_flows[7:9,0].cpu()) < 1e-4) +def test_non_sd_generalized_em(): + ni1 = inputs(0, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) + ni2 = inputs(1, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) + ni3 = inputs(2, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) + ni4 = inputs(3, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) + + np12 = multiply(ni1, ni2) + np23 = multiply(ni2, ni3) + np34 = multiply(ni3, ni4) + + ns12 = summate(np12, num_nodes = 2) + ns23 = summate(np23, num_nodes = 2) + ns34 = summate(np34, num_nodes = 2) + + np1 = multiply(ns12, ns34) + np2 = multiply(ni1, ns23, ni4) + np3 = multiply(ni1, ni2, ns34) + + ns = summate(np1, np2, np3, num_nodes = 1) + + pc = TensorCircuit(ns) + + device = torch.device("cuda:0") + pc.to(device) + + data = torch.randint(0, 2, [16, 4]).to(device) + + alpha = 2.0 + + lls = pc(data, propagation_alg = "GeneralLL", alpha = alpha) + + pc.backward(data.permute(1, 0), allow_modify_flows = False, + propagation_alg = "GeneralLL", alpha = alpha) + + pc.update_parameters() + + import pdb; pdb.set_trace() + + if __name__ == "__main__": torch.manual_seed(129) - test_non_sd() \ No newline at end of file + # test_non_sd() + test_non_sd_generalized_em() \ No newline at end of file diff --git a/tests/structures/hclt_correctness_test.py b/tests/structures/hclt_correctness_test.py index 3886a6e5..aac42071 100644 --- a/tests/structures/hclt_correctness_test.py +++ b/tests/structures/hclt_correctness_test.py @@ -341,11 +341,14 @@ def test_hclt_backward(): eflows = (nflows[None,:,:] * (params.permute(1, 0).log()[:,:,None] + emars[:,None,:] - nmars[None,:,:]).exp()).sum(dim = 1) pflows = (nflows[None,:,:] * (params.permute(1, 0).log()[:,:,None] + emars[:,None,:] - nmars[None,:,:]).exp()).sum(dim = 2).permute(1, 0) - # log_n_fdm = nflows.log() - nmars - # log_n_fdm_max = log_n_fdm.max(dim = 0).values - # n_fdm_sub = (log_n_fdm - log_n_fdm_max[None,:]).exp() + log_n_fdm = nflows.log() - nmars + log_n_fdm_max = log_n_fdm.max(dim = 0).values + n_fdm_sub = (log_n_fdm - log_n_fdm_max[None,:]).exp() + + eflows_prim = torch.matmul(params.permute(1, 0), n_fdm_sub) * (emars + log_n_fdm_max[None,:]).exp() - # eflows_prim = torch.matmul(params.permute(1, 0), n_fdm_sub) * (emars + log_n_fdm_max[None,:]).exp() + scaled_emars = (emars + log_n_fdm_max[None,:]).exp() + pflows_prim = torch.matmul(n_fdm_sub, scaled_emars.permute(1, 0)) * params # From `pc` pc_eflows = (node_flows[sid:eid,:][None,:,:] * (params.permute(1, 0).log()[:,:,None] + emars[:,None,:] - nmars[None,:,:]).exp()).sum(dim = 1) @@ -357,7 +360,8 @@ def test_hclt_backward(): # pc_eflows_prim = torch.matmul(params.permute(1, 0), n_fdm_sub) * (emars + log_n_fdm_max[None,:]).exp() - # print(torch.abs(eflows - pc_eflows).max()) + # print(torch.abs(eflows - eflows_prim).max()) + # print(torch.abs(pflows - pflows_prim).max()) ch_eflows.append(eflows) ch_pflows.append(pflows) diff --git a/tests/structures/hmm_correctness_test.py b/tests/structures/hmm_correctness_test.py index 6c1ba89f..4a9df879 100644 --- a/tests/structures/hmm_correctness_test.py +++ b/tests/structures/hmm_correctness_test.py @@ -204,6 +204,218 @@ def test_hmm_forward_backward(): assert torch.all(torch.abs(pflows - cum_pflows) < 1e-4) +def test_hmm_forward_backward_with_generalized_em(): + + device = torch.device("cuda:0") + + seq_length = 16 + vocab_size = 1023 + batch_size = 32 + + num_node_blocks = 4 # 4096 // 32 # 4 + block_size = 1024 # 32 # 1024 + num_latents = block_size * num_node_blocks + + with juice.set_block_size(block_size = block_size): + ns_input = juice.inputs(seq_length - 1, num_node_blocks = num_node_blocks, + dist = dists.Categorical(num_cats = vocab_size)) + + ns_sum = None + curr_zs = ns_input + for var in range(seq_length - 2, -1, -1): + curr_xs = ns_input.duplicate(var, tie_params = True) + + if ns_sum is None: + ns = juice.summate( + curr_zs, num_node_blocks = num_node_blocks) + ns_sum = ns + else: + ns = ns_sum.duplicate(curr_zs, tie_params=True) + + curr_zs = juice.multiply(curr_xs, ns) + + root_ns = juice.summate(curr_zs, num_node_blocks = 1, block_size = 1) + + root_ns.init_parameters() + + pc = juice.TensorCircuit(root_ns) + pc.to(device) + + data = torch.randint(0, vocab_size, (batch_size, seq_length)).to(device) + data_cpu = data.cpu() + + ## Forward tests ## + + alpha = 1.2 + + lls = pc(data, propagation_alg = "GeneralLL", alpha = alpha) + + ns2mars = dict() + + node_mars = pc.node_mars.detach().cpu() + + with torch.no_grad(): + for ns in root_ns: + if ns.is_input(): + v = ns.scope.to_list()[0] + params = ns.get_source_ns()._params.reshape(num_latents, vocab_size) + + mars = params[:,data_cpu[:,v]].log() + + sid, eid = ns._output_ind_range + assert torch.all(torch.abs(mars - node_mars[sid:eid,:]) < 1e-4) + + ns2mars[ns] = mars + + elif ns.is_prod(): + mars = torch.zeros([num_latents, batch_size]) + for cs in ns.chs: + mars += ns2mars[cs] + + ns2mars[ns] = mars + + elif ns.is_sum() and ns != root_ns: + emars = torch.cat([ns2mars[cs] for cs in ns.chs], dim = 0) + params = ns.get_source_ns()._params.reshape(num_node_blocks, num_node_blocks * ns.num_chs, block_size, block_size).permute(0, 2, 1, 3) + params = params.reshape(num_latents, num_latents * ns.num_chs) + + emars = emars * alpha + params = (params.log() * alpha).exp() + + emars_max = torch.max(emars, dim = 0).values[None,:] + emars = (emars - emars_max).exp() + + nmars = torch.matmul(params, emars) + nmars = nmars.log() + emars_max + + nmars *= (1.0 / alpha) + + sid, eid = ns._output_ind_range + assert torch.all(torch.abs(nmars - node_mars[sid:eid,:]) < 1e-3) + + ns2mars[ns] = nmars + + else: + assert ns == root_ns + + emars = torch.cat([ns2mars[cs] for cs in ns.chs], dim = 0) + params = ns._params.reshape(1, num_node_blocks * ns.num_chs, 1, block_size).permute(0, 2, 1, 3) + params = params.reshape(1, num_latents * ns.num_chs) + + emars = emars * alpha + params = (params.log() * alpha).exp() + + emars_max = torch.max(emars, dim = 0).values[None,:] + emars = (emars - emars_max).exp() + + nmars = torch.matmul(params, emars) + nmars = nmars.log() + emars_max + + nmars *= (1.0 / alpha) + + sid, eid = ns._output_ind_range + assert torch.all(torch.abs(nmars - node_mars[sid:eid,:]) < 1e-3) + + ## Backward tests ## + + pc.backward(data.permute(1, 0), allow_modify_flows = False, + propagation_alg = "GeneralLL", alpha = alpha) + + pc.update_param_flows() + + node_mars = pc.node_mars.cpu() + node_flows = pc.node_flows.cpu() + + ns2flows = dict() + ns2flows[root_ns] = torch.ones([1, batch_size]) + + cum_pflows = 0.0 + + with torch.no_grad(): + for ns in root_ns(reverse = True): + if ns == root_ns: + + sid, eid = ns._output_ind_range + assert torch.all(torch.abs(node_flows[sid:eid,:] - 1.0) < 1e-4) + + nflows = ns2flows[ns] + nmars = node_mars[sid:eid,:] + + for i, cs in enumerate(ns.chs): + params = ns._params.reshape(1, num_node_blocks * ns.num_chs, 1, block_size).permute(0, 2, 1, 3) + params = params[:,:,i*num_node_blocks:(i+1)*num_node_blocks,:].reshape(1, num_latents) + + param_flows = ns._param_flows.reshape(1, num_node_blocks * ns.num_chs, 1, block_size).permute(0, 2, 1, 3) + param_flows = param_flows[:,:,i*num_node_blocks:(i+1)*num_node_blocks,:].reshape(1, num_latents) + + if cs.is_prod(): + emars = torch.zeros([num_latents, batch_size]) + for cns in cs.chs: + sid, eid = cns._output_ind_range + emars += node_mars[sid:eid,:] + else: + raise ValueError() + + eflows = nflows * (params.log() * alpha).exp().permute(1, 0) * ((emars - nmars) * alpha).exp() + pflows = (nflows * params.permute(1, 0) * (emars - nmars).exp()).sum(dim = 1) + + assert torch.all(torch.abs(pflows - param_flows[0,:]) < 1e-4 * batch_size) + + if cs not in ns2flows: + ns2flows[cs] = torch.zeros([num_latents, batch_size]) + ns2flows[cs] += eflows + + elif ns.is_prod(): + nflows = ns2flows[ns] + for cs in ns.chs: + if cs not in ns2flows: + ns2flows[cs] = torch.zeros([num_latents, batch_size]) + ns2flows[cs] += nflows + + elif ns.is_sum(): + + nflows = ns2flows[ns] + + sid, eid = ns._output_ind_range + + assert torch.all(torch.abs(nflows - node_flows[sid:eid,:]) < 1e-5) + + nmars = node_mars[sid:eid,:] + + for i, cs in enumerate(ns.chs): + params = ns.get_source_ns()._params.reshape(num_node_blocks, num_node_blocks * ns.num_chs, block_size, block_size).permute(0, 2, 1, 3) + params = params[:,:,i*num_node_blocks:(i+1)*num_node_blocks,:].reshape(num_latents, num_latents) + + if cs.is_prod(): + emars = torch.zeros([num_latents, batch_size]) + for cns in cs.chs: + sid, eid = cns._output_ind_range + emars += node_mars[sid:eid,:] + else: + raise ValueError() + + log_n_fdm = nflows.log() - nmars * alpha + log_n_fdm_max = log_n_fdm.max(dim = 0).values + n_fdm_sub = (log_n_fdm - log_n_fdm_max[None,:]).exp() + + eflows = torch.matmul((params.log() * alpha).exp().permute(1, 0), n_fdm_sub) * (emars * alpha + log_n_fdm_max[None,:]).exp() + + log_n_fdm = nflows.log() - nmars + log_n_fdm_max = log_n_fdm.max(dim = 0).values + scaled_emars = (emars + log_n_fdm_max[None,:]).exp() + pflows = torch.matmul(n_fdm_sub, scaled_emars.permute(1, 0)) * params + + cum_pflows = cum_pflows + pflows + + if cs not in ns2flows: + ns2flows[cs] = torch.zeros([num_latents, batch_size]) + ns2flows[cs] += eflows + + pflows = ns_sum._param_flows.reshape(num_node_blocks, num_node_blocks, block_size, block_size).permute( + 0, 2, 1, 3).flatten(2, 3).flatten(0, 1) + assert torch.all(torch.abs(pflows - cum_pflows) < 1e-4) + + if __name__ == "__main__": - torch.manual_seed(23289) test_hmm_forward_backward() + test_hmm_forward_backward_with_generalized_em() From d92f3427e66a25c64594478532c27cd9f2821f04 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Wed, 13 Mar 2024 17:21:43 +0800 Subject: [PATCH 29/53] fix runtests --- tests/model/non_sd_pcs_test.py | 4 +- tests/optim/hmm_general_em_test.py | 4 +- tests/structures/hclt_correctness_test.py | 97 +++++++++++++++++++++++ 3 files changed, 100 insertions(+), 5 deletions(-) diff --git a/tests/model/non_sd_pcs_test.py b/tests/model/non_sd_pcs_test.py index b39578d0..03cac563 100644 --- a/tests/model/non_sd_pcs_test.py +++ b/tests/model/non_sd_pcs_test.py @@ -137,10 +137,8 @@ def test_non_sd_generalized_em(): pc.update_parameters() - import pdb; pdb.set_trace() - if __name__ == "__main__": torch.manual_seed(129) - # test_non_sd() + test_non_sd() test_non_sd_generalized_em() \ No newline at end of file diff --git a/tests/optim/hmm_general_em_test.py b/tests/optim/hmm_general_em_test.py index 008ab4f9..210ceed5 100644 --- a/tests/optim/hmm_general_em_test.py +++ b/tests/optim/hmm_general_em_test.py @@ -213,12 +213,12 @@ def test_hmm_general_ll_fast(): pc = juice.compile(root_ns) pc.to(device) - best_valid_ll = train(pc, 10, train_loader, valid_loader, device) + best_valid_ll = train(pc, 20, train_loader, valid_loader, device) assert best_valid_ll > -92.0 if __name__ == "__main__": - test_hmm_general_ll() + # test_hmm_general_ll() test_hmm_general_ll_fast() test_hmm_general_ll_slow() \ No newline at end of file diff --git a/tests/structures/hclt_correctness_test.py b/tests/structures/hclt_correctness_test.py index aac42071..ccc9df2f 100644 --- a/tests/structures/hclt_correctness_test.py +++ b/tests/structures/hclt_correctness_test.py @@ -189,6 +189,102 @@ def test_hclt_single_layer_backward(): assert torch.all(torch.abs(fpars - pflows) < 3e-4 * batch_size) +def test_hclt_single_layer_backward_general_em(): + + device = torch.device("cuda:0") + + train_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = True, download = True) + + train_data = train_dataset.data.reshape(60000, 28*28)[:5000,:] + + num_features = train_data.size(1) + num_latents = 128 + + root_ns = juice.structures.HCLT( + train_data.float().to(device), + num_bins = 32, + sigma = 0.5 / 32, + num_latents = num_latents, + chunk_size = 32 + ) + root_ns.init_parameters() + + pc = juice.TensorCircuit(root_ns) + + pc.to(device) + + block_size = root_ns.chs[0].block_size + num_blocks = num_latents // block_size + + batch_data = train_data[:512,:].contiguous().to(device) + data_cpu = batch_data.cpu().long() + batch_size = batch_data.size(0) + + alpha = 2.0 + + pc.init_param_flows(flows_memory = 0.0) + + lls = pc(batch_data, propagation_alg = "GeneralLL", alpha = alpha) + pc.backward(batch_data.permute(1, 0), allow_modify_flows = False, + propagation_alg = "GeneralLL", alpha = alpha) + + pc.update_param_flows() + + for layer_id in range(1, len(pc.inner_layer_groups) - 2, 2): + + node_mars = pc.node_mars.clone() + node_flows = pc.node_flows.clone() + element_mars = pc.element_mars.clone() + element_flows = pc.element_flows.clone() + params = pc.params.clone() + param_flows = pc.param_flows.clone().zero_() + + my_layer = pc.inner_layer_groups[layer_id][0] + previous_layer = pc.inner_layer_groups[layer_id-1][0] + + previous_layer.forward(node_mars, element_mars, _for_backward = True) + + my_layer.backward(node_flows, element_flows, node_mars, element_mars, params, + param_flows = param_flows, allow_modify_flows = False, + propagation_alg = "GeneralLL", alpha = alpha) + + chids = my_layer.partitioned_chids[0] + parids = my_layer.partitioned_parids[0] + parpids = my_layer.partitioned_parpids[0] + + nids = my_layer.partitioned_nids[0] + cids = my_layer.partitioned_cids[0] + pids = my_layer.partitioned_pids[0] + pfids = my_layer.partitioned_pfids[0] + + for i in range(chids.size(0)): + eflows = torch.zeros([block_size, batch_size], dtype = torch.float32, device = device) + + for j in range(parids.size(1)): + nflows = node_flows[parids[i,j]:parids[i,j]+block_size,:] # [num_par_nodes, batch_size] + nmars = node_mars[parids[i,j]:parids[i,j]+block_size,:] # [num_par_nodes, batch_size] + emars = element_mars[chids[i]:chids[i]+block_size,:] # [num_ch_nodes, batch_size] + epars = params[parpids[i,j]:parpids[i,j]+block_size**2].reshape(block_size, block_size) # [num_ch_nodes, num_par_nodes] + fpars = param_flows[pfids[i,j]:pfids[i,j]+block_size**2].reshape(block_size, block_size) # [num_ch_nodes, num_par_nodes] + + curr_eflows = (nflows[None,:,:] * ((epars.log()[:,:,None] + emars[:,None,:] - nmars[None,:,:]) * alpha).exp()).sum(dim = 1) + eflows += curr_eflows + + assert torch.all(torch.abs(eflows - element_flows[chids[i]:chids[i]+block_size,:]) < 1e-3) + + for i in range(nids.size(0)): + for j in range(0, cids.size(1), block_size): + nflows = node_flows[nids[i]:nids[i]+block_size,:] # [num_par_nodes, batch_size] + nmars = node_mars[nids[i]:nids[i]+block_size,:] # [num_par_nodes, batch_size] + emars = element_mars[cids[i,j]:cids[i,j]+block_size,:] # [num_ch_nodes, batch_size] + epars = params[pids[i,j]:pids[i,j]+block_size**2].reshape(block_size, block_size) # [num_ch_nodes, num_par_nodes] + fpars = param_flows[pfids[i,j]:pfids[i,j]+block_size**2].reshape(block_size, block_size) # [num_ch_nodes, num_par_nodes] + + pflows = (nflows[None,:,:] * (epars.log()[:,:,None] + emars[:,None,:] - nmars[None,:,:]).exp()).sum(dim = 2) + + assert torch.all(torch.abs(fpars - pflows) < 3e-4 * batch_size) + + def test_hclt_backward(): device = torch.device("cuda:0") @@ -497,3 +593,4 @@ def test_hclt_em(): test_hclt_single_layer_backward() test_hclt_backward() test_hclt_em() + test_hclt_single_layer_backward_general_em() From 5b34c303ebf94c47f0d8e75f904b1215acf51797 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Wed, 13 Mar 2024 19:44:51 +0800 Subject: [PATCH 30/53] improve numerical stability of forward pass --- src/pyjuice/layer/sum_layer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index f00660cc..76524f42 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -503,7 +503,7 @@ def _fw_triton_block_sparse_tlmm_kernel(node_mars, element_mars, params, nids, c nmars = tl.dot(epars, emars_sub) acc = tl.where(emars_max > acc, - tl.log(nmars + tl.exp(acc - emars_max)) + emars_max, + tl.log(nmars + tl.exp(acc - emars_max) + 1e-12) + emars_max, tl.log(tl.exp(emars_max - acc) * nmars + 1.0) + acc ) @@ -619,7 +619,7 @@ def _fw_triton_block_sparse_csmm1_kernel(node_mars, element_mars, params, nids, nmars = tl.sum(epars[:,:,None] * emars_sub[None,:,:], axis = 1) acc = tl.where(emars_max > acc, - tl.log(nmars + tl.exp(acc - emars_max)) + emars_max, + tl.log(nmars + tl.exp(acc - emars_max) + 1e-12) + emars_max, tl.log(tl.exp(emars_max - acc) * nmars + 1.0) + acc ) @@ -729,7 +729,7 @@ def _fw_triton_block_sparse_csmm2_kernel(node_mars, element_mars, params, nids, nmars = tl.sum(epars[:,:,None] * tl.trans(emars_sub)[None,:,:], axis = 1) acc = tl.where(emars_max[None,:] > acc, - tl.log(nmars + tl.exp(acc - emars_max[None,:])) + emars_max[None,:], + tl.log(nmars + tl.exp(acc - emars_max[None,:]) + 1e-12) + emars_max[None,:], tl.log(tl.exp(emars_max[None,:] - acc) * nmars + 1.0) + acc ) From 9f13826c65dc0a49a6371c960cc428d8bd38995c Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Fri, 15 Mar 2024 06:16:29 +0800 Subject: [PATCH 31/53] temporarily fix conditional query --- src/pyjuice/queries/conditional.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/src/pyjuice/queries/conditional.py b/src/pyjuice/queries/conditional.py index 43ae22f8..9b9a833a 100644 --- a/src/pyjuice/queries/conditional.py +++ b/src/pyjuice/queries/conditional.py @@ -114,7 +114,7 @@ def _categorical_forward(layer, inputs: torch.Tensor, node_mars: torch.Tensor, @triton.jit def _categorical_backward_kernel(cat_probs_ptr, node_flows_ptr, local_ids_ptr, rev_vars_mapping_ptr, vids_ptr, psids_ptr, node_nchs_ptr, params_ptr, sid, eid, num_target_nodes, batch_size: tl.constexpr, - num_cats: tl.constexpr, BLOCK_SIZE: tl.constexpr): + num_cats: tl.constexpr, partial_eval: tl.constexpr, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis = 0) block_start = pid * BLOCK_SIZE @@ -123,7 +123,10 @@ def _categorical_backward_kernel(cat_probs_ptr, node_flows_ptr, local_ids_ptr, r # Get node offsets and batch offsets local_offsets = (offsets // batch_size) - local_node_offsets = tl.load(local_ids_ptr + local_offsets, mask = mask, other = 0) + if partial_eval == 1: + local_node_offsets = tl.load(local_ids_ptr + local_offsets, mask = mask, other = 0) + else: + local_node_offsets = local_offsets batch_offsets = (offsets % batch_size) global_node_offsets = local_node_offsets + sid @@ -182,20 +185,27 @@ def _categorical_backward(layer, inputs: torch.Tensor, node_flows: torch.Tensor, cat_probs = torch.zeros([num_target_vars * num_cats * batch_size], dtype = torch.float32, device = node_flows.device) - local_ids = layer.enable_partial_evaluation(bk_scopes = target_vars, return_ids = True).to(node_flows.device) - num_target_nodes = local_ids.size(0) + if len(target_vars) < num_vars: + local_ids = layer.enable_partial_evaluation(bk_scopes = target_vars, return_ids = True).to(node_flows.device) + num_target_nodes = local_ids.size(0) + partial_eval = 1 + else: + local_ids = None + num_target_nodes = eid - sid + partial_eval = 0 node_nchs = layer.metadata[layer.s_mids] grid = lambda meta: (triton.cdiv(num_target_nodes * batch_size, meta['BLOCK_SIZE']),) + _categorical_backward_kernel[grid]( cat_probs, node_flows, local_ids, rev_vars_mapping, layer.vids, layer.s_pids, node_nchs, layer.params, - sid, eid, num_target_nodes, batch_size, num_cats, BLOCK_SIZE = 512 + sid, eid, num_target_nodes, batch_size, num_cats, partial_eval = partial_eval, BLOCK_SIZE = 512 ) cat_probs = cat_probs.reshape(num_target_vars, num_cats, batch_size) - cat_probs /= cat_probs.sum(dim = 1, keepdim = True) + cat_probs /= (cat_probs.sum(dim = 1, keepdim = True) + 1e-12) cat_probs = cat_probs.permute(2, 0, 1) return cat_probs From 5b58875585d7afe11be7be0e7a235b83833a809d Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Fri, 15 Mar 2024 06:16:53 +0800 Subject: [PATCH 32/53] change eps to 1e-24 --- src/pyjuice/layer/sum_layer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index 76524f42..dfacf735 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -503,7 +503,7 @@ def _fw_triton_block_sparse_tlmm_kernel(node_mars, element_mars, params, nids, c nmars = tl.dot(epars, emars_sub) acc = tl.where(emars_max > acc, - tl.log(nmars + tl.exp(acc - emars_max) + 1e-12) + emars_max, + tl.log(nmars + tl.exp(acc - emars_max) + 1e-24) + emars_max, tl.log(tl.exp(emars_max - acc) * nmars + 1.0) + acc ) @@ -619,7 +619,7 @@ def _fw_triton_block_sparse_csmm1_kernel(node_mars, element_mars, params, nids, nmars = tl.sum(epars[:,:,None] * emars_sub[None,:,:], axis = 1) acc = tl.where(emars_max > acc, - tl.log(nmars + tl.exp(acc - emars_max) + 1e-12) + emars_max, + tl.log(nmars + tl.exp(acc - emars_max) + 1e-24) + emars_max, tl.log(tl.exp(emars_max - acc) * nmars + 1.0) + acc ) @@ -729,7 +729,7 @@ def _fw_triton_block_sparse_csmm2_kernel(node_mars, element_mars, params, nids, nmars = tl.sum(epars[:,:,None] * tl.trans(emars_sub)[None,:,:], axis = 1) acc = tl.where(emars_max[None,:] > acc, - tl.log(nmars + tl.exp(acc - emars_max[None,:]) + 1e-12) + emars_max[None,:], + tl.log(nmars + tl.exp(acc - emars_max[None,:]) + 1e-24) + emars_max[None,:], tl.log(tl.exp(emars_max[None,:] - acc) * nmars + 1.0) + acc ) From ab66287e470e0b1a3f707bb55542c5873f36e805 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Fri, 15 Mar 2024 06:17:29 +0800 Subject: [PATCH 33/53] allows deciding propagation alg per layer --- src/pyjuice/model/tensorcircuit.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/pyjuice/model/tensorcircuit.py b/src/pyjuice/model/tensorcircuit.py index a8848677..c41234b8 100644 --- a/src/pyjuice/model/tensorcircuit.py +++ b/src/pyjuice/model/tensorcircuit.py @@ -139,7 +139,7 @@ def set_propagation_alg(self, propagation_alg: str, **kwargs): def forward(self, inputs: torch.Tensor, input_layer_fn: Optional[Union[str,Callable]] = None, cache: Optional[dict] = None, return_cache: bool = False, record_cudagraph: bool = False, apply_cudagraph: bool = True, force_use_fp16: bool = False, force_use_fp32: bool = False, - propagation_alg: Optional[str] = None, **kwargs): + propagation_alg: Optional[Union[str,Sequence[str]]] = None, **kwargs): """ Forward evaluation of the PC. @@ -191,7 +191,7 @@ def forward(self, inputs: torch.Tensor, input_layer_fn: Optional[Union[str,Calla # Inner layers def _run_inner_layers(): - for layer_group in self.inner_layer_groups: + for layer_id, layer_group in enumerate(self.inner_layer_groups): if layer_group.is_prod(): # Prod layer layer_group(self.node_mars, self.element_mars) @@ -201,7 +201,8 @@ def _run_inner_layers(): layer_group(self.node_mars, self.element_mars, self.params, force_use_fp16 = force_use_fp16, force_use_fp32 = force_use_fp32, - propagation_alg = propagation_alg, **kwargs) + propagation_alg = propagation_alg if isinstance(propagation_alg, str) else propagation_alg[layer_id], + **kwargs) else: raise ValueError(f"Unknown layer type {type(layer)}.") @@ -273,7 +274,7 @@ def backward(self, inputs: Optional[torch.Tensor] = None, record_cudagraph: bool = False, apply_cudagraph: bool = True, allow_modify_flows: bool = True, - propagation_alg: str = "LL", + propagation_alg: Union[str,Sequence[str]] = "LL", **kwargs): """ Backward evaluation of the PC that computes node flows as well as parameter flows. @@ -348,7 +349,9 @@ def _run_inner_layers(): # Backward sum layer layer_group.backward(self.node_flows, self.element_flows, self.node_mars, self.element_mars, self.params, param_flows = self.param_flows if compute_param_flows else None, - allow_modify_flows = allow_modify_flows, propagation_alg = propagation_alg, **kwargs) + allow_modify_flows = allow_modify_flows, + propagation_alg = propagation_alg if isinstance(propagation_alg, str) else propagation_alg[layer_id], + **kwargs) else: raise ValueError(f"Unknown layer type {type(layer)}.") From e298620a00429bfaac1d66ca563616106545aed9 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Fri, 15 Mar 2024 06:17:44 +0800 Subject: [PATCH 34/53] add `num_ch_nodes` for sum nodes --- src/pyjuice/nodes/sum_nodes.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/pyjuice/nodes/sum_nodes.py b/src/pyjuice/nodes/sum_nodes.py index eb64e4ab..117a35bb 100644 --- a/src/pyjuice/nodes/sum_nodes.py +++ b/src/pyjuice/nodes/sum_nodes.py @@ -71,6 +71,10 @@ def num_edges(self): """ return self.edge_ids.size(1) * self.block_size * self.ch_block_size + @property + def num_ch_nodes(self): + return self.num_ch_node_blocks * self.ch_block_size + def duplicate(self, *args, tie_params: bool = False) -> SumNodes: """ Create a duplication of the current node with the same specification (i.e., number of nodes, block size). From c61463d12c6f0fdc17ce4aec72262185fb497cc0 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Fri, 15 Mar 2024 14:17:59 +0800 Subject: [PATCH 35/53] add log-space backward option for product layers --- src/pyjuice/layer/prod_layer.py | 157 ++++++++++++++++++++++++++------ 1 file changed, 129 insertions(+), 28 deletions(-) diff --git a/src/pyjuice/layer/prod_layer.py b/src/pyjuice/layer/prod_layer.py index 1e7b9534..fbbf35d4 100644 --- a/src/pyjuice/layer/prod_layer.py +++ b/src/pyjuice/layer/prod_layer.py @@ -8,6 +8,13 @@ import time from typing import Sequence, Optional +# In the latest triton, math functions were shuffled around into different modules: +# https://github.com/openai/triton/pull/3172 +if hasattr(tl.extra.cuda, "libdevice"): + tlmath = tl.extra.cuda.libdevic +else: + tlmath = tl.math + from pyjuice.nodes import ProdNodes from pyjuice.utils.parameter_list import FastParamList from pyjuice.utils.kernel_launcher import FastJITFunction @@ -249,7 +256,7 @@ def enable_partial_evaluation(self, fw_scopes: Optional[Sequence[BitSet]] = None @FastJITFunction def _forward_backward_kernel_3d(node_vals_ptr, element_vals_ptr, local_ids_ptr, nids_ptr, cids_ptr, tot_n_nodes, tot_n_eles, n_nblocks, num_edges: tl.constexpr, batch_size, BLOCK_M: tl.constexpr, BLOCK_B: tl.constexpr, - block_size: tl.constexpr, accum: tl.constexpr, partial_eval: tl.constexpr): + block_size: tl.constexpr, accum: tl.constexpr, partial_eval: tl.constexpr, prop_logsumexp: tl.constexpr): """ This kernel implements the function with 3d tensors. However, it only work with `triton==2.0.0`. """ @@ -264,7 +271,7 @@ def _forward_backward_kernel_3d(node_vals_ptr, element_vals_ptr, local_ids_ptr, ntile_id = pid_m % (block_size // BLOCK_M) # For partial evaluation - if partial_eval == 1: + if partial_eval: nblock_id = tl.load(local_ids_ptr + nblock_id) # Batch offsets and mask @@ -282,17 +289,36 @@ def _forward_backward_kernel_3d(node_vals_ptr, element_vals_ptr, local_ids_ptr, offs_evals = offs_egstart + block_nids[:,None] evals = tl.load(element_vals_ptr + offs_evals[None,:,:] * batch_size + offs_batch[:,None,None], mask = mask_batch[:,None,None]) - # Take the sum of the child nodes' log-probabilities - nvals = tl.sum(evals, axis = 2) + if prop_logsumexp: + # Take the logsumexp of the child nodes' values + evals_max = tl.max(evals, axis = 2) + nvals = tl.sum(tl.exp(evals - evals_max[:,:,None]), axis = 2) + evals_max + else: + # Take the sum of the child nodes' values + nvals = tl.sum(evals, axis = 2) # Node ids to `node_vals_ptr` nblock_start = tl.load(nids_ptr + nblock_id) offs_nvals = (nblock_start + block_nids[None,:]) * batch_size + offs_batch[:,None] # Accumulate the `node_vals` if required - if accum == 1: + if accum: node_vals = tl.load(node_vals_ptr + offs_nvals, mask = mask_batch[:,None], other = 0) - nvals += node_vals + if prop_logsumexp: + # logaddexp + diff = nvals - node_vals + nvals = tl.where( + diff == 0, + nvals + 0.69314718055994530942, # log(2) + tl.where( + diff > 0, + nvals + tlmath.log1p(tl.exp(-diff)), + node_vals + tlmath.log1p(tl.exp(diff)) + ) + ) + else: + # sum + nvals += node_vals tl.store(node_vals_ptr + offs_nvals, nvals, mask = mask_batch[:,None]) @@ -306,7 +332,7 @@ def _forward_backward_kernel_3d(node_vals_ptr, element_vals_ptr, local_ids_ptr, nblock_ids = offs_node // block_size # For partial evaluation - if partial_eval == 1: + if partial_eval: nblock_ids = tl.load(local_ids_ptr + nblock_ids, mask = mask_node) # Batch offsets and mask @@ -323,17 +349,36 @@ def _forward_backward_kernel_3d(node_vals_ptr, element_vals_ptr, local_ids_ptr, offs_evals = offs_egstart + block_nids[:,None] evals = tl.load(element_vals_ptr + offs_evals[None,:,:] * batch_size + offs_batch[:,None,None], mask = (mask_batch[:,None,None] & mask_node[None,:,None])) - # Take the sum of the child nodes' log-probabilities - nvals = tl.sum(evals, axis = 2) + if prop_logsumexp: + # Take the logsumexp of the child nodes' values + evals_max = tl.max(evals, axis = 2) + nvals = tl.sum(tl.exp(evals - evals_max[:,:,None]), axis = 2) + evals_max + else: + # Take the sum of the child nodes' values + nvals = tl.sum(evals, axis = 2) # Node ids to `node_vals_ptr` nblock_start = tl.load(nids_ptr + nblock_ids[None,:]) offs_nvals = (nblock_start + block_nids[None,:]) * batch_size + offs_batch[:,None] # Accumulate the `node_vals` if required - if accum == 1: + if accum: node_vals = tl.load(node_vals_ptr + offs_nvals, mask = mask_batch[:,None], other = 0) - nvals += node_vals + if prop_logsumexp: + # logaddexp + diff = nvals - node_vals + nvals = tl.where( + diff == 0, + nvals + 0.69314718055994530942, # log(2) + tl.where( + diff > 0, + nvals + tlmath.log1p(tl.exp(-diff)), + node_vals + tlmath.log1p(tl.exp(diff)) + ) + ) + else: + # sum + nvals += node_vals tl.store(node_vals_ptr + offs_nvals, nvals, mask = mask_batch[:,None]) @@ -342,7 +387,7 @@ def _forward_backward_kernel_3d(node_vals_ptr, element_vals_ptr, local_ids_ptr, @FastJITFunction def _forward_backward_kernel_2d(node_vals_ptr, element_vals_ptr, local_ids_ptr, nids_ptr, cids_ptr, tot_n_nodes, tot_n_eles, n_nblocks, num_edges: tl.constexpr, batch_size, BLOCK_M: tl.constexpr, BLOCK_B: tl.constexpr, - block_size: tl.constexpr, accum: tl.constexpr, partial_eval: tl.constexpr): + block_size: tl.constexpr, accum: tl.constexpr, partial_eval: tl.constexpr, prop_logsumexp: tl.constexpr): """ This kernel implements the function with 2d tensors. It works for all `triton` versions. """ @@ -355,7 +400,7 @@ def _forward_backward_kernel_2d(node_vals_ptr, element_vals_ptr, local_ids_ptr, ntile_id = pid_m % (block_size // BLOCK_M) # For partial evaluation - if partial_eval == 1: + if partial_eval: nblock_id = tl.load(local_ids_ptr + nblock_id) # Batch offsets and mask @@ -376,12 +421,34 @@ def _forward_backward_kernel_2d(node_vals_ptr, element_vals_ptr, local_ids_ptr, # Inner loop for i in range(0, BLOCK_M): evals = tl.load(element_vals_ptr + offs_evals, mask = mask_batch[None,:], other = 0) - nvals = tl.sum(evals, axis = 0) + + if prop_logsumexp: + # Take the logsumexp of the child nodes' values + evals_max = tl.max(evals, axis = 0) + nvals = tl.sum(tl.exp(evals - evals_max[:,:,None]), axis = 2) + evals_max + else: + # Take the sum of the child nodes' values + nvals = tl.sum(evals, axis = 0) # Accumulate the `node_vals` if required - if accum == 1: + if accum: node_vals = tl.load(node_vals_ptr + offs_nvals, mask = mask_batch) - nvals += node_vals + + if prop_logsumexp: + # logaddexp + diff = nvals - node_vals + nvals = tl.where( + diff == 0, + nvals + 0.69314718055994530942, # log(2) + tl.where( + diff > 0, + nvals + tlmath.log1p(tl.exp(-diff)), + node_vals + tlmath.log1p(tl.exp(diff)) + ) + ) + else: + # sum + nvals += node_vals tl.store(node_vals_ptr + offs_nvals, nvals, mask = mask_batch) @@ -393,7 +460,8 @@ def _forward_backward_kernel_2d(node_vals_ptr, element_vals_ptr, local_ids_ptr, @FastJITFunction def _forward_backward_kernel_large(node_vals_ptr, element_vals_ptr, local_ids_ptr, nids_ptr, cids_ptr, tot_n_nodes, tot_n_eles, n_nblocks, num_edges: tl.constexpr, batch_size, BLOCK_N: tl.constexpr, BLOCK_B: tl.constexpr, - N_NUM_BLKS: tl.constexpr, block_size: tl.constexpr, accum: tl.constexpr, partial_eval: tl.constexpr): + N_NUM_BLKS: tl.constexpr, block_size: tl.constexpr, accum: tl.constexpr, partial_eval: tl.constexpr, + prop_logsumexp: tl.constexpr): """ This kernel implements the function with 2d tensors. It is designed for nodes with many edges. """ @@ -406,7 +474,7 @@ def _forward_backward_kernel_large(node_vals_ptr, element_vals_ptr, local_ids_pt ntile_id = pid_m % block_size # For partial evaluation - if partial_eval == 1: + if partial_eval: nblock_id = tl.load(local_ids_ptr + nblock_id) # Batch offsets and mask @@ -425,11 +493,27 @@ def _forward_backward_kernel_large(node_vals_ptr, element_vals_ptr, local_ids_pt nblock_start = tl.load(nids_ptr + nblock_id) offs_nvals = (nblock_start + ntile_id) * batch_size + offs_batch # [BLOCK_B] + # Prepare buffer + if prop_logsumexp: + nvals = tl.zeros([BLOCK_B], dtype = tl.float32) - float("inf") + else: + nvals = tl.zeros([BLOCK_B], dtype = tl.float32) + # Inner loop - nvals = tl.zeros([BLOCK_B], dtype = tl.float32) for i in range(0, N_NUM_BLKS): evals = tl.load(element_vals_ptr + offs_evals, mask = (mask_edge[:,None] & mask_batch[None,:]), other = 0) - nvals += tl.sum(evals, axis = 0) + + if prop_logsumexp: + # Take the logsumexp of the child nodes' values + evals_max = tl.max(evals, axis = 0) + nvals_sub = tl.sum(tl.exp(evals - evals_max[None,:]), axis = 2) + nvals = tl.where(evals_max > nvals, + tl.log(nvals_sub + tl.exp(nvals - evals_max) + 1e-24) + evals_max, + tl.log(tl.exp(evals_max - nvals) * nvals_sub + 1.0) + nvals + ) + else: + # Take the sum of the child nodes' values + nvals += tl.sum(evals, axis = 0) offs_edge += BLOCK_N mask_edge = (offs_edge < num_edges) @@ -439,15 +523,30 @@ def _forward_backward_kernel_large(node_vals_ptr, element_vals_ptr, local_ids_pt offs_evals = (offs_egstart[:,None] + ntile_id) * batch_size + offs_batch[None,:] # [BLOCK_N, BLOCK_B] # Accumulate the `node_vals` if required - if accum == 1: + if accum: node_vals = tl.load(node_vals_ptr + offs_nvals, mask = mask_batch) - nvals += node_vals + + if prop_logsumexp: + # logaddexp + diff = nvals - node_vals + nvals = tl.where( + diff == 0, + nvals + 0.69314718055994530942, # log(2) + tl.where( + diff > 0, + nvals + tlmath.log1p(tl.exp(-diff)), + node_vals + tlmath.log1p(tl.exp(diff)) + ) + ) + else: + # sum + nvals += node_vals tl.store(node_vals_ptr + offs_nvals, nvals, mask = mask_batch) def _forward_backward(self, node_vals: torch.Tensor, element_vals: torch.Tensor, nids: torch.Tensor, cids: torch.Tensor, local_ids: Optional[torch.Tensor] = None, - accum: bool = False) -> None: + accum: bool = False, prop_logsumexp: bool = False) -> None: tot_n_nodes = node_vals.size(0) tot_n_eles = element_vals.size(0) n_nblocks = nids.size(0) if local_ids is None else local_ids.size(0) @@ -455,8 +554,7 @@ def _forward_backward(self, node_vals: torch.Tensor, element_vals: torch.Tensor, batch_size = node_vals.size(1) block_size = self.block_size - accum = 1 if accum else 0 - partial_eval = 1 if local_ids is not None else 0 + partial_eval = local_ids is not None assert num_edges & (num_edges - 1) == 0, "`num_edges` must be a power of 2." @@ -484,7 +582,8 @@ def _forward_backward(self, node_vals: torch.Tensor, element_vals: torch.Tensor, N_NUM_BLKS = triton.cdiv(num_edges, BLOCK_B), block_size = block_size, accum = accum, - partial_eval = partial_eval + partial_eval = partial_eval, + prop_logsumexp = prop_logsumexp ) return None @@ -511,7 +610,8 @@ def _forward_backward(self, node_vals: torch.Tensor, element_vals: torch.Tensor, BLOCK_B = BLOCK_B, block_size = block_size, accum = accum, - partial_eval = partial_eval + partial_eval = partial_eval, + prop_logsumexp = prop_logsumexp ) else: @@ -536,7 +636,8 @@ def _forward_backward(self, node_vals: torch.Tensor, element_vals: torch.Tensor, BLOCK_B = BLOCK_B, block_size = block_size, accum = accum, - partial_eval = partial_eval + partial_eval = partial_eval, + prop_logsumexp = prop_logsumexp ) return None From c5d1ff87ca002d0f30b9154de857310687b1c1fe Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Fri, 15 Mar 2024 16:01:39 +0800 Subject: [PATCH 36/53] log-space backward for block-sparse layers --- src/pyjuice/layer/sum_layer.py | 173 ++++++++++++++++++++++----------- 1 file changed, 114 insertions(+), 59 deletions(-) diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index dfacf735..06c2f667 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -1487,9 +1487,10 @@ def _backward_block_sparse(self, node_flows: torch.Tensor, element_flows: torch. def _bk_triton_block_sparse_ele_kernel(node_flows, element_flows, node_mars, element_mars, params, chids, parids_start, parids_increment, parpids_start, parpids_increment, local_ids, batch_size: tl.constexpr, partial_eval: tl.constexpr, ptr_inc_step: tl.constexpr, - allow_modify_flows: tl.constexpr, BLOCK_B: tl.constexpr, TILE_SIZE_K: tl.constexpr, - K_NUM_TILES: tl.constexpr, TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, TL_DOT: tl.constexpr, propagation_alg_id: tl.constexpr, alpha = 0.0): + allow_modify_flows: tl.constexpr, logspace_flows: tl.constexpr, BLOCK_B: tl.constexpr, + TILE_SIZE_K: tl.constexpr, K_NUM_TILES: tl.constexpr, TILE_SIZE_M: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, TL_DOT: tl.constexpr, + propagation_alg_id: tl.constexpr, alpha = 0.0): pid_b = tl.program_id(0) # ID of size-`BLOCK_B` batches pid_m = tl.program_id(1) # ID of size-`TILE_SIZE_M` nodes @@ -1538,7 +1539,10 @@ def _bk_triton_block_sparse_ele_kernel(node_flows, element_flows, node_mars, ele emars *= alpha # Inner loop - acc = tl.zeros([TILE_SIZE_M, BLOCK_B], dtype = tl.float32) + if logspace_flows: + acc = tl.zeros([TILE_SIZE_M, BLOCK_B], dtype = tl.float32) - float("inf") + else: + acc = tl.zeros([TILE_SIZE_M, BLOCK_B], dtype = tl.float32) for k in range(0, K_NUM_TILES): epars = tl.load(epars_ptr) # [TILE_SIZE_M, TILE_SIZE_K] @@ -1548,8 +1552,23 @@ def _bk_triton_block_sparse_ele_kernel(node_flows, element_flows, node_mars, ele nmars = tl.load(nmars_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_K, BLOCK_B] elpars = tl.log(tl.trans(epars)) # [TILE_SIZE_K, TILE_SIZE_M] - acc += tl.sum(tl.where(tl.abs(elpars[:,:,None] + emars[None,:,:] - nmars[:,None,:]) < 1e-6, nflows[:,None,:], 0.0), axis = 0) - + eflows = tl.sum(tl.where(tl.abs(elpars[:,:,None] + emars[None,:,:] - nmars[:,None,:]) < 1e-6, nflows[:,None,:], 0.0), axis = 0) + + if prop_logsumexp: + # logaddexp + diff = acc - eflows + acc = tl.where( + diff == 0, + acc + 0.69314718055994530942, # log(2) + tl.where( + diff > 0, + acc + tlmath.log1p(tl.exp(-diff)), + eflows + tlmath.log1p(tl.exp(diff)) + ) + ) + else: + # sum + acc += eflows else: if propagation_alg_id == 2: @@ -1561,11 +1580,18 @@ def _bk_triton_block_sparse_ele_kernel(node_flows, element_flows, node_mars, ele nflows = tl.load(nflows_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_K, BLOCK_B] nmars = tl.load(nmars_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_K, BLOCK_B] - if propagation_alg_id == 0: - log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars) - - if propagation_alg_id == 2: - log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars * alpha) + if logspace_flows: + if propagation_alg_id == 0: + log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), nflows - nmars) + + if propagation_alg_id == 2: + log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), nflows - nmars * alpha) + else: + if propagation_alg_id == 0: + log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars) + + if propagation_alg_id == 2: + log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars * alpha) log_n_fdm_max = tl.max(log_n_fdm, axis = 0)[None,:] n_fdm_sub = tl.where(log_n_fdm_max != -float("inf"), tl.exp(log_n_fdm - log_n_fdm_max), 0.0) @@ -1575,7 +1601,14 @@ def _bk_triton_block_sparse_ele_kernel(node_flows, element_flows, node_mars, ele else: partial_flows = tl.sum(epars[:,:,None] * n_fdm_sub[None,:,:], axis = 1) - acc += partial_flows * tl.exp(emars + log_n_fdm_max) + if logspace_flows: + partial_flows_max = emars + log_n_fdm_max + acc = tl.where(partial_flows_max > acc, + tl.log(partial_flows + tl.exp(acc - partial_flows_max) + 1e-24) + partial_flows_max, + tl.log(tl.exp(partial_flows_max - acc) * partial_flows + 1.0) + acc + ) + else: + acc += partial_flows * tl.exp(emars + log_n_fdm_max) # Increment `epars_ptr` parpids_inc = tl.load(parpids_inc_ptr) @@ -1599,9 +1632,10 @@ def _bk_triton_block_sparse_ele_kernel(node_flows, element_flows, node_mars, ele def _bk_triton_block_sparse_ele_csmm2_kernel(node_flows, element_flows, node_mars, element_mars, params, chids, parids_start, parids_increment, parpids_start, parpids_increment, local_ids, batch_size: tl.constexpr, partial_eval: tl.constexpr, ptr_inc_step: tl.constexpr, - allow_modify_flows: tl.constexpr, BLOCK_B: tl.constexpr, TILE_SIZE_K: tl.constexpr, - K_NUM_TILES: tl.constexpr, TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, TL_DOT: tl.constexpr, propagation_alg_id: tl.constexpr, alpha = 0.0): + allow_modify_flows: tl.constexpr, logspace_flows: tl.constexpr, BLOCK_B: tl.constexpr, + TILE_SIZE_K: tl.constexpr, K_NUM_TILES: tl.constexpr, TILE_SIZE_M: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, TL_DOT: tl.constexpr, + propagation_alg_id: tl.constexpr, alpha = 0.0): pid_b = tl.program_id(0) # ID of size-`BLOCK_B` batches pid_m = tl.program_id(1) # ID of size-`TILE_SIZE_M` nodes @@ -1650,7 +1684,10 @@ def _bk_triton_block_sparse_ele_csmm2_kernel(node_flows, element_flows, node_mar emars *= alpha # Inner loop - acc = tl.zeros([TILE_SIZE_M, BLOCK_B], dtype = tl.float32) + if logspace_flows: + acc = tl.zeros([TILE_SIZE_M, BLOCK_B], dtype = tl.float32) - float("inf") + else: + acc = tl.zeros([TILE_SIZE_M, BLOCK_B], dtype = tl.float32) for k in range(0, K_NUM_TILES): epars = tl.load(epars_ptr) # [TILE_SIZE_M, TILE_SIZE_K] @@ -1661,8 +1698,23 @@ def _bk_triton_block_sparse_ele_csmm2_kernel(node_flows, element_flows, node_mar elpars = tl.log(tl.trans(epars)) # [TILE_SIZE_K, TILE_SIZE_M] eflows = tl.sum(tl.where(tl.abs(elpars[None,:,:] + tl.trans(emars)[:,None,:] - nmars[:,:,None]) < 1e-6, nflows[:,:,None], 0.0), axis = 1) - - acc += tl.trans(eflows) + eflows = tl.trans(eflows) + + if prop_logsumexp: + # logaddexp + diff = acc - eflows + acc = tl.where( + diff == 0, + acc + 0.69314718055994530942, # log(2) + tl.where( + diff > 0, + acc + tlmath.log1p(tl.exp(-diff)), + eflows + tlmath.log1p(tl.exp(diff)) + ) + ) + else: + # sum + acc += eflows else: @@ -1675,18 +1727,32 @@ def _bk_triton_block_sparse_ele_csmm2_kernel(node_flows, element_flows, node_mar nflows = tl.load(nflows_ptr, mask = mask_batch[:,None]) # [BLOCK_B, TILE_SIZE_K] nmars = tl.load(nmars_ptr, mask = mask_batch[:,None]) # [BLOCK_B, TILE_SIZE_K] - if propagation_alg_id == 0: - log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars) - - if propagation_alg_id == 2: - log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars * alpha) + if logspace_flows: + if propagation_alg_id == 0: + log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), nflows - nmars) + + if propagation_alg_id == 2: + log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), nflows - nmars * alpha) + else: + if propagation_alg_id == 0: + log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars) + + if propagation_alg_id == 2: + log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars * alpha) log_n_fdm_max = tl.max(log_n_fdm, axis = 1) n_fdm_sub = tl.where(log_n_fdm_max[:,None] != -float("inf"), tl.exp(log_n_fdm - log_n_fdm_max[:,None]), 0.0) partial_flows = tl.sum(epars[:,:,None] * tl.trans(n_fdm_sub)[None,:,:], axis = 1) - acc += partial_flows * tl.exp(emars + log_n_fdm_max[None,:]) + if logspace_flows: + partial_flows_max = emars + log_n_fdm_max[None,:] + acc = tl.where(partial_flows_max > acc, + tl.log(partial_flows + tl.exp(acc - partial_flows_max) + 1e-24) + partial_flows_max, + tl.log(tl.exp(partial_flows_max - acc) * partial_flows + 1.0) + acc + ) + else: + acc += partial_flows * tl.exp(emars + log_n_fdm_max[None,:]) # Increment `epars_ptr` parpids_inc = tl.load(parpids_inc_ptr) @@ -1709,9 +1775,10 @@ def _backward_block_sparse_ele_flows(self, node_flows: torch.Tensor, element_flo element_mars: torch.Tensor, chids: torch.Tensor, parids: torch.Tensor, parpids: torch.Tensor, cs_block_size: int, local_ids: Optional[torch.Tensor] = None, partition_id: int = -1, allow_modify_flows: bool = False, - propagation_alg: str = "LL", **kwargs) -> None: + propagation_alg: str = "LL", logspace_flows: bool = False, **kwargs) -> None: assert params.dim() == 1, "Expecting a 1D `params`." + assert not (allow_modify_flows and logspace_flows), "`allow_modify_flows` should be set to `False` when using `logspace_flows`." num_nblocks = chids.size(0) if local_ids is None else local_ids.size(0) layer_n_nodes = num_nblocks * cs_block_size @@ -1804,6 +1871,7 @@ def _backward_block_sparse_ele_flows(self, node_flows: torch.Tensor, element_flo partial_eval = partial_eval, ptr_inc_step = ptr_inc_step, allow_modify_flows = allow_modify_flows, + logspace_flows = logspace_flows, BLOCK_B = BLOCK_B, TILE_SIZE_K = TILE_SIZE_K, K_NUM_TILES = K_NUM_TILES, @@ -1833,6 +1901,7 @@ def _backward_block_sparse_ele_flows(self, node_flows: torch.Tensor, element_flo partial_eval = partial_eval, ptr_inc_step = ptr_inc_step, allow_modify_flows = allow_modify_flows, + logspace_flows = logspace_flows, BLOCK_B = BLOCK_B, TILE_SIZE_K = TILE_SIZE_K, K_NUM_TILES = K_NUM_TILES, @@ -1853,9 +1922,9 @@ def _backward_block_sparse_ele_flows(self, node_flows: torch.Tensor, element_flo @FastJITFunction def _bk_triton_block_sparse_par_kernel(node_flows, node_mars, element_mars, params, param_flows, nids, cids, pids, pfids, batch_size: tl.constexpr, num_edges: tl.constexpr, allow_modify_flows: tl.constexpr, - TILE_SIZE_B: tl.constexpr, B_NUM_TILES: tl.constexpr, TILE_SIZE_K: tl.constexpr, - TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, TL_DOT: tl.constexpr, - propagation_alg_id: tl.constexpr, alpha = 0.0): + logspace_flows: tl.constexpr, TILE_SIZE_B: tl.constexpr, B_NUM_TILES: tl.constexpr, + TILE_SIZE_K: tl.constexpr, TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, + TL_DOT: tl.constexpr, propagation_alg_id: tl.constexpr, alpha = 0.0): pid_k = tl.program_id(0) # ID of size-`TILE_SIZE_K` edges pid_m = tl.program_id(1) # ID of size-`TILE_SIZE_M` nodes @@ -1890,7 +1959,7 @@ def _bk_triton_block_sparse_par_kernel(node_flows, node_mars, element_mars, para # Inner loop acc = tl.zeros([TILE_SIZE_M, TILE_SIZE_K], dtype = tl.float32) - + for b in range(0, B_NUM_TILES): emars = tl.load(emars_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_K] @@ -1903,25 +1972,19 @@ def _bk_triton_block_sparse_par_kernel(node_flows, node_mars, element_mars, para else: - # if propagation_alg_id == 2: - # emars *= alpha - nmars = tl.load(nmars_ptr, mask = mask_batch[None,:], other = 0.0) # [TILE_SIZE_M, TILE_SIZE_B] if allow_modify_flows == 1: log_n_fdm = tl.load(nflows_ptr, mask = mask_batch[None,:], other = -float("inf")) # [TILE_SIZE_M, TILE_SIZE_B] if propagation_alg_id == 2: - # nmars = tl.load(nmars_ptr, mask = mask_batch[None,:], other = 0.0) # [TILE_SIZE_M, TILE_SIZE_B] log_n_fdm += (alpha - 1.0) * nmars else: nflows = tl.load(nflows_ptr, mask = mask_batch[None,:], other = 0.0) # [TILE_SIZE_M, TILE_SIZE_B] - # nmars = tl.load(nmars_ptr, mask = mask_batch[None,:], other = 0.0) # [TILE_SIZE_M, TILE_SIZE_B] - if propagation_alg_id == 0: - log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars) - - if propagation_alg_id == 2: + if logspace_flows: + log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), nflows - nmars) + else: log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars) log_n_fdm_max = tl.max(log_n_fdm, axis = 0) @@ -1952,9 +2015,6 @@ def _bk_triton_block_sparse_par_kernel(node_flows, node_mars, element_mars, para epars_offsets = offs_node[:,None] + par_start[None,:] # [TILE_SIZE_M, TILE_SIZE_K] epars = tl.load(params + epars_offsets) - # if propagation_alg_id == 2: - # epars = tl.exp(tl.log(epars) * alpha) - if propagation_alg_id != 1: pflows = acc * epars else: @@ -1970,9 +2030,9 @@ def _bk_triton_block_sparse_par_kernel(node_flows, node_mars, element_mars, para @FastJITFunction def _bk_triton_block_sparse_par_csmm2_kernel(node_flows, node_mars, element_mars, params, param_flows, nids, cids, pids, pfids, batch_size: tl.constexpr, num_edges: tl.constexpr, allow_modify_flows: tl.constexpr, - TILE_SIZE_B: tl.constexpr, B_NUM_TILES: tl.constexpr, TILE_SIZE_K: tl.constexpr, - TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, TL_DOT: tl.constexpr, - propagation_alg_id: tl.constexpr, alpha = 0.0): + logspace_flows: tl.constexpr, TILE_SIZE_B: tl.constexpr, B_NUM_TILES: tl.constexpr, + TILE_SIZE_K: tl.constexpr, TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, + TL_DOT: tl.constexpr, propagation_alg_id: tl.constexpr, alpha = 0.0): pid_k = tl.program_id(0) # ID of size-`TILE_SIZE_K` edges pid_m = tl.program_id(1) # ID of size-`TILE_SIZE_M` nodes @@ -2010,32 +2070,26 @@ def _bk_triton_block_sparse_par_csmm2_kernel(node_flows, node_mars, element_mars for b in range(0, B_NUM_TILES): emars = tl.load(emars_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_K] + nmars = tl.load(nmars_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_M] if propagation_alg_id == 1: nflows = tl.load(nflows_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_M] - nmars = tl.load(nmars_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_M] acc += tl.sum(tl.where(tl.abs(elpars[None,:,:] + emars[:,None,:] - nmars[:,:,None]) < 1e-6, nflows[:,:,None], 0.0), axis = 0) else: - # if propagation_alg_id == 2: - # emars *= alpha - if allow_modify_flows == 1: log_n_fdm = tl.load(nflows_ptr, mask = mask_batch[:,None], other = -float("inf")) # [TILE_SIZE_B, TILE_SIZE_M] if propagation_alg_id == 2: - nmars = tl.load(nmars_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_M] log_n_fdm += (alpha - 1.0) * nmars else: nflows = tl.load(nflows_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_M] - nmars = tl.load(nmars_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_M] - if propagation_alg_id == 0: - log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars) - - if propagation_alg_id == 2: + if logspace_flows: + log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), nflows - nmars) + else: log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars) log_n_fdm_max = tl.max(log_n_fdm, axis = 1) @@ -2063,9 +2117,6 @@ def _bk_triton_block_sparse_par_csmm2_kernel(node_flows, node_mars, element_mars epars_offsets = offs_node[:,None] + par_start[None,:] # [TILE_SIZE_M, TILE_SIZE_K] epars = tl.load(params + epars_offsets) - # if propagation_alg_id == 2: - # epars = tl.exp(tl.log(epars) * alpha) - if propagation_alg_id != 1: pflows = acc * epars else: @@ -2079,7 +2130,8 @@ def _bk_triton_block_sparse_par_csmm2_kernel(node_flows, node_mars, element_mars def _backward_block_sparse_par_flows(self, node_flows: torch.Tensor, params: torch.Tensor, node_mars: torch.Tensor, element_mars: torch.Tensor, param_flows: torch.Tensor, nids: torch.Tensor, cids: torch.Tensor, pids: torch.Tensor, pfids: torch.Tensor, - allow_modify_flows: bool = False, propagation_alg: str = "LL", **kwargs) -> None: + allow_modify_flows: bool = False, propagation_alg: str = "LL", + logspace_flows: bool = False, **kwargs) -> None: """ Backward pass of sum layers w.r.t. sum parameters with the block-sparse processing kernel. @@ -2096,6 +2148,7 @@ def _backward_block_sparse_par_flows(self, node_flows: torch.Tensor, params: tor """ assert params.dim() == 1, "Expecting a 1D `params`." + assert not (allow_modify_flows and logspace_flows), "`allow_modify_flows` should be set to `False` when using `logspace_flows`." num_nblocks = nids.size(0) layer_n_nodes = num_nblocks * self.block_size @@ -2152,6 +2205,7 @@ def _backward_block_sparse_par_flows(self, node_flows: torch.Tensor, params: tor batch_size = batch_size, num_edges = num_edges, allow_modify_flows = allow_modify_flows, + logspace_flows = logspace_flows, TILE_SIZE_B = TILE_SIZE_B, B_NUM_TILES = B_NUM_TILES, TILE_SIZE_K = TILE_SIZE_K, @@ -2174,7 +2228,8 @@ def _backward_block_sparse_par_flows(self, node_flows: torch.Tensor, params: tor pfids = pfids, batch_size = batch_size, num_edges = num_edges, - allow_modify_flows = allow_modify_flows, + allow_modify_flows = allow_modify_flows, + logspace_flows = logspace_flows, TILE_SIZE_B = TILE_SIZE_B, B_NUM_TILES = B_NUM_TILES, TILE_SIZE_K = TILE_SIZE_K, From 917843d87edfb369b0c0d3e34f0bb6a331441ce9 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Fri, 15 Mar 2024 16:40:15 +0800 Subject: [PATCH 37/53] logspace backward for sparse layers + fixes --- src/pyjuice/layer/prod_layer.py | 8 ++- src/pyjuice/layer/sum_layer.py | 109 +++++++++++++++++++++----------- 2 files changed, 78 insertions(+), 39 deletions(-) diff --git a/src/pyjuice/layer/prod_layer.py b/src/pyjuice/layer/prod_layer.py index fbbf35d4..e83ab3c0 100644 --- a/src/pyjuice/layer/prod_layer.py +++ b/src/pyjuice/layer/prod_layer.py @@ -292,7 +292,7 @@ def _forward_backward_kernel_3d(node_vals_ptr, element_vals_ptr, local_ids_ptr, if prop_logsumexp: # Take the logsumexp of the child nodes' values evals_max = tl.max(evals, axis = 2) - nvals = tl.sum(tl.exp(evals - evals_max[:,:,None]), axis = 2) + evals_max + nvals = tl.log(tl.sum(tl.exp(evals - evals_max[:,:,None]), axis = 2)) + evals_max else: # Take the sum of the child nodes' values nvals = tl.sum(evals, axis = 2) @@ -304,6 +304,7 @@ def _forward_backward_kernel_3d(node_vals_ptr, element_vals_ptr, local_ids_ptr, # Accumulate the `node_vals` if required if accum: node_vals = tl.load(node_vals_ptr + offs_nvals, mask = mask_batch[:,None], other = 0) + if prop_logsumexp: # logaddexp diff = nvals - node_vals @@ -352,7 +353,7 @@ def _forward_backward_kernel_3d(node_vals_ptr, element_vals_ptr, local_ids_ptr, if prop_logsumexp: # Take the logsumexp of the child nodes' values evals_max = tl.max(evals, axis = 2) - nvals = tl.sum(tl.exp(evals - evals_max[:,:,None]), axis = 2) + evals_max + nvals = tl.log(tl.sum(tl.exp(evals - evals_max[:,:,None]), axis = 2)) + evals_max else: # Take the sum of the child nodes' values nvals = tl.sum(evals, axis = 2) @@ -364,6 +365,7 @@ def _forward_backward_kernel_3d(node_vals_ptr, element_vals_ptr, local_ids_ptr, # Accumulate the `node_vals` if required if accum: node_vals = tl.load(node_vals_ptr + offs_nvals, mask = mask_batch[:,None], other = 0) + if prop_logsumexp: # logaddexp diff = nvals - node_vals @@ -425,7 +427,7 @@ def _forward_backward_kernel_2d(node_vals_ptr, element_vals_ptr, local_ids_ptr, if prop_logsumexp: # Take the logsumexp of the child nodes' values evals_max = tl.max(evals, axis = 0) - nvals = tl.sum(tl.exp(evals - evals_max[:,:,None]), axis = 2) + evals_max + nvals = tl.log(tl.sum(tl.exp(evals - evals_max[:,:,None]), axis = 2)) + evals_max else: # Take the sum of the child nodes' values nvals = tl.sum(evals, axis = 0) diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index 06c2f667..61b4333e 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -1968,7 +1968,10 @@ def _bk_triton_block_sparse_par_kernel(node_flows, node_mars, element_mars, para nmars = tl.load(nmars_ptr, mask = mask_batch[None,:], other = 0.0) # [TILE_SIZE_M, TILE_SIZE_B] cond = tl.abs(elpars[:,None,:] + emars[None,:,:] - nmars[:,:,None]) < 1e-6 - acc += tl.sum(tl.where(cond, nflows[:,:,None], 0.0), axis = 1) + if logspace_flows: + acc += tl.sum(tl.where(cond, tl.exp(nflows[:,:,None]), 0.0), axis = 1) + else: + acc += tl.sum(tl.where(cond, nflows[:,:,None], 0.0), axis = 1) else: @@ -2075,7 +2078,10 @@ def _bk_triton_block_sparse_par_csmm2_kernel(node_flows, node_mars, element_mars if propagation_alg_id == 1: nflows = tl.load(nflows_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_M] - acc += tl.sum(tl.where(tl.abs(elpars[None,:,:] + emars[:,None,:] - nmars[:,:,None]) < 1e-6, nflows[:,:,None], 0.0), axis = 0) + if logspace_flows: + acc += tl.sum(tl.where(tl.abs(elpars[None,:,:] + emars[:,None,:] - nmars[:,:,None]) < 1e-6, tl.exp(nflows[:,:,None]), 0.0), axis = 0) + else: + acc += tl.sum(tl.where(tl.abs(elpars[None,:,:] + emars[:,None,:] - nmars[:,:,None]) < 1e-6, nflows[:,:,None], 0.0), axis = 0) else: @@ -2249,7 +2255,7 @@ def _backward_sparse(self, node_flows: torch.Tensor, element_flows: torch.Tensor chids: Optional[torch.Tensor], parids: Optional[torch.Tensor], parpids: Optional[torch.Tensor], cs_block_size: int, local_ids: Optional[torch.Tensor] = None, partition_id: int = -1, allow_modify_flows: bool = False, - propagation_alg: str = "LL", **kwargs) -> None: + propagation_alg: str = "LL", logspace_flows: bool = False, **kwargs) -> None: """ Back pass of sum layers with sparse processing kernel. @@ -2272,7 +2278,8 @@ def _backward_sparse(self, node_flows: torch.Tensor, element_flows: torch.Tensor chids = chids, parids = parids, parpids = parpids, cs_block_size = cs_block_size, local_ids = local_ids, allow_modify_flows = allow_modify_flows, - propagation_alg = propagation_alg, **kwargs + propagation_alg = propagation_alg, + logspace_flows = logspace_flows, **kwargs ) # Flows w.r.t. parameters @@ -2281,7 +2288,8 @@ def _backward_sparse(self, node_flows: torch.Tensor, element_flows: torch.Tensor node_flows, params, node_mars, element_mars, param_flows, nids = nids, cids = cids, pids = pids, pfids = pfids, allow_modify_flows = allow_modify_flows, - propagation_alg = propagation_alg, **kwargs + propagation_alg = propagation_alg, + logspace_flows = logspace_flows, **kwargs ) return None @@ -2291,8 +2299,9 @@ def _backward_sparse(self, node_flows: torch.Tensor, element_flows: torch.Tensor @FastJITFunction def _bk_triton_sparse_ele_kernel(node_flows, element_flows, node_mars, element_mars, params, chids, parids, parpids, local_ids, batch_size: tl.constexpr, partial_eval: tl.constexpr, - n_edge_blocks: tl.constexpr, allow_modify_flows: tl.constexpr, BLOCK_B: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, propagation_alg_id: tl.constexpr, alpha = 0.0): + n_edge_blocks: tl.constexpr, allow_modify_flows: tl.constexpr, logspace_flows: tl.constexpr, + BLOCK_B: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + propagation_alg_id: tl.constexpr, alpha = 0.0): pid_b = tl.program_id(0) # ID of size-`BLOCK_B` batches pid_m = tl.program_id(1) # ID of size-`BLOCK_M` nodes @@ -2347,9 +2356,10 @@ def _bk_triton_sparse_ele_kernel(node_flows, element_flows, node_mars, element_m eflows = tl.sum(tl.where(tl.abs(lpars[:,None] + emars[None,:] - nmars) < 1e-6, nflows, 0.0), axis = 0) else: + lpars = tl.log(epars) if propagation_alg_id == 2: - lpars = tl.log(epars) - epars = tl.exp(lpars * alpha) + lpars *= alpha + epars = tl.exp(lpars) if allow_modify_flows == 1: if propagation_alg_id == 0: @@ -2358,11 +2368,21 @@ def _bk_triton_sparse_ele_kernel(node_flows, element_flows, node_mars, element_m if propagation_alg_id == 2: eflows = tl.sum(epars[:,None] * tl.exp(emars[None,:] * alpha + log_n_fdm), axis = 0) else: - if propagation_alg_id == 0: - eflows = tl.sum(nflows * epars[:,None] * tl.exp(emars[None,:] - nmars), axis = 0) + if logspace_flows: + if propagation_alg_id == 0: + elflows = nflows + lpars[:,None] + emars[None,:] - nmars + + if propagation_alg_id == 2: + elflows = nflows + lpars[:,None] + (emars[None,:] - nmars) * alpha - if propagation_alg_id == 2: - eflows = tl.sum(nflows * epars[:,None] * tl.exp((emars[None,:] - nmars) * alpha), axis = 0) + elflows_max = tl.max(elflows, axis = 0) + eflows = tl.log(tl.sum(tl.exp(elflows - elflows_max[None,:]), axis = 0)) + elflows_max + else: + if propagation_alg_id == 0: + eflows = tl.sum(nflows * epars[:,None] * tl.exp(emars[None,:] - nmars), axis = 0) + + if propagation_alg_id == 2: + eflows = tl.sum(nflows * epars[:,None] * tl.exp((emars[None,:] - nmars) * alpha), axis = 0) tl.store(eflows_ptr, eflows, mask = mask_batch) @@ -2379,7 +2399,8 @@ def _bk_triton_sparse_ele_kernel(node_flows, element_flows, node_mars, element_m def _bk_triton_large_sparse_ele_kernel(node_flows, element_flows, node_mars, element_mars, params, chids, parids, parpids, local_ids, num_eles, pid_m_offset, batch_size: tl.constexpr, partial_eval: tl.constexpr, - n_edge_blocks: tl.constexpr, allow_modify_flows: tl.constexpr, BLOCK_B: tl.constexpr, + n_edge_blocks: tl.constexpr, allow_modify_flows: tl.constexpr, + logspace_flows: tl.constexpr, BLOCK_B: tl.constexpr, TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, propagation_alg_id: tl.constexpr, alpha = 0.0): @@ -2437,10 +2458,10 @@ def _bk_triton_large_sparse_ele_kernel(node_flows, element_flows, node_mars, ele eflows = tl.sum(tl.where(tl.abs(lpars[:,:,None] + emars[:,None,:] - nmars) < 1e-6, nflows, 0.0), axis = 1) else: - + lpars = tl.log(epars) if propagation_alg_id == 2: - lpars = tl.log(epars) - epars = tl.exp(lpars * alpha) + lpars *= alpha + epars = tl.exp(lpars) if allow_modify_flows == 1: if propagation_alg_id == 0: @@ -2449,12 +2470,21 @@ def _bk_triton_large_sparse_ele_kernel(node_flows, element_flows, node_mars, ele if propagation_alg_id == 2: eflows = tl.sum(epars[:,:,None] * tl.exp(emars[:,None,:] * alpha + log_n_fdm), axis = 1) else: - - if propagation_alg_id == 0: - eflows = tl.sum(nflows * epars[:,:,None] * tl.exp(emars[:,None,:] - nmars), axis = 1) + if logspace_flows: + if propagation_alg_id == 0: + elflows = nflows + lpars[:,:,None] + emars[:,None,:] - nmars - if propagation_alg_id == 2: - eflows = tl.sum(nflows * epars[:,:,None] * tl.exp((emars[:,None,:] - nmars) * alpha), axis = 0) + if propagation_alg_id == 2: + elflows = nflows + lpars[:,:,None] + (emars[:,None,:] - nmars) * alpha + + elflows_max = tl.max(elflows, axis = 1) + eflows = tl.log(tl.sum(tl.exp(elflows - elflows_max[:,None,:]), axis = 1)) + elflows_max + else: + if propagation_alg_id == 0: + eflows = tl.sum(nflows * epars[:,:,None] * tl.exp(emars[:,None,:] - nmars), axis = 1) + + if propagation_alg_id == 2: + eflows = tl.sum(nflows * epars[:,:,None] * tl.exp((emars[:,None,:] - nmars) * alpha), axis = 0) tl.store(eflows_ptr, eflows, mask = (mask_m[:,None] & mask_batch[None,:])) @@ -2462,9 +2492,11 @@ def _backward_sparse_ele_flows(self, node_flows: torch.Tensor, element_flows: to params: torch.Tensor, node_mars: torch.Tensor, element_mars: torch.Tensor, chids: torch.Tensor, parids: torch.Tensor, parpids: torch.Tensor, cs_block_size: int, local_ids: Optional[torch.Tensor] = None, - allow_modify_flows: bool = False, propagation_alg: str = "LL", **kwargs) -> None: + allow_modify_flows: bool = False, propagation_alg: str = "LL", + logspace_flows: bool = False, **kwargs) -> None: assert params.dim() == 1, "Expecting a 1D `params`." + assert not (allow_modify_flows and logspace_flows), "`allow_modify_flows` should be set to `False` when using `logspace_flows`." num_nblocks = chids.size(0) if local_ids is None else local_ids.size(0) layer_n_nodes = num_nblocks * cs_block_size @@ -2502,6 +2534,7 @@ def _backward_sparse_ele_flows(self, node_flows: torch.Tensor, element_flows: to partial_eval = 1 if local_ids is not None else 0, n_edge_blocks = n_edge_blocks, allow_modify_flows = allow_modify_flows, + logspace_flows = logspace_flows, BLOCK_B = BLOCK_B, BLOCK_M = BLOCK_M, BLOCK_SIZE_K = self.block_size, @@ -2535,6 +2568,7 @@ def _backward_sparse_ele_flows(self, node_flows: torch.Tensor, element_flows: to partial_eval = 1 if local_ids is not None else 0, n_edge_blocks = n_edge_blocks, allow_modify_flows = allow_modify_flows, + logspace_flows = logspace_flows, BLOCK_B = BLOCK_B, TILE_SIZE_M = TILE_SIZE_M, BLOCK_SIZE_M = cs_block_size, @@ -2565,6 +2599,7 @@ def _backward_sparse_ele_flows(self, node_flows: torch.Tensor, element_flows: to partial_eval = 1 if local_ids is not None else 0, n_edge_blocks = n_edge_blocks, allow_modify_flows = allow_modify_flows, + logspace_flows = logspace_flows, BLOCK_B = BLOCK_B, TILE_SIZE_M = TILE_SIZE_M, BLOCK_SIZE_M = cs_block_size, @@ -2580,7 +2615,7 @@ def _backward_sparse_ele_flows(self, node_flows: torch.Tensor, element_flows: to @FastJITFunction def _bk_triton_sparse_par_kernel(node_flows, node_mars, element_mars, params, param_flows, nids, cids, pids, pfids, pid_m_offset, num_edges: tl.constexpr, batch_size: tl.constexpr, allow_modify_flows: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_K: tl.constexpr, BLOCK_B: tl.constexpr, + logspace_flows: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_K: tl.constexpr, BLOCK_B: tl.constexpr, TILE_SIZE_B: tl.constexpr, B_NUM_BLOCKS: tl.constexpr, propagation_alg_id: tl.constexpr, alpha = 0.0): pid_b = tl.program_id(0) # ID of size-`BLOCK_B` samples @@ -2627,13 +2662,13 @@ def _bk_triton_sparse_par_kernel(node_flows, node_mars, element_mars, params, pa nmars = tl.load(nmars_ptr, mask = mask_batch, other = 0.0) # [BLOCK_B] nflows = tl.load(nflows_ptr, mask = mask_batch, other = 0.0) # [BLOCK_B] - acc += tl.sum(tl.where(tl.abs(elpars[:,None] + emars - nmars[None,:]) < 1e-6, nflows[None,:], 0.0), axis = 1) + if logspace_flows: + acc += tl.sum(tl.where(tl.abs(elpars[:,None] + emars - nmars[None,:]) < 1e-6, tl.exp(nflows[None,:]), 0.0), axis = 1) + else: + acc += tl.sum(tl.where(tl.abs(elpars[:,None] + emars - nmars[None,:]) < 1e-6, nflows[None,:], 0.0), axis = 1) else: - # if propagation_alg_id == 2: - # emars *= alpha - if allow_modify_flows == 1: log_n_fdm = tl.load(nflows_ptr, mask = mask_batch, other = -float("inf")) # [BLOCK_B] @@ -2647,10 +2682,11 @@ def _bk_triton_sparse_par_kernel(node_flows, node_mars, element_mars, params, pa nmars = tl.load(nmars_ptr, mask = mask_batch, other = 0.0) # [BLOCK_B] nflows = tl.load(nflows_ptr, mask = mask_batch, other = 0.0) # [BLOCK_B] - if propagation_alg_id == 0: - pflows = tl.sum(nflows[None,:] * tl.exp(emars - nmars[None,:]), axis = 1) - - if propagation_alg_id == 2: + if logspace_flows: + plflows = nflows[None,:] + emars - nmars[None,:] + plflows_max = tl.max(plflows, axis = 1) + pflows = tl.sum(tl.exp(plflows - plflows_max[:,None]), axis = 1) * tl.exp(plflows_max) + else: pflows = tl.sum(nflows[None,:] * tl.exp(emars - nmars[None,:]), axis = 1) acc += pflows @@ -2665,9 +2701,6 @@ def _bk_triton_sparse_par_kernel(node_flows, node_mars, element_mars, params, pa epars_ptr = params + par_start + tile_id epars = tl.load(epars_ptr) # [BLOCK_K] - # if propagation_alg_id == 2: - # epars = tl.exp(tl.log(epars) * alpha) - parflow_start = tl.load(pfids + nblock_id * num_edges + offs_edge) eparflows_ptr = param_flows + parflow_start + tile_id @@ -2681,7 +2714,8 @@ def _bk_triton_sparse_par_kernel(node_flows, node_mars, element_mars, params, pa def _backward_sparse_par_flows(self, node_flows: torch.Tensor, params: torch.Tensor, node_mars: torch.Tensor, element_mars: torch.Tensor, param_flows: torch.Tensor, nids: torch.Tensor, cids: torch.Tensor, pids: torch.Tensor, pfids: torch.Tensor, - allow_modify_flows: bool = False, propagation_alg: str = "LL", **kwargs) -> None: + allow_modify_flows: bool = False, propagation_alg: str = "LL", + logspace_flows: bool = False, **kwargs) -> None: """ Backward pass of sum layers w.r.t. sum parameters with the block-sparse processing kernel. @@ -2698,6 +2732,7 @@ def _backward_sparse_par_flows(self, node_flows: torch.Tensor, params: torch.Ten """ assert params.dim() == 1, "Expecting a 1D `params`." + assert not (allow_modify_flows and logspace_flows), "`allow_modify_flows` should be set to `False` when using `logspace_flows`." num_nblocks = nids.size(0) layer_n_nodes = num_nblocks * self.block_size @@ -2749,6 +2784,7 @@ def _backward_sparse_par_flows(self, node_flows: torch.Tensor, params: torch.Ten num_edges = num_edges, batch_size = batch_size, allow_modify_flows = allow_modify_flows, + logspace_flows = logspace_flows, BLOCK_M = BLOCK_M, BLOCK_K = BLOCK_K, BLOCK_B = BLOCK_B, @@ -2779,6 +2815,7 @@ def _backward_sparse_par_flows(self, node_flows: torch.Tensor, params: torch.Ten num_edges = num_edges, batch_size = batch_size, allow_modify_flows = allow_modify_flows, + logspace_flows = logspace_flows, BLOCK_M = BLOCK_M, BLOCK_K = BLOCK_K, BLOCK_B = BLOCK_B, From 59b89bc90369b84bd489e1ec49724cb21b1c95dc Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Fri, 15 Mar 2024 16:44:58 +0800 Subject: [PATCH 38/53] logspace backward with pytorch kernels --- src/pyjuice/layer/sum_layer.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index 61b4333e..afc015b6 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -2829,7 +2829,8 @@ def _backward_sparse_par_flows(self, node_flows: torch.Tensor, params: torch.Ten def _backward_pytorch(self, node_flows, element_flows, params, node_mars, element_mars, param_flows, nids, cids, pids, pfids, - chids, parids, parpids, cs_block_size): + chids, parids, parpids, cs_block_size, propagation_alg: str = "LL", + logspace_flows: bool = False): """ Back pass of sum layers with native pytorch. @@ -2845,18 +2846,20 @@ def _backward_pytorch(self, node_flows, element_flows, params, node_mars, `parpids`: [ng, c] """ + assert propagation_alg == "LL" + # Flows w.r.t. input elements (product nodes) if chids is not None: self._backward_pytorch_ele_kernel( node_flows, element_flows, params, node_mars, element_mars, - param_flows, chids, parids, parpids, cs_block_size + param_flows, chids, parids, parpids, cs_block_size, logspace_flows ) # Flows w.r.t. parameters if param_flows is not None and nids is not None: self._backward_pytorch_par_kernel( node_flows, params, node_mars, element_mars, param_flows, - nids, cids, pids, pfids, self.block_size + nids, cids, pids, pfids, self.block_size, logspace_flows ) @torch.compile @@ -2864,7 +2867,7 @@ def _backward_pytorch_ele_kernel(self, node_flows: torch.Tensor, element_flows: params: torch.Tensor, node_mars: torch.Tensor, element_mars: torch.Tensor, param_flows: Optional[torch.Tensor], chids: torch.Tensor, parids: torch.Tensor, parpids: torch.Tensor, - cs_block_size: int): + cs_block_size: int, logspace_flows: bool): num_nblocks = chids.size(0) num_eblocks = parids.size(1) @@ -2878,15 +2881,20 @@ def _backward_pytorch_ele_kernel(self, node_flows: torch.Tensor, element_flows: num_nblocks * cs_block_size, num_eblocks * self.block_size ) - element_flows[chids] = (node_flows[parids] * params[parpids].unsqueeze(-1) * \ - (element_mars[chids].unsqueeze(1) - node_mars[parids]).exp()).sum(dim = 1) + if logspace_flows: + element_flows[chids] = (node_flows[parids] + params[parpids].log().unsqueeze(-1) + \ + element_mars[chids].unsqueeze(1) - node_mars[parids]).logsumexp(dim = 1) + else: + element_flows[chids] = (node_flows[parids] * params[parpids].unsqueeze(-1) * \ + (element_mars[chids].unsqueeze(1) - node_mars[parids]).exp()).sum(dim = 1) return None @torch.compile def _backward_pytorch_par_kernel(self, node_flows: torch.Tensor, params: torch.Tensor, node_mars: torch.Tensor, element_mars: torch.Tensor, param_flows: torch.Tensor, nids: torch.Tensor, - cids: torch.Tensor, pids: torch.Tensor, pfids: torch.Tensor, ns_block_size: int): + cids: torch.Tensor, pids: torch.Tensor, pfids: torch.Tensor, ns_block_size: int, + logspace_flows: bool): num_nblocks = nids.size(0) num_edges = cids.size(1) @@ -2898,7 +2906,10 @@ def _backward_pytorch_par_kernel(self, node_flows: torch.Tensor, params: torch.T pfids = (pfids[:,None,:].repeat(1, self.block_size, 1) + \ torch.arange(0, self.block_size, device = cids.device)[None,:,None]).reshape(num_nblocks * self.block_size, num_edges) - parflows = (node_flows[nids].unsqueeze(1) * params[pids].unsqueeze(-1) * (element_mars[cids] - node_mars[nids].unsqueeze(1)).exp()).sum(dim = 2) + if logspace_flows: + parflows = (node_flows[nids].exp().unsqueeze(1) * params[pids].unsqueeze(-1) * (element_mars[cids] - node_mars[nids].unsqueeze(1)).exp()).sum(dim = 2) + else: + parflows = (node_flows[nids].unsqueeze(1) * params[pids].unsqueeze(-1) * (element_mars[cids] - node_mars[nids].unsqueeze(1)).exp()).sum(dim = 2) for i in range(num_nblocks): sid, eid = ns_block_size * i, ns_block_size * (i + 1) From e0e6f9b8319324874a15d29d8010bbbc7e99a8ce Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Sun, 17 Mar 2024 23:46:48 +0800 Subject: [PATCH 39/53] update cudagraph signature --- src/pyjuice/model/tensorcircuit.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/pyjuice/model/tensorcircuit.py b/src/pyjuice/model/tensorcircuit.py index c41234b8..c3a75bbc 100644 --- a/src/pyjuice/model/tensorcircuit.py +++ b/src/pyjuice/model/tensorcircuit.py @@ -275,6 +275,7 @@ def backward(self, inputs: Optional[torch.Tensor] = None, apply_cudagraph: bool = True, allow_modify_flows: bool = True, propagation_alg: Union[str,Sequence[str]] = "LL", + logspace_flows: bool = False, **kwargs): """ Backward evaluation of the PC that computes node flows as well as parameter flows. @@ -299,21 +300,24 @@ def backward(self, inputs: Optional[torch.Tensor] = None, ## Initialize buffers for backward pass ## - self._init_buffer(name = "node_flows", shape = (self.num_nodes, B), set_value = 0.0) - self._init_buffer(name = "element_flows", shape = (self.num_elements, B), set_value = 0.0) + self._init_buffer(name = "node_flows", shape = (self.num_nodes, B), set_value = 0.0 if not logspace_flows else -float("inf")) + self._init_buffer(name = "element_flows", shape = (self.num_elements, B), set_value = 0.0 if not logspace_flows else -float("inf")) # Set root node flows def _set_root_node_flows(): nonlocal ll_weights + nonlocal logspace_flows if ll_weights is None: - self.node_flows[self._root_node_range[0]:self._root_node_range[1],:] = 1.0 + root_flows = 1.0 if not logspace_flows else 0.0 + self.node_flows[self._root_node_range[0]:self._root_node_range[1],:] = root_flows else: if ll_weights.dim() == 1: ll_weights = ll_weights.unsqueeze(1) assert ll_weights.size(0) == self.num_root_nodes - self.node_flows[self._root_node_range[0]:self._root_node_range[1],:] = ll_weights + root_flows = ll_weights if not logspace_flows else ll_weights.log() + self.node_flows[self._root_node_range[0]:self._root_node_range[1],:] = root_flows _set_root_node_flows() @@ -338,7 +342,7 @@ def _run_inner_layers(): if layer_group.is_prod(): # Prod layer - layer_group.backward(self.node_flows, self.element_flows) + layer_group.backward(self.node_flows, self.element_flows, logspace_flows = logspace_flows) elif layer_group.is_sum(): # Sum layer @@ -351,12 +355,13 @@ def _run_inner_layers(): param_flows = self.param_flows if compute_param_flows else None, allow_modify_flows = allow_modify_flows, propagation_alg = propagation_alg if isinstance(propagation_alg, str) else propagation_alg[layer_id], + logspace_flows = logspace_flows, **kwargs) else: raise ValueError(f"Unknown layer type {type(layer)}.") - signature = (1, id(self.node_flows), id(self.element_flows), id(self.node_mars), id(self.element_mars), id(self.params), id(self.param_flows), B) + signature = (1, id(self.node_flows), id(self.element_flows), id(self.node_mars), id(self.element_mars), id(self.params), id(self.param_flows), B, allow_modify_flows, logspace_flows) if record_cudagraph and signature not in self._recorded_cuda_graphs: # Warmup s = torch.cuda.Stream() @@ -387,14 +392,14 @@ def _run_inner_layers(): # Compute backward pass for all input layers for idx, layer in enumerate(self.input_layer_group): if input_layer_fn is None: - layer.backward(inputs, self.node_flows, self.node_mars, **kwargs) + layer.backward(inputs, self.node_flows, self.node_mars, logspace_flows = logspace_flows, **kwargs) elif isinstance(input_layer_fn, str): assert hasattr(layer, input_layer_fn), f"Custom input function `{input_layer_fn}` not found for layer type {type(layer)}." - getattr(layer, input_layer_fn)(inputs, self.node_flows, self.node_mars, **kwargs) + getattr(layer, input_layer_fn)(inputs, self.node_flows, self.node_mars, logspace_flows = logspace_flows, **kwargs) elif isinstance(input_layer_fn, Callable): - input_layer_fn(layer, inputs, self.node_flows, self.node_mars, **kwargs) + input_layer_fn(layer, inputs, self.node_flows, self.node_mars, logspace_flows = logspace_flows, **kwargs) else: raise ValueError(f"Custom input function should be either a `str` or a `Callable`. Found {type(input_layer_fn)} instead.") From 0db64312bb007b93752ac590953413fc1182cf92 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Sun, 17 Mar 2024 23:52:52 +0800 Subject: [PATCH 40/53] support logspace backward for input layers --- src/pyjuice/layer/input_layer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/pyjuice/layer/input_layer.py b/src/pyjuice/layer/input_layer.py index 41ce3c8d..fae2c44c 100644 --- a/src/pyjuice/layer/input_layer.py +++ b/src/pyjuice/layer/input_layer.py @@ -302,7 +302,8 @@ def forward(self, data: torch.Tensor, node_mars: torch.Tensor, params: Optional[ raise NotImplementedError("CPU forward fn for input nodes is not implemented.") def backward(self, data: torch.Tensor, node_flows: torch.Tensor, - node_mars: torch.Tensor, params: Optional[Dict] = None, **kwargs): + node_mars: torch.Tensor, params: Optional[Dict] = None, + logspace_flows: bool = False, **kwargs): """ data: [num_vars, B] node_flows: [num_nodes, B] @@ -357,6 +358,7 @@ def backward(self, data: torch.Tensor, node_flows: torch.Tensor, node_offset = node_offset, BLOCK_SIZE = BLOCK_SIZE, partial_eval = 1 if bk_local_ids is not None else 0, + logspace_flows = logspace_flows, num_warps = 8 ) @@ -683,7 +685,7 @@ def _fw_missing_mask_kernel(missing_mask_ptr, node_mars_ptr, vids_ptr, fw_local_ @staticmethod def _flows_kernel_template(flow_fn, params_ptr, param_flows_ptr, node_flows_ptr, node_mars_ptr, data_ptr, vids_ptr, s_pids_ptr, s_pfids_ptr, - metadata_ptr, s_mids_ptr, bk_local_ids_ptr, partial_eval: tl.constexpr, layer_num_nodes: tl.constexpr, + metadata_ptr, s_mids_ptr, bk_local_ids_ptr, partial_eval: tl.constexpr, logspace_flows: tl.constexpr, layer_num_nodes: tl.constexpr, batch_size: tl.constexpr, num_vars_per_node: tl.constexpr, nv_block_size: tl.constexpr, node_offset: tl.constexpr, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis = 0) @@ -722,6 +724,9 @@ def _flows_kernel_template(flow_fn, params_ptr, param_flows_ptr, node_flows_ptr, ns_offsets = (local_offsets + node_offset) * batch_size + batch_offsets flows = tl.load(node_flows_ptr + ns_offsets, mask = mask, other = 0) + if logspace_flows: + flows = tl.exp(flows) + flow_fn(local_offsets, ns_offsets, data, flows, node_mars_ptr, params_ptr, param_flows_ptr, s_pids, s_pfids, metadata_ptr, s_mids_ptr, mask, num_vars_per_node, BLOCK_SIZE) From 6de9045021acc8a23a3b5d6bd59f12321d048d3e Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Sun, 17 Mar 2024 23:53:16 +0800 Subject: [PATCH 41/53] receive `logspace_flows` in `ProdLayer` --- src/pyjuice/layer/prod_layer.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/pyjuice/layer/prod_layer.py b/src/pyjuice/layer/prod_layer.py index e83ab3c0..5a8e106c 100644 --- a/src/pyjuice/layer/prod_layer.py +++ b/src/pyjuice/layer/prod_layer.py @@ -203,7 +203,7 @@ def forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, _for_back return None - def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, **kwargs) -> None: + def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, logspace_flows: bool = False, **kwargs) -> None: """ Computes the backward pass of a product layer: ``` @@ -222,7 +222,8 @@ def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, **kwar parids = self.partitioned_parids[partition_id] local_ids = self.bk_partition_local_ids[partition_id] - self._forward_backward(node_flows, element_flows, u_cids, parids, local_ids = local_ids, accum = True) + self._forward_backward(node_flows, element_flows, u_cids, parids, local_ids = local_ids, accum = True, + prop_logsumexp = logspace_flows) else: # Evaluate the whole layer @@ -230,7 +231,8 @@ def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, **kwar u_cids = self.partitioned_u_cids[partition_id] parids = self.partitioned_parids[partition_id] - self._forward_backward(node_flows, element_flows, u_cids, parids, accum = True) + self._forward_backward(node_flows, element_flows, u_cids, parids, accum = True, + prop_logsumexp = logspace_flows) return None @@ -427,7 +429,7 @@ def _forward_backward_kernel_2d(node_vals_ptr, element_vals_ptr, local_ids_ptr, if prop_logsumexp: # Take the logsumexp of the child nodes' values evals_max = tl.max(evals, axis = 0) - nvals = tl.log(tl.sum(tl.exp(evals - evals_max[:,:,None]), axis = 2)) + evals_max + nvals = tl.log(tl.sum(tl.exp(evals - evals_max[None,:]), axis = 0)) + evals_max else: # Take the sum of the child nodes' values nvals = tl.sum(evals, axis = 0) From d1e7c02881cadeaa7ff1cc97f402552c7ff18102 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Sun, 17 Mar 2024 23:53:44 +0800 Subject: [PATCH 42/53] receive input `logspace_flows` for `SumLayer` --- src/pyjuice/layer/sum_layer.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index afc015b6..d3ce78db 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -255,7 +255,8 @@ def forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, params: t def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, node_mars: torch.Tensor, element_mars: torch.Tensor, params: torch.Tensor, param_flows: Optional[torch.Tensor] = None, - allow_modify_flows: bool = False, propagation_alg: str = "LL", **kwargs) -> None: + allow_modify_flows: bool = False, propagation_alg: str = "LL", + logspace_flows: bool = False, **kwargs) -> None: """ Computes the forward pass of a sum layer: ``` @@ -276,6 +277,8 @@ def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, `params`: [num_params, B] or [num_params] """ + assert not (allow_modify_flows and logspace_flows), "`allow_modify_flows` should be set to `False` when using `logspace_flows`." + # Disallow modifications of `node_flows` in case of partial evaluation if self.provided("bk_partition_local_ids") and allow_modify_flows: allow_modify_flows = False @@ -308,6 +311,7 @@ def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, partition_id = partition_id, allow_modify_flows = allow_modify_flows, propagation_alg = propagation_alg, + logspace_flows = logspace_flows, **kwargs ) @@ -328,6 +332,7 @@ def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, partition_id = partition_id, allow_modify_flows = allow_modify_flows, propagation_alg = propagation_alg, + logspace_flows = logspace_flows, **kwargs ) @@ -346,6 +351,7 @@ def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, partition_id = partition_id, allow_modify_flows = allow_modify_flows, propagation_alg = propagation_alg, + logspace_flows = logspace_flows, **kwargs ) @@ -1211,7 +1217,8 @@ def _backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, cs_block_size: int = 0, local_ids: Optional[torch.Tensor] = None, partition_id: int = -1, mode: Optional[str] = None, allow_modify_flows: bool = False, - propagation_alg: str = "LL", **kwargs) -> None: + propagation_alg: str = "LL", + logspace_flows: bool = False, **kwargs) -> None: """ Back pass of sum layers. @@ -1256,7 +1263,7 @@ def _backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, node_flows, element_flows, params, node_mars, element_mars, param_flows, nids, cids, pids, pfids, chids, parids, parpids, cs_block_size, local_ids, partition_id = partition_id, allow_modify_flows = allow_modify_flows, - propagation_alg = propagation_alg, **kwargs + propagation_alg = propagation_alg, logspace_flows = logspace_flows, **kwargs ) elif mode == self.SPARSE: @@ -1264,7 +1271,7 @@ def _backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, node_flows, element_flows, params, node_mars, element_mars, param_flows, nids, cids, pids, pfids, chids, parids, parpids, cs_block_size, local_ids, partition_id = partition_id, allow_modify_flows = allow_modify_flows, - propagation_alg = propagation_alg, **kwargs + propagation_alg = propagation_alg, logspace_flows = logspace_flows, **kwargs ) elif mode == self.PYTORCH: @@ -1444,7 +1451,8 @@ def _backward_block_sparse(self, node_flows: torch.Tensor, element_flows: torch. nids: Optional[torch.Tensor], cids: Optional[torch.Tensor], pids: Optional[torch.Tensor], pfids: Optional[torch.Tensor], chids: Optional[torch.Tensor], parids: Optional[torch.Tensor], parpids: Optional[torch.Tensor], cs_block_size: int, local_ids: Optional[torch.Tensor] = None, - partition_id: int = -1, allow_modify_flows: bool = False, propagation_alg: str = "LL", **kwargs) -> None: + partition_id: int = -1, allow_modify_flows: bool = False, propagation_alg: str = "LL", + logspace_flows: bool = False, **kwargs) -> None: """ Back pass of sum layers with block-sparse processing kernel. @@ -1467,7 +1475,8 @@ def _backward_block_sparse(self, node_flows: torch.Tensor, element_flows: torch. chids = chids, parids = parids, parpids = parpids, cs_block_size = cs_block_size, local_ids = local_ids, partition_id = partition_id, allow_modify_flows = allow_modify_flows, - propagation_alg = propagation_alg, **kwargs + propagation_alg = propagation_alg, + logspace_flows = logspace_flows, **kwargs ) # Flows w.r.t. parameters @@ -1476,7 +1485,8 @@ def _backward_block_sparse(self, node_flows: torch.Tensor, element_flows: torch. node_flows, params, node_mars, element_mars, param_flows, nids = nids, cids = cids, pids = pids, pfids = pfids, allow_modify_flows = allow_modify_flows, - propagation_alg = propagation_alg, **kwargs + propagation_alg = propagation_alg, + logspace_flows = logspace_flows, **kwargs ) return None @@ -1554,7 +1564,7 @@ def _bk_triton_block_sparse_ele_kernel(node_flows, element_flows, node_mars, ele eflows = tl.sum(tl.where(tl.abs(elpars[:,:,None] + emars[None,:,:] - nmars[:,None,:]) < 1e-6, nflows[:,None,:], 0.0), axis = 0) - if prop_logsumexp: + if logspace_flows: # logaddexp diff = acc - eflows acc = tl.where( @@ -1700,7 +1710,7 @@ def _bk_triton_block_sparse_ele_csmm2_kernel(node_flows, element_flows, node_mar eflows = tl.sum(tl.where(tl.abs(elpars[None,:,:] + tl.trans(emars)[:,None,:] - nmars[:,:,None]) < 1e-6, nflows[:,:,None], 0.0), axis = 1) eflows = tl.trans(eflows) - if prop_logsumexp: + if logspace_flows: # logaddexp diff = acc - eflows acc = tl.where( @@ -1778,7 +1788,6 @@ def _backward_block_sparse_ele_flows(self, node_flows: torch.Tensor, element_flo propagation_alg: str = "LL", logspace_flows: bool = False, **kwargs) -> None: assert params.dim() == 1, "Expecting a 1D `params`." - assert not (allow_modify_flows and logspace_flows), "`allow_modify_flows` should be set to `False` when using `logspace_flows`." num_nblocks = chids.size(0) if local_ids is None else local_ids.size(0) layer_n_nodes = num_nblocks * cs_block_size @@ -2154,7 +2163,6 @@ def _backward_block_sparse_par_flows(self, node_flows: torch.Tensor, params: tor """ assert params.dim() == 1, "Expecting a 1D `params`." - assert not (allow_modify_flows and logspace_flows), "`allow_modify_flows` should be set to `False` when using `logspace_flows`." num_nblocks = nids.size(0) layer_n_nodes = num_nblocks * self.block_size @@ -2496,7 +2504,6 @@ def _backward_sparse_ele_flows(self, node_flows: torch.Tensor, element_flows: to logspace_flows: bool = False, **kwargs) -> None: assert params.dim() == 1, "Expecting a 1D `params`." - assert not (allow_modify_flows and logspace_flows), "`allow_modify_flows` should be set to `False` when using `logspace_flows`." num_nblocks = chids.size(0) if local_ids is None else local_ids.size(0) layer_n_nodes = num_nblocks * cs_block_size @@ -2732,7 +2739,6 @@ def _backward_sparse_par_flows(self, node_flows: torch.Tensor, params: torch.Ten """ assert params.dim() == 1, "Expecting a 1D `params`." - assert not (allow_modify_flows and logspace_flows), "`allow_modify_flows` should be set to `False` when using `logspace_flows`." num_nblocks = nids.size(0) layer_n_nodes = num_nblocks * self.block_size From 7bd6f748490c1240c639fb2e5f5b06901e1723c3 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Sun, 17 Mar 2024 23:53:56 +0800 Subject: [PATCH 43/53] add runtests for logspace flows --- tests/structures/hclt_test.py | 75 ++++++++++- tests/structures/logspace_flows_test.py | 170 ++++++++++++++++++++++++ 2 files changed, 239 insertions(+), 6 deletions(-) create mode 100644 tests/structures/logspace_flows_test.py diff --git a/tests/structures/hclt_test.py b/tests/structures/hclt_test.py index 4377da58..089fd0c8 100644 --- a/tests/structures/hclt_test.py +++ b/tests/structures/hclt_test.py @@ -19,7 +19,7 @@ def evaluate(pc, loader): return lls_total -def mini_batch_em_epoch(num_epochs, pc, optimizer, scheduler, train_loader, test_loader, device): +def mini_batch_em_epoch(num_epochs, pc, optimizer, scheduler, train_loader, test_loader, device, logspace_flows = False): for epoch in range(num_epochs): t0 = time.time() train_ll = 0.0 @@ -29,7 +29,10 @@ def mini_batch_em_epoch(num_epochs, pc, optimizer, scheduler, train_loader, test optimizer.zero_grad() lls = pc(x) - lls.mean().backward() + if not logspace_flows: + lls.mean().backward() + else: + pc.backward(x.permute(1, 0), allow_modify_flows = False, logspace_flows = True) train_ll += lls.mean().detach().cpu().numpy().item() @@ -125,6 +128,65 @@ def test_hclt(): assert test_ll > -785 +def test_hclt_logspace_flows(): + + device = torch.device("cuda:0") + + train_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = True, download = True) + test_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = False, download = True) + + train_data = train_dataset.data.reshape(60000, 28*28) + test_data = test_dataset.data.reshape(10000, 28*28) + + num_features = train_data.size(1) + + train_loader = DataLoader( + dataset = TensorDataset(train_data), + batch_size = 512, + shuffle = True, + drop_last = True + ) + test_loader = DataLoader( + dataset = TensorDataset(test_data), + batch_size = 512, + shuffle = False, + drop_last = True + ) + + ns = juice.structures.HCLT( + train_data.float().to(device), + num_bins = 32, + sigma = 0.5 / 32, + num_latents = 128, + chunk_size = 32 + ) + ns.init_parameters(perturbation = 2.0) + pc = juice.TensorCircuit(ns) + + pc.to(device) + + optimizer = juice.optim.CircuitOptimizer(pc, lr = 0.1, pseudocount = 0.1) + scheduler = juice.optim.CircuitScheduler( + optimizer, + method = "multi_linear", + lrs = [0.9, 0.1, 0.05], + milestone_steps = [0, len(train_loader) * 100, len(train_loader) * 350] + ) + + for batch in train_loader: + x = batch[0].to(device) + + lls = pc(x, record_cudagraph = True) + lls.mean().backward() + break + + mini_batch_em_epoch(5, pc, optimizer, scheduler, train_loader, test_loader, device, logspace_flows = True) + + test_ll = evaluate(pc, test_loader) + + assert test_ll > -785 + + @pytest.mark.slow def test_small_hclt_full(): @@ -303,7 +365,8 @@ def test_hclt_logistic(): if __name__ == "__main__": # torch.manual_seed(3289) - test_hclt() - test_small_hclt_full() - test_large_hclt_full() - test_hclt_logistic() + # test_hclt() + test_hclt_logspace_flows() + # test_small_hclt_full() + # test_large_hclt_full() + # test_hclt_logistic() diff --git a/tests/structures/logspace_flows_test.py b/tests/structures/logspace_flows_test.py new file mode 100644 index 00000000..4f095d27 --- /dev/null +++ b/tests/structures/logspace_flows_test.py @@ -0,0 +1,170 @@ +import torch +import torchvision +import time +from torch.utils.data import TensorDataset, DataLoader + +import pyjuice as juice +import pyjuice.distributions as dists + + +def logsubexp(x, y): + """ + Compute log(exp(x) - exp(y)) in a numerically stable way. + """ + x, y = torch.maximum(x, y), torch.minimum(x, y) + + # Compute the maximum value between x and y element-wise + max_val = torch.max(x, y) + + # Compute the result using logsumexp trick + result = max_val + torch.log(torch.exp(x - max_val) - torch.exp(y - max_val)) + + return result + + +def test_logspace_hclt_backward(): + + device = torch.device("cuda:0") + + train_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = True, download = True) + + train_data = train_dataset.data.reshape(60000, 28*28)[:5000,:] + + num_features = train_data.size(1) + num_latents = 128 + + root_ns = juice.structures.HCLT( + train_data.float().to(device), + num_bins = 32, + sigma = 0.5 / 32, + num_latents = num_latents, + chunk_size = 32 + ) + root_ns.init_parameters() + + pc = juice.TensorCircuit(root_ns) + + pc.to(device) + + block_size = root_ns.chs[0].block_size + num_blocks = num_latents // block_size + + batch_data = train_data[:512,:].contiguous().to(device) + data_cpu = batch_data.long() + batch_size = batch_data.size(0) + + pc.init_param_flows(flows_memory = 0.0) + + lls = pc(batch_data) + pc.backward(batch_data.permute(1, 0), allow_modify_flows = False, logspace_flows = True) + + pc.update_param_flows() + + node_mars = pc.node_mars + node_flows = pc.node_flows + + temp_node_mars = pc.node_mars.clone() + temp_node_flows = pc.node_flows.clone() + temp_element_mars = pc.element_mars.clone() + temp_element_flows = pc.element_flows.clone() + temp_params = pc.params + temp_param_flows = pc.param_flows.clone() + + ns2flows = dict() + ns2flows[root_ns] = torch.ones([1, batch_size], device = device) + + ch2par = dict() + for ns in root_ns: + for cs in ns.chs: + if cs not in ch2par: + ch2par[cs] = set() + ch2par[cs].add(ns) + + visited = set() + + with torch.no_grad(): + for ns in root_ns(reverse = True): + visited.add(ns) + if ns == root_ns: + + sid, eid = ns._output_ind_range + assert torch.all(torch.abs(node_flows[sid:eid,:] - 0.0) < 1e-4) + + nflows = ns2flows[ns] + nmars = node_mars[sid:eid,:] + + for i, cs in enumerate(ns.chs): + params = ns._params.reshape(1, num_blocks * ns.num_chs, 1, block_size).permute(0, 2, 1, 3).to(device) + params = params[:,:,i*num_blocks:(i+1)*num_blocks,:].reshape(1, num_latents) + + param_flows = ns._param_flows.reshape(1, num_blocks * ns.num_chs, 1, block_size).permute(0, 2, 1, 3).to(device) + param_flows = param_flows[:,:,i*num_blocks:(i+1)*num_blocks,:].reshape(1, num_latents) + + if cs.is_prod(): + emars = torch.zeros([num_latents, batch_size], device = device) + for cns in cs.chs: + sid, eid = cns._output_ind_range + emars += node_mars[sid:eid,:] + else: + raise ValueError() + + eflows = nflows.log() + params.log().permute(1, 0) + emars - nmars + pflows = eflows.exp().sum(dim = 1) + + assert torch.all(torch.abs(pflows - param_flows[0,:]) < 1e-4 * batch_size) + + ns2flows[cs] = eflows + + elif ns.is_prod(): + nflows = ns2flows[ns] + + for cs in ns.chs: + if cs not in ns2flows: + ns2flows[cs] = torch.zeros([num_latents, batch_size], device = device) - float("inf") + ns2flows[cs] = torch.logaddexp(ns2flows[cs], nflows) + + elif ns.is_sum(): + + for par_cs in ch2par[ns]: + assert par_cs in visited + + nflows = ns2flows[ns] + + sid, eid = ns._output_ind_range + + assert torch.all(logsubexp(nflows, node_flows[sid:eid,:]).exp() < 1e-3) + assert (logsubexp(nflows, node_flows[sid:eid,:]).exp() > 1e-5).float().mean() < 0.2 + + nflows = node_flows[sid:eid,:] + + nmars = node_mars[sid:eid,:] + + ch_eflows = [] + + for i, cs in enumerate(ns.chs): + params = ns._params.reshape(num_blocks, num_blocks * ns.num_chs, block_size, block_size).permute(0, 2, 1, 3).to(device) + params = params[:,:,i*num_blocks:(i+1)*num_blocks,:].reshape(num_latents, num_latents) + + param_flows = ns._param_flows.reshape(num_blocks, num_blocks * ns.num_chs, block_size, block_size).permute(0, 2, 1, 3).to(device) + param_flows = param_flows[:,:,i*num_blocks:(i+1)*num_blocks,:].reshape(num_latents, num_latents) + + if cs.is_prod(): + emars = torch.zeros([num_latents, batch_size], device = device) + for cns in cs.chs: + sid, eid = cns._output_ind_range + emars += node_mars[sid:eid,:] + else: + raise ValueError() + + eflows = (nflows[None,:,:] + params.permute(1, 0).log()[:,:,None] + emars[:,None,:] - nmars[None,:,:]).logsumexp(dim = 1) + pflows = (nflows[None,:,:] + params.permute(1, 0).log()[:,:,None] + emars[:,None,:] - nmars[None,:,:]).logsumexp(dim = 2).permute(1, 0).exp() + + ch_eflows.append(eflows) + + assert torch.all(torch.abs(pflows - param_flows) < 1e-4 * batch_size) + + ns2flows[cs] = eflows + + +if __name__ == "__main__": + test_logspace_hclt_backward() From 83694a644f2937cb961ffcdbfb489036c99eb199 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Mon, 18 Mar 2024 01:06:50 +0800 Subject: [PATCH 44/53] homogeneous PD --- src/pyjuice/structures/pd.py | 66 +++++++++++++++++++++++++++++------- tests/structures/pd_test.py | 48 +++++++++++++++++++++++++- 2 files changed, 100 insertions(+), 14 deletions(-) diff --git a/src/pyjuice/structures/pd.py b/src/pyjuice/structures/pd.py index a14861ec..05ee7f6e 100644 --- a/src/pyjuice/structures/pd.py +++ b/src/pyjuice/structures/pd.py @@ -24,7 +24,7 @@ def PD(data_shape: Tuple, num_latents: int, input_dist: Optional[Distribution] = None, input_node_type: Type[Distribution] = Categorical, input_node_params: Dict = {"num_cats": 256}, - use_linear_mixing: bool = False, + tie_homogeneous_params: bool = False, block_size: Optional[int] = None): """ Generate PCs with the PD structure (https://arxiv.org/pdf/1202.3732.pdf). @@ -53,6 +53,9 @@ def PD(data_shape: Tuple, num_latents: int, :param input_dist: input distribution :type input_dist: Distribution + :param tie_homogeneous_params: whether to tie parameters of sum/input nodes with compatible structures + :type tie_homogeneous_params: bool + :param block_size: block size :type block_size: int """ @@ -72,6 +75,9 @@ def PD(data_shape: Tuple, num_latents: int, num_axes = len(data_shape) + # A dictionary of source nodes + source_ns_dict = dict() + # Construct split points if split_intervals is not None: if isinstance(split_intervals, int): @@ -117,6 +123,12 @@ def updated_hypercube(hypercube, axis, s = None, e = None): hypercube = (tuple(hypercube[0]), tuple(hypercube[1])) return hypercube + def hypercube2shape(hypercube): + return tuple(hypercube[1][i] - hypercube[0][i] for i in range(len(hypercube[0]))) + + def get_signature(hypercube, *ch_ns): + return (hypercube2shape(hypercube), tuple(len(ns.scope) for ns in ch_ns)) + def create_input_ns(hypercube): scope = hypercube2scope(hypercube) if input_layer_fn is not None: @@ -124,11 +136,28 @@ def create_input_ns(hypercube): else: input_nodes = [] for var in scope: - ns = inputs(var, num_node_blocks = num_node_blocks, dist = input_node_type(**input_node_params)) + if not tie_homogeneous_params: + ns = inputs(var, num_node_blocks = num_node_blocks, dist = input_node_type(**input_node_params)) + else: + if "input" in source_ns_dict: + ns = source_ns_dict["input"].duplicate(var, tie_params = True) + else: + ns = inputs(var, num_node_blocks = num_node_blocks, dist = input_node_type(**input_node_params)) + source_ns_dict["input"] = ns input_nodes.append(ns) edge_ids = torch.arange(0, num_node_blocks)[None,:].repeat(2, 1) - return summate(multiply(*input_nodes), num_node_blocks = num_node_blocks, edge_ids = edge_ids) + pns = multiply(*input_nodes) + if not tie_homogeneous_params: + return summate(pns, num_node_blocks = num_node_blocks, edge_ids = edge_ids) + else: + signature = get_signature(hypercube, pns) + if signature in source_ns_dict: + return source_ns_dict[signature].duplicate(pns, tie_params = True) + else: + ns = summate(pns, num_node_blocks = num_node_blocks, edge_ids = edge_ids) + source_ns_dict[signature] = ns + return ns def recursive_construct(hypercube, depth = 1): if hypercube in hypercube2ns: @@ -161,24 +190,35 @@ def recursive_construct(hypercube, depth = 1): ns = create_input_ns(hypercube) elif hypercube == root_hypercube: ns = summate(*pns, num_node_blocks = 1, block_size = 1) - elif not use_linear_mixing: + else: if len(pns) <= max_prod_block_conns: - ns = summate(*pns, num_node_blocks = num_node_blocks) + if not tie_homogeneous_params: + ns = summate(*pns, num_node_blocks = num_node_blocks) + else: + signature = get_signature(hypercube, *pns) + if signature in source_ns_dict: + ns = source_ns_dict[signature].duplicate(*pns, tie_params = True) + else: + ns = summate(*pns, num_node_blocks = num_node_blocks) + source_ns_dict[signature] = ns else: block_ids = torch.topk(torch.rand([num_node_blocks, len(pns)]), k = max_prod_block_conns, dim = 1).indices par_ids = torch.arange(0, num_node_blocks)[:,None,None].repeat(1, max_prod_block_conns, num_node_blocks) chs_ids = block_ids[:,:,None] * num_node_blocks + torch.arange(0, num_node_blocks)[None,None,:] edge_ids = torch.stack((par_ids.reshape(-1), chs_ids.reshape(-1)), dim = 0) - ns = summate(*pns, num_node_blocks = num_node_blocks, edge_ids = edge_ids) - else: - # Linear mixing as implemented in EiNet's Mixing layer - if len(pns) <= max_prod_block_conns: - ns = summate(*pns, num_node_blocks = num_node_blocks) - else: - ch_ns = [multiply(summate(pn, num_node_blocks = num_node_blocks)) for pn in pns] - ns = summate(*ch_ns, num_node_blocks = num_node_blocks, edge_ids = torch.arange(0, num_node_blocks)[None,:].repeat(2, 1)) + + if not tie_homogeneous_params: + ns = summate(*pns, num_node_blocks = num_node_blocks, edge_ids = edge_ids) + else: + signature = get_signature(hypercube, *pns) + if signature in source_ns_dict: + ns = source_ns_dict[signature].duplicate(*pns, tie_params = True) + else: + ns = summate(*pns, num_node_blocks = num_node_blocks, edge_ids = edge_ids) + source_ns_dict[signature] = ns hypercube2ns[hypercube] = ns + return ns with set_block_size(block_size = block_size): diff --git a/tests/structures/pd_test.py b/tests/structures/pd_test.py index 6fb83592..ac1b8b6e 100644 --- a/tests/structures/pd_test.py +++ b/tests/structures/pd_test.py @@ -135,6 +135,52 @@ def test_pd(): assert test_ll > -765.0 +def test_homogeneous_pd(): + + device = torch.device("cuda:0") + + train_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = True, download = True) + test_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = False, download = True) + + train_data = train_dataset.data.reshape(60000, 28*28) + test_data = test_dataset.data.reshape(10000, 28*28) + + num_features = train_data.size(1) + + train_loader = DataLoader( + dataset = TensorDataset(train_data), + batch_size = 512, + shuffle = True, + drop_last = True + ) + test_loader = DataLoader( + dataset = TensorDataset(test_data), + batch_size = 512, + shuffle = False, + drop_last = True + ) + + ns = juice.structures.PD( + data_shape = (28, 28), + num_latents = 256, + split_intervals = (4, 4), + structure_type = "sum_dominated", + tie_homogeneous_params = True + ) + pc = juice.TensorCircuit(ns) + + pc.to(device) + + optimizer = juice.optim.CircuitOptimizer(pc, lr = 0.1, pseudocount = 0.0001) + + mini_batch_em_epoch(10, pc, optimizer, None, train_loader, test_loader, device) + + test_ll = evaluate(pc, test_loader) + + assert test_ll > -780.0 + + if __name__ == "__main__": - torch.manual_seed(2391) + # torch.manual_seed(2391) test_pd() + test_homogeneous_pd() From fd415218971e7e8f515a78fa32c93c6130d4272c Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Mon, 18 Mar 2024 05:19:56 +0800 Subject: [PATCH 45/53] avoid nans in backward pass for zero-flow inner nodes --- src/pyjuice/layer/prod_layer.py | 34 ++++++++++++++++----------------- src/pyjuice/layer/sum_layer.py | 23 +++++++++++++++------- 2 files changed, 33 insertions(+), 24 deletions(-) diff --git a/src/pyjuice/layer/prod_layer.py b/src/pyjuice/layer/prod_layer.py index 5a8e106c..c47fa24a 100644 --- a/src/pyjuice/layer/prod_layer.py +++ b/src/pyjuice/layer/prod_layer.py @@ -370,14 +370,14 @@ def _forward_backward_kernel_3d(node_vals_ptr, element_vals_ptr, local_ids_ptr, if prop_logsumexp: # logaddexp - diff = nvals - node_vals + diff = node_vals - nvals nvals = tl.where( - diff == 0, - nvals + 0.69314718055994530942, # log(2) + nvals == -float("inf"), + node_vals, tl.where( diff > 0, - nvals + tlmath.log1p(tl.exp(-diff)), - node_vals + tlmath.log1p(tl.exp(diff)) + node_vals + tlmath.log1p(tl.exp(-diff)), + nvals + tlmath.log1p(tl.exp(diff)) ) ) else: @@ -429,7 +429,7 @@ def _forward_backward_kernel_2d(node_vals_ptr, element_vals_ptr, local_ids_ptr, if prop_logsumexp: # Take the logsumexp of the child nodes' values evals_max = tl.max(evals, axis = 0) - nvals = tl.log(tl.sum(tl.exp(evals - evals_max[None,:]), axis = 0)) + evals_max + nvals = tl.where(evals_max != -float("inf"), tl.log(tl.sum(tl.exp(evals - evals_max[None,:]), axis = 0)) + evals_max, -float("inf")) else: # Take the sum of the child nodes' values nvals = tl.sum(evals, axis = 0) @@ -440,14 +440,14 @@ def _forward_backward_kernel_2d(node_vals_ptr, element_vals_ptr, local_ids_ptr, if prop_logsumexp: # logaddexp - diff = nvals - node_vals + diff = node_vals - nvals nvals = tl.where( - diff == 0, - nvals + 0.69314718055994530942, # log(2) + nvals == -float("inf"), + node_vals, tl.where( diff > 0, - nvals + tlmath.log1p(tl.exp(-diff)), - node_vals + tlmath.log1p(tl.exp(diff)) + node_vals + tlmath.log1p(tl.exp(-diff)), + nvals + tlmath.log1p(tl.exp(diff)) ) ) else: @@ -510,7 +510,7 @@ def _forward_backward_kernel_large(node_vals_ptr, element_vals_ptr, local_ids_pt if prop_logsumexp: # Take the logsumexp of the child nodes' values evals_max = tl.max(evals, axis = 0) - nvals_sub = tl.sum(tl.exp(evals - evals_max[None,:]), axis = 2) + nvals_sub = tl.where(evals_max != -float("inf"), tl.sum(tl.exp(evals - evals_max[None,:]), axis = 0), 0.0) nvals = tl.where(evals_max > nvals, tl.log(nvals_sub + tl.exp(nvals - evals_max) + 1e-24) + evals_max, tl.log(tl.exp(evals_max - nvals) * nvals_sub + 1.0) + nvals @@ -532,14 +532,14 @@ def _forward_backward_kernel_large(node_vals_ptr, element_vals_ptr, local_ids_pt if prop_logsumexp: # logaddexp - diff = nvals - node_vals + diff = node_vals - nvals nvals = tl.where( - diff == 0, - nvals + 0.69314718055994530942, # log(2) + nvals == -float("inf"), + node_vals, tl.where( diff > 0, - nvals + tlmath.log1p(tl.exp(-diff)), - node_vals + tlmath.log1p(tl.exp(diff)) + node_vals + tlmath.log1p(tl.exp(-diff)), + nvals + tlmath.log1p(tl.exp(diff)) ) ) else: diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index d3ce78db..286e193e 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -1613,9 +1613,12 @@ def _bk_triton_block_sparse_ele_kernel(node_flows, element_flows, node_mars, ele if logspace_flows: partial_flows_max = emars + log_n_fdm_max - acc = tl.where(partial_flows_max > acc, - tl.log(partial_flows + tl.exp(acc - partial_flows_max) + 1e-24) + partial_flows_max, - tl.log(tl.exp(partial_flows_max - acc) * partial_flows + 1.0) + acc + acc = tl.where(log_n_fdm_max == -float("inf"), + acc, + tl.where(partial_flows_max > acc, + tl.log(partial_flows + tl.exp(acc - partial_flows_max) + 1e-24) + partial_flows_max, + tl.log(tl.exp(partial_flows_max - acc) * partial_flows + 1.0) + acc + ) ) else: acc += partial_flows * tl.exp(emars + log_n_fdm_max) @@ -1757,9 +1760,12 @@ def _bk_triton_block_sparse_ele_csmm2_kernel(node_flows, element_flows, node_mar if logspace_flows: partial_flows_max = emars + log_n_fdm_max[None,:] - acc = tl.where(partial_flows_max > acc, - tl.log(partial_flows + tl.exp(acc - partial_flows_max) + 1e-24) + partial_flows_max, - tl.log(tl.exp(partial_flows_max - acc) * partial_flows + 1.0) + acc + acc = tl.where(log_n_fdm_max[None,:] == -float("inf"), + acc, + tl.where(partial_flows_max > acc, + tl.log(partial_flows + tl.exp(acc - partial_flows_max) + 1e-24) + partial_flows_max, + tl.log(tl.exp(partial_flows_max - acc) * partial_flows + 1.0) + acc + ) ) else: acc += partial_flows * tl.exp(emars + log_n_fdm_max[None,:]) @@ -2692,7 +2698,10 @@ def _bk_triton_sparse_par_kernel(node_flows, node_mars, element_mars, params, pa if logspace_flows: plflows = nflows[None,:] + emars - nmars[None,:] plflows_max = tl.max(plflows, axis = 1) - pflows = tl.sum(tl.exp(plflows - plflows_max[:,None]), axis = 1) * tl.exp(plflows_max) + pflows = tl.where(plflows_max != -float("inf"), + tl.exp(tl.log(tl.sum(tl.exp(plflows - plflows_max[:,None]), axis = 1)) + plflows_max), + 0.0 + ) else: pflows = tl.sum(nflows[None,:] * tl.exp(emars - nmars[None,:]), axis = 1) From e483917669a8d9c8b2851d7f5601fd8a0b164341 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Fri, 22 Mar 2024 16:37:39 +0800 Subject: [PATCH 46/53] remove num_vars assertion in forward pass --- src/pyjuice/model/tensorcircuit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pyjuice/model/tensorcircuit.py b/src/pyjuice/model/tensorcircuit.py index c3a75bbc..6eeff652 100644 --- a/src/pyjuice/model/tensorcircuit.py +++ b/src/pyjuice/model/tensorcircuit.py @@ -153,7 +153,7 @@ def forward(self, inputs: torch.Tensor, input_layer_fn: Optional[Union[str,Calla B = inputs.size(0) if input_layer_fn is None: - assert inputs.dim() == 2 and inputs.size(1) == self.num_vars + assert inputs.dim() == 2 inputs = inputs.permute(1, 0) From d6ab5216681d533f216fa0eda0a15c0ef5c77873 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Fri, 22 Mar 2024 16:38:02 +0800 Subject: [PATCH 47/53] add `len(ns)` function --- src/pyjuice/nodes/nodes.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/pyjuice/nodes/nodes.py b/src/pyjuice/nodes/nodes.py index 00f31e01..b8741f79 100644 --- a/src/pyjuice/nodes/nodes.py +++ b/src/pyjuice/nodes/nodes.py @@ -226,6 +226,27 @@ def clear_hooks(ns): else: clear_hooks(self) + def __len__(self): + count = 0 + + def dfs(ns: CircuitNodes, visited: set = set()): + nonlocal count + + if ns in visited: + return + + visited.add(ns) + + # Recursively traverse children + if ns.is_sum() or ns.is_prod(): + for cs in ns.chs: + dfs(cs, visited = visited) + + count += 1 + + dfs(self) + return count + def __iter__(self): return node_iterator(self, self._reverse_iter) From 77fa53fde010c0ae540c7440cc44c7d38b90f31b Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Fri, 22 Mar 2024 16:38:49 +0800 Subject: [PATCH 48/53] fix sum node construction when input `edge_ids` is a list --- src/pyjuice/nodes/sum_nodes.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/pyjuice/nodes/sum_nodes.py b/src/pyjuice/nodes/sum_nodes.py index 117a35bb..3864f8ae 100644 --- a/src/pyjuice/nodes/sum_nodes.py +++ b/src/pyjuice/nodes/sum_nodes.py @@ -338,8 +338,6 @@ def _construct_edges(self, edge_ids: Optional[Union[Tensor,Sequence[Tensor]]], r curr_edge_ids[1,:] += ch_gid_start edge_ids.append(curr_edge_ids) - ch_nid_start += self.chs[cs_id].num_node_blocks - edge_ids = torch.cat(edge_ids, dim = 1) if reorder: From ff7d05ea15bcc16deedbe61a4c1dd27dd408fda4 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Fri, 22 Mar 2024 16:39:17 +0800 Subject: [PATCH 49/53] provide `max_block_size` option for `deepcopy` --- src/pyjuice/transformations/copy.py | 83 +++++++++++++++++++++++++---- 1 file changed, 73 insertions(+), 10 deletions(-) diff --git a/src/pyjuice/transformations/copy.py b/src/pyjuice/transformations/copy.py index 5ba4f2a4..73481923 100644 --- a/src/pyjuice/transformations/copy.py +++ b/src/pyjuice/transformations/copy.py @@ -1,5 +1,6 @@ from __future__ import annotations +import torch from copy import deepcopy as pydeepcopy from typing import Optional, Dict @@ -7,7 +8,7 @@ from pyjuice.utils import BitSet -def deepcopy(root_ns: CircuitNodes, tie_params: bool = False, +def deepcopy(root_ns: CircuitNodes, tie_params: bool = False, max_block_size: Optional[int] = None, var_mapping: Optional[Dict[int,int]] = None) -> CircuitNodes: """ Create a deepcopy of the input PC. @@ -18,12 +19,19 @@ def deepcopy(root_ns: CircuitNodes, tie_params: bool = False, :param tie_params: whether to tie the parameters between the original PC and the copied PC (if tied, their parameters will always be the same) :type tie_params: bool + :param max_block_size: the maximum block size of the copied PC + :type max_block_size: Optional[int] + :param var_mapping: a mapping dictionary between the variables of the original PC and the copied PC :type var_mapping: Optional[Dict[int,int]] :returns: a copied PC """ + assert not (max_block_size is not None and tie_params), "Could not change block size when `tie_params=True`." + if max_block_size is not None: + assert max_block_size > 0 and (max_block_size & (max_block_size - 1)) == 0, f"`max_block_size` must be a power of 2, but got `max_block_size={max_block_size}`." + old2new = dict() tied_ns_pairs = [] @@ -43,24 +51,74 @@ def dfs(ns: CircuitNodes): if ns.is_sum(): if not tie_params: + if max_block_size is None: + edge_ids = ns.edge_ids.clone() + block_size = ns.block_size + params = ns.get_params() + else: + old_ch_blk_size = ns.chs[0].block_size + old_blk_size = ns.block_size + + new_ch_blk_size = new_chs[0].block_size + new_blk_size = min(old_blk_size, max_block_size) + + blk_factor = old_blk_size // new_blk_size + ch_blk_factor = old_ch_blk_size // new_ch_blk_size + + edge_ids = torch.stack( + (ns.edge_ids[0,:][:,None,None].repeat(1, blk_factor, ch_blk_factor) * blk_factor + torch.arange(0, blk_factor)[None,:,None], + ns.edge_ids[1,:][:,None,None].repeat(1, blk_factor, ch_blk_factor) * ch_blk_factor + torch.arange(0, ch_blk_factor)[None,None,:]), + dim = 0 + ).flatten(1, 3) + block_size = new_blk_size + + params = ns.get_params() + if params is not None: + num_edges = params.size(0) + params = params.reshape(num_edges, blk_factor, new_blk_size, ch_blk_factor, new_ch_blk_size).permute(0, 1, 3, 2, 4).flatten(0, 2) + new_ns = SumNodes( - ns.num_node_blocks, + ns.num_nodes // block_size, new_chs, - ns.edge_ids.clone(), - block_size = ns.block_size + edge_ids, + block_size = block_size ) - params = ns.get_params() if params is not None: new_ns.set_params(params.clone(), normalize = False) else: new_ns = ns.duplicate(*new_chs, tie_params = True) elif ns.is_prod(): + if max_block_size is None: + edge_ids = ns.edge_ids.clone() + block_size = ns.block_size + else: + old_ch_blk_size = ns.chs[0].block_size + old_blk_size = ns.block_size + + new_ch_blk_size = new_chs[0].block_size + new_blk_size = min(old_blk_size, max_block_size) + + if old_blk_size == new_blk_size and old_ch_blk_size == new_ch_blk_size: + edge_ids = ns.edge_ids.clone() + block_size = ns.block_size + else: + blk_factor = old_blk_size // new_blk_size + ch_blk_factor = old_ch_blk_size // new_ch_blk_size + + if blk_factor == ch_blk_factor: + edge_ids = ns.edge_ids.clone() + edge_ids = edge_ids[:,None,:].repeat(1, blk_factor, 1) * blk_factor + torch.arange(0, blk_factor)[None,:,None] + edge_ids = edge_ids.flatten(0, 1) + block_size = new_blk_size + else: + raise NotImplementedError() + new_ns = ProdNodes( - ns.num_node_blocks, + ns.num_nodes // block_size, new_chs, - ns.edge_ids.clone(), - block_size = ns.block_size + edge_ids, + block_size = block_size ) else: @@ -76,12 +134,17 @@ def dfs(ns: CircuitNodes): else: scope = pydeepcopy(ns.scope) + if max_block_size is None: + block_size = ns.block_size + else: + block_size = min(ns.block_size, max_block_size) + if not tie_params: new_ns = InputNodes( - num_node_blocks = ns.num_node_blocks, + num_node_blocks = ns.num_nodes // block_size, scope = pydeepcopy(scope), dist = pydeepcopy(ns.dist), - block_size = ns.block_size + block_size = block_size ) params = ns.get_params() if params is not None: From 4e7180ef9e651c6f8dce5282965b20b36bc881ba Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Fri, 22 Mar 2024 16:39:46 +0800 Subject: [PATCH 50/53] SGD update function --- src/pyjuice/model/backend/par_update.py | 81 +++++++++++++++++++++++-- 1 file changed, 77 insertions(+), 4 deletions(-) diff --git a/src/pyjuice/model/backend/par_update.py b/src/pyjuice/model/backend/par_update.py index ea6f8cb8..924441fc 100644 --- a/src/pyjuice/model/backend/par_update.py +++ b/src/pyjuice/model/backend/par_update.py @@ -209,9 +209,9 @@ def cum_pflow_kernel(cum_pflows, params, param_flows, nchs, par_start_ids, pflow @triton.jit -def par_update_kernel(params, param_flows, cum_pflows, nchs, par_start_ids, pflow_start_ids, blk_sizes, blk_intervals, - global_nids, constexprs, num_blocks, keep_zero_params: tl.constexpr, BLOCK_ID: tl.constexpr, - BLOCK_SIZE: tl.constexpr): +def em_par_update_kernel(params, param_flows, cum_pflows, nchs, par_start_ids, pflow_start_ids, blk_sizes, blk_intervals, + global_nids, constexprs, num_blocks, keep_zero_params: tl.constexpr, BLOCK_ID: tl.constexpr, + BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis = 0) @@ -253,6 +253,42 @@ def par_update_kernel(params, param_flows, cum_pflows, nchs, par_start_ids, pflo tl.store(params + offs_par, updated_param, mask = mask_pflow) +@triton.jit +def sgd_par_update_kernel(params, param_grads, par_start_ids, pgrad_start_ids, blk_sizes, blk_intervals, + global_nids, constexprs, num_blocks, keep_zero_params: tl.constexpr, BLOCK_ID: tl.constexpr, + BLOCK_SIZE: tl.constexpr): + + pid = tl.program_id(axis = 0) + + # Retrieve the constants + lr = tl.load(constexprs) + + offs_m = pid * BLOCK_ID + tl.arange(0, BLOCK_ID) + mask_m = offs_m < num_blocks + + offs_blk = tl.arange(0, BLOCK_SIZE) + + par_start = tl.load(par_start_ids + offs_m, mask = mask_m, other = 0) + pgrad_start = tl.load(pgrad_start_ids + offs_m, mask = mask_m, other = 0) + blk_size = tl.load(blk_sizes + offs_m, mask = mask_m, other = 0) + blk_interval = tl.load(blk_intervals + offs_m, mask = mask_m, other = 0) + global_nid = tl.load(global_nids + offs_m, mask = mask_m, other = 0) + + offs_pgrad = pgrad_start[:,None] + offs_blk[None,:] * blk_interval[:,None] + mask_pgrad = mask_m[:,None] & (offs_blk[None,:] < blk_size[:,None]) + pgrads = tl.load(param_grads + offs_pgrad, mask = mask_pgrad, other = 0) + + offs_par = par_start[:,None] + offs_blk[None,:] * blk_interval[:,None] + old_param = tl.load(params + offs_par, mask = mask_pflow, other = 0) + + if keep_zero_params: + updated_params = tl.where(old_param < 1e-12, 0.0, tl.exp(tl.log(old_param) + lr * pgrads)) + else: + updated_param = tl.exp(tl.log(old_param) + lr * pgrads) + + tl.store(params + offs_par, updated_param, mask = mask_pgrad) + + def em_par_update(params: torch.Tensor, param_flows: torch.Tensor, par_update_kwargs: Sequence, step_size: float, pseudocount: float = 0.0, keep_zero_params: bool = True): @@ -280,7 +316,44 @@ def em_par_update(params: torch.Tensor, param_flows: torch.Tensor, par_update_kw global_nids, constexprs, num_blocks, keep_zero_params, BLOCK_ID, BLOCK_SIZE ) - par_update_kernel[grid]( + em_par_update_kernel[grid]( params, param_flows, cum_pflows, nchs, par_start_ids, pflow_start_ids, blk_sizes, blk_intervals, global_nids, constexprs, num_blocks, keep_zero_params, BLOCK_ID, BLOCK_SIZE ) + + return None + + +def sgd_par_update(params: torch.Tensor, param_grads: torch.Tensor, par_update_kwargs: Sequence, + lr: float, keep_zero_params: bool = True): + """ + Apply one-step SGD parameter update. + + :param params: the parameter tensor + :type params: torch.Tensor + + :param param_grads: gradients of the log-parameters + :type param_grads: torch.Tensor + + :param lr: learning rate + :type lr: float + + :param keep_zero_params: whether to freeze zero parameters + :type keep_zero_params: bool + """ + + par_start_ids, pgrad_start_ids, blk_sizes, blk_intervals, global_nids, nchs, cum_pflows, metadata = par_update_kwargs + + num_blocks = par_start_ids.size(0) + BLOCK_ID = 2048 // BLOCK_SIZE + + grid = (triton.cdiv(num_blocks, BLOCK_ID),) + + constexprs = torch.tensor([lr]).to(params.device) + + sgd_par_update_kernel[grid]( + params, param_grads, par_start_ids, pgrad_start_ids, blk_sizes, blk_intervals, + global_nids, constexprs, num_blocks, keep_zero_params, BLOCK_ID, BLOCK_SIZE + ) + + return None From 7c03d181952b0aa2dcc3707a821d8afb86c0768e Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Fri, 22 Mar 2024 17:09:18 +0800 Subject: [PATCH 51/53] add option to accumulate negative parameter flows in the backward pass --- src/pyjuice/layer/sum_layer.py | 72 +++++++++++++++++++++--------- src/pyjuice/model/tensorcircuit.py | 4 +- 2 files changed, 53 insertions(+), 23 deletions(-) diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index 286e193e..d4dcf977 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -256,7 +256,7 @@ def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, node_mars: torch.Tensor, element_mars: torch.Tensor, params: torch.Tensor, param_flows: Optional[torch.Tensor] = None, allow_modify_flows: bool = False, propagation_alg: str = "LL", - logspace_flows: bool = False, **kwargs) -> None: + logspace_flows: bool = False, negate_pflows: bool = False, **kwargs) -> None: """ Computes the forward pass of a sum layer: ``` @@ -312,6 +312,7 @@ def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, allow_modify_flows = allow_modify_flows, propagation_alg = propagation_alg, logspace_flows = logspace_flows, + negate_pflows = negate_pflows, **kwargs ) @@ -333,6 +334,7 @@ def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, allow_modify_flows = allow_modify_flows, propagation_alg = propagation_alg, logspace_flows = logspace_flows, + negate_pflows = negate_pflows, **kwargs ) @@ -352,6 +354,7 @@ def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, allow_modify_flows = allow_modify_flows, propagation_alg = propagation_alg, logspace_flows = logspace_flows, + negate_pflows = negate_pflows, **kwargs ) @@ -1218,7 +1221,8 @@ def _backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, partition_id: int = -1, mode: Optional[str] = None, allow_modify_flows: bool = False, propagation_alg: str = "LL", - logspace_flows: bool = False, **kwargs) -> None: + logspace_flows: bool = False, + negate_pflows: bool = False, **kwargs) -> None: """ Back pass of sum layers. @@ -1263,7 +1267,8 @@ def _backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, node_flows, element_flows, params, node_mars, element_mars, param_flows, nids, cids, pids, pfids, chids, parids, parpids, cs_block_size, local_ids, partition_id = partition_id, allow_modify_flows = allow_modify_flows, - propagation_alg = propagation_alg, logspace_flows = logspace_flows, **kwargs + propagation_alg = propagation_alg, logspace_flows = logspace_flows, + negate_pflows = negate_pflows, **kwargs ) elif mode == self.SPARSE: @@ -1271,7 +1276,8 @@ def _backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, node_flows, element_flows, params, node_mars, element_mars, param_flows, nids, cids, pids, pfids, chids, parids, parpids, cs_block_size, local_ids, partition_id = partition_id, allow_modify_flows = allow_modify_flows, - propagation_alg = propagation_alg, logspace_flows = logspace_flows, **kwargs + propagation_alg = propagation_alg, logspace_flows = logspace_flows, + negate_pflows = negate_pflows, **kwargs ) elif mode == self.PYTORCH: @@ -1281,7 +1287,8 @@ def _backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, node_flows, element_flows, params, node_mars, element_mars, param_flows, nids, cids, pids, pfids, chids, parids, parpids, cs_block_size, - propagation_alg = propagation_alg, **kwargs + propagation_alg = propagation_alg, + negate_pflows = negate_pflows, **kwargs ) else: raise ValueError(f"Not supported mode `{mode}`.") @@ -1452,7 +1459,7 @@ def _backward_block_sparse(self, node_flows: torch.Tensor, element_flows: torch. chids: Optional[torch.Tensor], parids: Optional[torch.Tensor], parpids: Optional[torch.Tensor], cs_block_size: int, local_ids: Optional[torch.Tensor] = None, partition_id: int = -1, allow_modify_flows: bool = False, propagation_alg: str = "LL", - logspace_flows: bool = False, **kwargs) -> None: + logspace_flows: bool = False, negate_pflows: bool = False, **kwargs) -> None: """ Back pass of sum layers with block-sparse processing kernel. @@ -1486,7 +1493,8 @@ def _backward_block_sparse(self, node_flows: torch.Tensor, element_flows: torch. nids = nids, cids = cids, pids = pids, pfids = pfids, allow_modify_flows = allow_modify_flows, propagation_alg = propagation_alg, - logspace_flows = logspace_flows, **kwargs + logspace_flows = logspace_flows, + negate_pflows = negate_pflows, **kwargs ) return None @@ -1939,7 +1947,8 @@ def _bk_triton_block_sparse_par_kernel(node_flows, node_mars, element_mars, para batch_size: tl.constexpr, num_edges: tl.constexpr, allow_modify_flows: tl.constexpr, logspace_flows: tl.constexpr, TILE_SIZE_B: tl.constexpr, B_NUM_TILES: tl.constexpr, TILE_SIZE_K: tl.constexpr, TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, - TL_DOT: tl.constexpr, propagation_alg_id: tl.constexpr, alpha = 0.0): + TL_DOT: tl.constexpr, propagation_alg_id: tl.constexpr, negate_pflows: tl.constexpr, + alpha = 0.0): pid_k = tl.program_id(0) # ID of size-`TILE_SIZE_K` edges pid_m = tl.program_id(1) # ID of size-`TILE_SIZE_M` nodes @@ -2041,7 +2050,10 @@ def _bk_triton_block_sparse_par_kernel(node_flows, node_mars, element_mars, para parflow_start = tl.load(pfids + nblock_id * num_edges + offs_edge) eparflows_offsets = offs_node[:,None] + parflow_start[None,:] # [TILE_SIZE_M, TILE_SIZE_K] - tl.atomic_add(param_flows + eparflows_offsets, pflows) + if negate_pflows: + tl.atomic_add(param_flows + eparflows_offsets, -1.0 * pflows) + else: + tl.atomic_add(param_flows + eparflows_offsets, pflows) @staticmethod # @triton.jit @@ -2050,7 +2062,8 @@ def _bk_triton_block_sparse_par_csmm2_kernel(node_flows, node_mars, element_mars batch_size: tl.constexpr, num_edges: tl.constexpr, allow_modify_flows: tl.constexpr, logspace_flows: tl.constexpr, TILE_SIZE_B: tl.constexpr, B_NUM_TILES: tl.constexpr, TILE_SIZE_K: tl.constexpr, TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, - TL_DOT: tl.constexpr, propagation_alg_id: tl.constexpr, alpha = 0.0): + TL_DOT: tl.constexpr, propagation_alg_id: tl.constexpr, negate_pflows: tl.constexpr, + alpha = 0.0): pid_k = tl.program_id(0) # ID of size-`TILE_SIZE_K` edges pid_m = tl.program_id(1) # ID of size-`TILE_SIZE_M` nodes @@ -2146,13 +2159,16 @@ def _bk_triton_block_sparse_par_csmm2_kernel(node_flows, node_mars, element_mars parflow_start = tl.load(pfids + nblock_id * num_edges + offs_edge) eparflows_offsets = offs_node[:,None] + parflow_start[None,:] # [TILE_SIZE_M, TILE_SIZE_K] - tl.atomic_add(param_flows + eparflows_offsets, pflows) + if negate_pflows: + tl.atomic_add(param_flows + eparflows_offsets, -1.0 * pflows) + else: + tl.atomic_add(param_flows + eparflows_offsets, pflows) def _backward_block_sparse_par_flows(self, node_flows: torch.Tensor, params: torch.Tensor, node_mars: torch.Tensor, element_mars: torch.Tensor, param_flows: torch.Tensor, nids: torch.Tensor, cids: torch.Tensor, pids: torch.Tensor, pfids: torch.Tensor, allow_modify_flows: bool = False, propagation_alg: str = "LL", - logspace_flows: bool = False, **kwargs) -> None: + logspace_flows: bool = False, negate_pflows: bool = False, **kwargs) -> None: """ Backward pass of sum layers w.r.t. sum parameters with the block-sparse processing kernel. @@ -2233,6 +2249,7 @@ def _backward_block_sparse_par_flows(self, node_flows: torch.Tensor, params: tor BLOCK_SIZE_M = self.block_size, TL_DOT = TL_DOT, propagation_alg_id = propagation_alg_id, + negate_pflows = negate_pflows, **propagation_alg_kwargs ) else: @@ -2257,6 +2274,7 @@ def _backward_block_sparse_par_flows(self, node_flows: torch.Tensor, params: tor BLOCK_SIZE_M = self.block_size, TL_DOT = TL_DOT, propagation_alg_id = propagation_alg_id, + negate_pflows = negate_pflows, **propagation_alg_kwargs ) @@ -2269,7 +2287,8 @@ def _backward_sparse(self, node_flows: torch.Tensor, element_flows: torch.Tensor chids: Optional[torch.Tensor], parids: Optional[torch.Tensor], parpids: Optional[torch.Tensor], cs_block_size: int, local_ids: Optional[torch.Tensor] = None, partition_id: int = -1, allow_modify_flows: bool = False, - propagation_alg: str = "LL", logspace_flows: bool = False, **kwargs) -> None: + propagation_alg: str = "LL", logspace_flows: bool = False, + negate_pflows: bool = False, **kwargs) -> None: """ Back pass of sum layers with sparse processing kernel. @@ -2303,7 +2322,8 @@ def _backward_sparse(self, node_flows: torch.Tensor, element_flows: torch.Tensor nids = nids, cids = cids, pids = pids, pfids = pfids, allow_modify_flows = allow_modify_flows, propagation_alg = propagation_alg, - logspace_flows = logspace_flows, **kwargs + logspace_flows = logspace_flows, + negate_pflows = negate_pflows, **kwargs ) return None @@ -2629,7 +2649,8 @@ def _backward_sparse_ele_flows(self, node_flows: torch.Tensor, element_flows: to def _bk_triton_sparse_par_kernel(node_flows, node_mars, element_mars, params, param_flows, nids, cids, pids, pfids, pid_m_offset, num_edges: tl.constexpr, batch_size: tl.constexpr, allow_modify_flows: tl.constexpr, logspace_flows: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_K: tl.constexpr, BLOCK_B: tl.constexpr, - TILE_SIZE_B: tl.constexpr, B_NUM_BLOCKS: tl.constexpr, propagation_alg_id: tl.constexpr, alpha = 0.0): + TILE_SIZE_B: tl.constexpr, B_NUM_BLOCKS: tl.constexpr, propagation_alg_id: tl.constexpr, + negate_pflows: tl.constexpr, alpha = 0.0): pid_b = tl.program_id(0) # ID of size-`BLOCK_B` samples pid_e = tl.program_id(1) # ID of size-`BLOCK_K` edges @@ -2725,13 +2746,16 @@ def _bk_triton_sparse_par_kernel(node_flows, node_mars, element_mars, params, pa else: curr_pflows = acc - tl.atomic_add(eparflows_ptr, curr_pflows) + if negate_pflows: + tl.atomic_add(eparflows_ptr, -1.0 * curr_pflows) + else: + tl.atomic_add(eparflows_ptr, curr_pflows) def _backward_sparse_par_flows(self, node_flows: torch.Tensor, params: torch.Tensor, node_mars: torch.Tensor, element_mars: torch.Tensor, param_flows: torch.Tensor, nids: torch.Tensor, cids: torch.Tensor, pids: torch.Tensor, pfids: torch.Tensor, allow_modify_flows: bool = False, propagation_alg: str = "LL", - logspace_flows: bool = False, **kwargs) -> None: + logspace_flows: bool = False, negate_pflows: bool = False, **kwargs) -> None: """ Backward pass of sum layers w.r.t. sum parameters with the block-sparse processing kernel. @@ -2806,6 +2830,7 @@ def _backward_sparse_par_flows(self, node_flows: torch.Tensor, params: torch.Ten TILE_SIZE_B = TILE_SIZE_B, B_NUM_BLOCKS = B_NUM_BLOCKS, propagation_alg_id = propagation_alg_id, + negate_pflows = negate_pflows, **propagation_alg_kwargs ) @@ -2837,6 +2862,7 @@ def _backward_sparse_par_flows(self, node_flows: torch.Tensor, params: torch.Ten TILE_SIZE_B = TILE_SIZE_B, B_NUM_BLOCKS = B_NUM_BLOCKS, propagation_alg_id = propagation_alg_id, + negate_pflows = negate_pflows, **propagation_alg_kwargs ) @@ -2845,7 +2871,7 @@ def _backward_sparse_par_flows(self, node_flows: torch.Tensor, params: torch.Ten def _backward_pytorch(self, node_flows, element_flows, params, node_mars, element_mars, param_flows, nids, cids, pids, pfids, chids, parids, parpids, cs_block_size, propagation_alg: str = "LL", - logspace_flows: bool = False): + logspace_flows: bool = False, negate_pflows: bool = False): """ Back pass of sum layers with native pytorch. @@ -2874,7 +2900,8 @@ def _backward_pytorch(self, node_flows, element_flows, params, node_mars, if param_flows is not None and nids is not None: self._backward_pytorch_par_kernel( node_flows, params, node_mars, element_mars, param_flows, - nids, cids, pids, pfids, self.block_size, logspace_flows + nids, cids, pids, pfids, self.block_size, logspace_flows, + negate_pflows ) @torch.compile @@ -2909,7 +2936,7 @@ def _backward_pytorch_ele_kernel(self, node_flows: torch.Tensor, element_flows: def _backward_pytorch_par_kernel(self, node_flows: torch.Tensor, params: torch.Tensor, node_mars: torch.Tensor, element_mars: torch.Tensor, param_flows: torch.Tensor, nids: torch.Tensor, cids: torch.Tensor, pids: torch.Tensor, pfids: torch.Tensor, ns_block_size: int, - logspace_flows: bool): + logspace_flows: bool, negate_pflows: bool): num_nblocks = nids.size(0) num_edges = cids.size(1) @@ -2928,7 +2955,10 @@ def _backward_pytorch_par_kernel(self, node_flows: torch.Tensor, params: torch.T for i in range(num_nblocks): sid, eid = ns_block_size * i, ns_block_size * (i + 1) - param_flows[pfids[sid:eid,:]] += parflows[sid:eid,:] + if negate_pflows: + param_flows[pfids[sid:eid,:]] -= parflows[sid:eid,:] + else: + param_flows[pfids[sid:eid,:]] += parflows[sid:eid,:] return None diff --git a/src/pyjuice/model/tensorcircuit.py b/src/pyjuice/model/tensorcircuit.py index 6eeff652..9a1994bf 100644 --- a/src/pyjuice/model/tensorcircuit.py +++ b/src/pyjuice/model/tensorcircuit.py @@ -276,6 +276,7 @@ def backward(self, inputs: Optional[torch.Tensor] = None, allow_modify_flows: bool = True, propagation_alg: Union[str,Sequence[str]] = "LL", logspace_flows: bool = False, + negate_pflows: bool = False, **kwargs): """ Backward evaluation of the PC that computes node flows as well as parameter flows. @@ -355,8 +356,7 @@ def _run_inner_layers(): param_flows = self.param_flows if compute_param_flows else None, allow_modify_flows = allow_modify_flows, propagation_alg = propagation_alg if isinstance(propagation_alg, str) else propagation_alg[layer_id], - logspace_flows = logspace_flows, - **kwargs) + logspace_flows = logspace_flows, negate_pflows = negate_pflows, **kwargs) else: raise ValueError(f"Unknown layer type {type(layer)}.") From ca8b6419e13a6d37751e8fb9f4d0726f975305f5 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Sat, 30 Mar 2024 20:28:47 +0800 Subject: [PATCH 52/53] fix typo in SGD update kernel --- src/pyjuice/model/backend/__init__.py | 2 +- src/pyjuice/model/backend/par_update.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/pyjuice/model/backend/__init__.py b/src/pyjuice/model/backend/__init__.py index 6580dde5..46a4cc63 100644 --- a/src/pyjuice/model/backend/__init__.py +++ b/src/pyjuice/model/backend/__init__.py @@ -1,3 +1,3 @@ from .parflow_fusing import compile_cum_par_flows_fn, compute_cum_par_flows, cum_par_flows_to_device -from .par_update import compile_par_update_fn, em_par_update, par_update_to_device +from .par_update import compile_par_update_fn, em_par_update, par_update_to_device, sgd_par_update from .normalize import normalize_parameters diff --git a/src/pyjuice/model/backend/par_update.py b/src/pyjuice/model/backend/par_update.py index 924441fc..f9d4e803 100644 --- a/src/pyjuice/model/backend/par_update.py +++ b/src/pyjuice/model/backend/par_update.py @@ -279,7 +279,7 @@ def sgd_par_update_kernel(params, param_grads, par_start_ids, pgrad_start_ids, b pgrads = tl.load(param_grads + offs_pgrad, mask = mask_pgrad, other = 0) offs_par = par_start[:,None] + offs_blk[None,:] * blk_interval[:,None] - old_param = tl.load(params + offs_par, mask = mask_pflow, other = 0) + old_param = tl.load(params + offs_par, mask = mask_pgrad, other = 0) if keep_zero_params: updated_params = tl.where(old_param < 1e-12, 0.0, tl.exp(tl.log(old_param) + lr * pgrads)) @@ -344,6 +344,9 @@ def sgd_par_update(params: torch.Tensor, param_grads: torch.Tensor, par_update_k par_start_ids, pgrad_start_ids, blk_sizes, blk_intervals, global_nids, nchs, cum_pflows, metadata = par_update_kwargs + tot_num_nodes = metadata["tot_num_nodes"] + BLOCK_SIZE = metadata["BLOCK_SIZE"] + num_blocks = par_start_ids.size(0) BLOCK_ID = 2048 // BLOCK_SIZE From 5373ddc461b122789751037e36f136b9f3e5b662 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Sat, 30 Mar 2024 20:29:46 +0800 Subject: [PATCH 53/53] change kernel allocation for sum layers --- src/pyjuice/layer/sum_layer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index d4dcf977..61aa2209 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -504,9 +504,9 @@ def _fw_triton_block_sparse_tlmm_kernel(node_mars, element_mars, params, nids, c if use_fp16 == 1: # Built-in matmul kernel of triton + float16 - epars_fp16 = (epars * (2**12)).to(tl.float16) + epars_fp16 = (epars * (2**4)).to(tl.float16) emars_fp16 = emars_sub.to(tl.float16) - nmars = tl.dot(epars_fp16, emars_fp16).to(tl.float32) / (2**12) + nmars = tl.dot(epars_fp16, emars_fp16).to(tl.float32) / (2**4) else: # Built-in matmul kernel of triton + float32 nmars = tl.dot(epars, emars_sub) @@ -1260,7 +1260,7 @@ def _backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, elif num_edges <= 32768: mode = self.BLOCK_SPARSE else: - mode = self.SPARSE + mode = self.BLOCK_SPARSE if mode == self.BLOCK_SPARSE: self._backward_block_sparse(