diff --git a/pyproject.toml b/pyproject.toml index a8a612a8..6aea82d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,14 +21,26 @@ authors = [ {name="StarAI", email="guyvdb@cs.ucla.edu"}, ] +[project.optional-dependencies] +dev = [ + "pytest", + "pytest-xdist", + "pytest-skip-slow", + "torchvision", + "torchtext", + "matplotlib" +] + [options.packages.find] where = "src" [tool.setuptools.dynamic] readme = {file = "README.md"} - [tool.pytest.ini_options] addopts = [ "--import-mode=importlib", ] +testpaths = [ + "tests/" +] \ No newline at end of file diff --git a/src/pyjuice/layer/input_layer.py b/src/pyjuice/layer/input_layer.py index 7994206f..fae2c44c 100644 --- a/src/pyjuice/layer/input_layer.py +++ b/src/pyjuice/layer/input_layer.py @@ -30,6 +30,8 @@ def __init__(self, nodes: Sequence[InputNodes], cum_nodes: int = 0, pc_num_vars: in gradient accumulation. """ + assert len(nodes) == len(set(nodes)), "Input node list contains duplicates." + nn.Module.__init__(self) Layer.__init__(self, nodes, disable_block_size_check = True) @@ -213,7 +215,7 @@ def init_param_flows(self, flows_memory: float = 1.0): def forward(self, data: torch.Tensor, node_mars: torch.Tensor, params: Optional[Dict] = None, missing_mask: Optional[torch.Tensor] = None, _batch_first: bool = True, - _apply_missing_mask_only: bool = False): + _apply_missing_mask_only: bool = False, **kwargs): self._used_external_params = (params is not None) if params is None: @@ -300,7 +302,8 @@ def forward(self, data: torch.Tensor, node_mars: torch.Tensor, params: Optional[ raise NotImplementedError("CPU forward fn for input nodes is not implemented.") def backward(self, data: torch.Tensor, node_flows: torch.Tensor, - node_mars: torch.Tensor, params: Optional[Dict] = None): + node_mars: torch.Tensor, params: Optional[Dict] = None, + logspace_flows: bool = False, **kwargs): """ data: [num_vars, B] node_flows: [num_nodes, B] @@ -355,6 +358,7 @@ def backward(self, data: torch.Tensor, node_flows: torch.Tensor, node_offset = node_offset, BLOCK_SIZE = BLOCK_SIZE, partial_eval = 1 if bk_local_ids is not None else 0, + logspace_flows = logspace_flows, num_warps = 8 ) @@ -681,7 +685,7 @@ def _fw_missing_mask_kernel(missing_mask_ptr, node_mars_ptr, vids_ptr, fw_local_ @staticmethod def _flows_kernel_template(flow_fn, params_ptr, param_flows_ptr, node_flows_ptr, node_mars_ptr, data_ptr, vids_ptr, s_pids_ptr, s_pfids_ptr, - metadata_ptr, s_mids_ptr, bk_local_ids_ptr, partial_eval: tl.constexpr, layer_num_nodes: tl.constexpr, + metadata_ptr, s_mids_ptr, bk_local_ids_ptr, partial_eval: tl.constexpr, logspace_flows: tl.constexpr, layer_num_nodes: tl.constexpr, batch_size: tl.constexpr, num_vars_per_node: tl.constexpr, nv_block_size: tl.constexpr, node_offset: tl.constexpr, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis = 0) @@ -720,6 +724,9 @@ def _flows_kernel_template(flow_fn, params_ptr, param_flows_ptr, node_flows_ptr, ns_offsets = (local_offsets + node_offset) * batch_size + batch_offsets flows = tl.load(node_flows_ptr + ns_offsets, mask = mask, other = 0) + if logspace_flows: + flows = tl.exp(flows) + flow_fn(local_offsets, ns_offsets, data, flows, node_mars_ptr, params_ptr, param_flows_ptr, s_pids, s_pfids, metadata_ptr, s_mids_ptr, mask, num_vars_per_node, BLOCK_SIZE) diff --git a/src/pyjuice/layer/layer.py b/src/pyjuice/layer/layer.py index a8585fbf..f218cb12 100644 --- a/src/pyjuice/layer/layer.py +++ b/src/pyjuice/layer/layer.py @@ -7,6 +7,13 @@ class Layer(): + + propagation_alg_mapping = { + "LL": 0, + "MPE": 1, + "GeneralLL": 2 + } + def __init__(self, nodes: Sequence[CircuitNodes], disable_block_size_check: bool = False) -> None: if disable_block_size_check: @@ -60,3 +67,13 @@ def disable_partial_evaluation(self, forward: bool = True, backward: bool = True def provided(self, var_name): return hasattr(self, var_name) and getattr(self, var_name) is not None + + def _get_propagation_alg_kwargs(self, propagation_alg: str, **kwargs): + if propagation_alg == "LL": + return {"alpha": 0.0} + elif propagation_alg == "MPE": + return {"alpha": 0.0} + elif propagation_alg == "GeneralLL": + return {"alpha": kwargs["alpha"]} + else: + raise ValueError(f"Unknown propagation algorithm {propagation_alg}.") diff --git a/src/pyjuice/layer/prod_layer.py b/src/pyjuice/layer/prod_layer.py index 1236e043..c47fa24a 100644 --- a/src/pyjuice/layer/prod_layer.py +++ b/src/pyjuice/layer/prod_layer.py @@ -8,6 +8,13 @@ import time from typing import Sequence, Optional +# In the latest triton, math functions were shuffled around into different modules: +# https://github.com/openai/triton/pull/3172 +if hasattr(tl.extra.cuda, "libdevice"): + tlmath = tl.extra.cuda.libdevic +else: + tlmath = tl.math + from pyjuice.nodes import ProdNodes from pyjuice.utils.parameter_list import FastParamList from pyjuice.utils.kernel_launcher import FastJITFunction @@ -27,6 +34,7 @@ def __init__(self, nodes: Sequence[ProdNodes], global_nid_start: Optional[int] = nn.Module.__init__(self) assert len(nodes) > 0, "No input node." + assert len(nodes) == len(set(nodes)), "Input node list contains duplicates." use_block_sparse_edges = True for nid in range(0, len(nodes)): @@ -149,7 +157,7 @@ def __init__(self, nodes: Sequence[ProdNodes], global_nid_start: Optional[int] = self.partitioned_u_cids = FastParamList([nn.Parameter(tensor, requires_grad = False) for tensor in u_cids]) self.partitioned_parids = FastParamList([nn.Parameter(tensor, requires_grad = False) for tensor in parids]) - def forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, _for_backward: bool = False) -> None: + def forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, _for_backward: bool = False, **kwargs) -> None: """ Computes the forward pass of a product layer. If `block_size == 1`, it is equivalent to the following: ``` @@ -195,7 +203,7 @@ def forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, _for_back return None - def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor) -> None: + def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, logspace_flows: bool = False, **kwargs) -> None: """ Computes the backward pass of a product layer: ``` @@ -214,7 +222,8 @@ def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor) -> Non parids = self.partitioned_parids[partition_id] local_ids = self.bk_partition_local_ids[partition_id] - self._forward_backward(node_flows, element_flows, u_cids, parids, local_ids = local_ids, accum = True) + self._forward_backward(node_flows, element_flows, u_cids, parids, local_ids = local_ids, accum = True, + prop_logsumexp = logspace_flows) else: # Evaluate the whole layer @@ -222,7 +231,8 @@ def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor) -> Non u_cids = self.partitioned_u_cids[partition_id] parids = self.partitioned_parids[partition_id] - self._forward_backward(node_flows, element_flows, u_cids, parids, accum = True) + self._forward_backward(node_flows, element_flows, u_cids, parids, accum = True, + prop_logsumexp = logspace_flows) return None @@ -248,7 +258,7 @@ def enable_partial_evaluation(self, fw_scopes: Optional[Sequence[BitSet]] = None @FastJITFunction def _forward_backward_kernel_3d(node_vals_ptr, element_vals_ptr, local_ids_ptr, nids_ptr, cids_ptr, tot_n_nodes, tot_n_eles, n_nblocks, num_edges: tl.constexpr, batch_size, BLOCK_M: tl.constexpr, BLOCK_B: tl.constexpr, - block_size: tl.constexpr, accum: tl.constexpr, partial_eval: tl.constexpr): + block_size: tl.constexpr, accum: tl.constexpr, partial_eval: tl.constexpr, prop_logsumexp: tl.constexpr): """ This kernel implements the function with 3d tensors. However, it only work with `triton==2.0.0`. """ @@ -263,7 +273,7 @@ def _forward_backward_kernel_3d(node_vals_ptr, element_vals_ptr, local_ids_ptr, ntile_id = pid_m % (block_size // BLOCK_M) # For partial evaluation - if partial_eval == 1: + if partial_eval: nblock_id = tl.load(local_ids_ptr + nblock_id) # Batch offsets and mask @@ -281,17 +291,37 @@ def _forward_backward_kernel_3d(node_vals_ptr, element_vals_ptr, local_ids_ptr, offs_evals = offs_egstart + block_nids[:,None] evals = tl.load(element_vals_ptr + offs_evals[None,:,:] * batch_size + offs_batch[:,None,None], mask = mask_batch[:,None,None]) - # Take the sum of the child nodes' log-probabilities - nvals = tl.sum(evals, axis = 2) + if prop_logsumexp: + # Take the logsumexp of the child nodes' values + evals_max = tl.max(evals, axis = 2) + nvals = tl.log(tl.sum(tl.exp(evals - evals_max[:,:,None]), axis = 2)) + evals_max + else: + # Take the sum of the child nodes' values + nvals = tl.sum(evals, axis = 2) # Node ids to `node_vals_ptr` nblock_start = tl.load(nids_ptr + nblock_id) offs_nvals = (nblock_start + block_nids[None,:]) * batch_size + offs_batch[:,None] # Accumulate the `node_vals` if required - if accum == 1: + if accum: node_vals = tl.load(node_vals_ptr + offs_nvals, mask = mask_batch[:,None], other = 0) - nvals += node_vals + + if prop_logsumexp: + # logaddexp + diff = nvals - node_vals + nvals = tl.where( + diff == 0, + nvals + 0.69314718055994530942, # log(2) + tl.where( + diff > 0, + nvals + tlmath.log1p(tl.exp(-diff)), + node_vals + tlmath.log1p(tl.exp(diff)) + ) + ) + else: + # sum + nvals += node_vals tl.store(node_vals_ptr + offs_nvals, nvals, mask = mask_batch[:,None]) @@ -305,7 +335,7 @@ def _forward_backward_kernel_3d(node_vals_ptr, element_vals_ptr, local_ids_ptr, nblock_ids = offs_node // block_size # For partial evaluation - if partial_eval == 1: + if partial_eval: nblock_ids = tl.load(local_ids_ptr + nblock_ids, mask = mask_node) # Batch offsets and mask @@ -322,17 +352,37 @@ def _forward_backward_kernel_3d(node_vals_ptr, element_vals_ptr, local_ids_ptr, offs_evals = offs_egstart + block_nids[:,None] evals = tl.load(element_vals_ptr + offs_evals[None,:,:] * batch_size + offs_batch[:,None,None], mask = (mask_batch[:,None,None] & mask_node[None,:,None])) - # Take the sum of the child nodes' log-probabilities - nvals = tl.sum(evals, axis = 2) + if prop_logsumexp: + # Take the logsumexp of the child nodes' values + evals_max = tl.max(evals, axis = 2) + nvals = tl.log(tl.sum(tl.exp(evals - evals_max[:,:,None]), axis = 2)) + evals_max + else: + # Take the sum of the child nodes' values + nvals = tl.sum(evals, axis = 2) # Node ids to `node_vals_ptr` nblock_start = tl.load(nids_ptr + nblock_ids[None,:]) offs_nvals = (nblock_start + block_nids[None,:]) * batch_size + offs_batch[:,None] # Accumulate the `node_vals` if required - if accum == 1: + if accum: node_vals = tl.load(node_vals_ptr + offs_nvals, mask = mask_batch[:,None], other = 0) - nvals += node_vals + + if prop_logsumexp: + # logaddexp + diff = node_vals - nvals + nvals = tl.where( + nvals == -float("inf"), + node_vals, + tl.where( + diff > 0, + node_vals + tlmath.log1p(tl.exp(-diff)), + nvals + tlmath.log1p(tl.exp(diff)) + ) + ) + else: + # sum + nvals += node_vals tl.store(node_vals_ptr + offs_nvals, nvals, mask = mask_batch[:,None]) @@ -341,7 +391,7 @@ def _forward_backward_kernel_3d(node_vals_ptr, element_vals_ptr, local_ids_ptr, @FastJITFunction def _forward_backward_kernel_2d(node_vals_ptr, element_vals_ptr, local_ids_ptr, nids_ptr, cids_ptr, tot_n_nodes, tot_n_eles, n_nblocks, num_edges: tl.constexpr, batch_size, BLOCK_M: tl.constexpr, BLOCK_B: tl.constexpr, - block_size: tl.constexpr, accum: tl.constexpr, partial_eval: tl.constexpr): + block_size: tl.constexpr, accum: tl.constexpr, partial_eval: tl.constexpr, prop_logsumexp: tl.constexpr): """ This kernel implements the function with 2d tensors. It works for all `triton` versions. """ @@ -354,7 +404,7 @@ def _forward_backward_kernel_2d(node_vals_ptr, element_vals_ptr, local_ids_ptr, ntile_id = pid_m % (block_size // BLOCK_M) # For partial evaluation - if partial_eval == 1: + if partial_eval: nblock_id = tl.load(local_ids_ptr + nblock_id) # Batch offsets and mask @@ -375,12 +425,34 @@ def _forward_backward_kernel_2d(node_vals_ptr, element_vals_ptr, local_ids_ptr, # Inner loop for i in range(0, BLOCK_M): evals = tl.load(element_vals_ptr + offs_evals, mask = mask_batch[None,:], other = 0) - nvals = tl.sum(evals, axis = 0) + + if prop_logsumexp: + # Take the logsumexp of the child nodes' values + evals_max = tl.max(evals, axis = 0) + nvals = tl.where(evals_max != -float("inf"), tl.log(tl.sum(tl.exp(evals - evals_max[None,:]), axis = 0)) + evals_max, -float("inf")) + else: + # Take the sum of the child nodes' values + nvals = tl.sum(evals, axis = 0) # Accumulate the `node_vals` if required - if accum == 1: + if accum: node_vals = tl.load(node_vals_ptr + offs_nvals, mask = mask_batch) - nvals += node_vals + + if prop_logsumexp: + # logaddexp + diff = node_vals - nvals + nvals = tl.where( + nvals == -float("inf"), + node_vals, + tl.where( + diff > 0, + node_vals + tlmath.log1p(tl.exp(-diff)), + nvals + tlmath.log1p(tl.exp(diff)) + ) + ) + else: + # sum + nvals += node_vals tl.store(node_vals_ptr + offs_nvals, nvals, mask = mask_batch) @@ -392,7 +464,8 @@ def _forward_backward_kernel_2d(node_vals_ptr, element_vals_ptr, local_ids_ptr, @FastJITFunction def _forward_backward_kernel_large(node_vals_ptr, element_vals_ptr, local_ids_ptr, nids_ptr, cids_ptr, tot_n_nodes, tot_n_eles, n_nblocks, num_edges: tl.constexpr, batch_size, BLOCK_N: tl.constexpr, BLOCK_B: tl.constexpr, - N_NUM_BLKS: tl.constexpr, block_size: tl.constexpr, accum: tl.constexpr, partial_eval: tl.constexpr): + N_NUM_BLKS: tl.constexpr, block_size: tl.constexpr, accum: tl.constexpr, partial_eval: tl.constexpr, + prop_logsumexp: tl.constexpr): """ This kernel implements the function with 2d tensors. It is designed for nodes with many edges. """ @@ -405,7 +478,7 @@ def _forward_backward_kernel_large(node_vals_ptr, element_vals_ptr, local_ids_pt ntile_id = pid_m % block_size # For partial evaluation - if partial_eval == 1: + if partial_eval: nblock_id = tl.load(local_ids_ptr + nblock_id) # Batch offsets and mask @@ -424,11 +497,27 @@ def _forward_backward_kernel_large(node_vals_ptr, element_vals_ptr, local_ids_pt nblock_start = tl.load(nids_ptr + nblock_id) offs_nvals = (nblock_start + ntile_id) * batch_size + offs_batch # [BLOCK_B] + # Prepare buffer + if prop_logsumexp: + nvals = tl.zeros([BLOCK_B], dtype = tl.float32) - float("inf") + else: + nvals = tl.zeros([BLOCK_B], dtype = tl.float32) + # Inner loop - nvals = tl.zeros([BLOCK_B], dtype = tl.float32) for i in range(0, N_NUM_BLKS): evals = tl.load(element_vals_ptr + offs_evals, mask = (mask_edge[:,None] & mask_batch[None,:]), other = 0) - nvals += tl.sum(evals, axis = 0) + + if prop_logsumexp: + # Take the logsumexp of the child nodes' values + evals_max = tl.max(evals, axis = 0) + nvals_sub = tl.where(evals_max != -float("inf"), tl.sum(tl.exp(evals - evals_max[None,:]), axis = 0), 0.0) + nvals = tl.where(evals_max > nvals, + tl.log(nvals_sub + tl.exp(nvals - evals_max) + 1e-24) + evals_max, + tl.log(tl.exp(evals_max - nvals) * nvals_sub + 1.0) + nvals + ) + else: + # Take the sum of the child nodes' values + nvals += tl.sum(evals, axis = 0) offs_edge += BLOCK_N mask_edge = (offs_edge < num_edges) @@ -438,15 +527,30 @@ def _forward_backward_kernel_large(node_vals_ptr, element_vals_ptr, local_ids_pt offs_evals = (offs_egstart[:,None] + ntile_id) * batch_size + offs_batch[None,:] # [BLOCK_N, BLOCK_B] # Accumulate the `node_vals` if required - if accum == 1: + if accum: node_vals = tl.load(node_vals_ptr + offs_nvals, mask = mask_batch) - nvals += node_vals + + if prop_logsumexp: + # logaddexp + diff = node_vals - nvals + nvals = tl.where( + nvals == -float("inf"), + node_vals, + tl.where( + diff > 0, + node_vals + tlmath.log1p(tl.exp(-diff)), + nvals + tlmath.log1p(tl.exp(diff)) + ) + ) + else: + # sum + nvals += node_vals tl.store(node_vals_ptr + offs_nvals, nvals, mask = mask_batch) def _forward_backward(self, node_vals: torch.Tensor, element_vals: torch.Tensor, nids: torch.Tensor, cids: torch.Tensor, local_ids: Optional[torch.Tensor] = None, - accum: bool = False) -> None: + accum: bool = False, prop_logsumexp: bool = False) -> None: tot_n_nodes = node_vals.size(0) tot_n_eles = element_vals.size(0) n_nblocks = nids.size(0) if local_ids is None else local_ids.size(0) @@ -454,8 +558,7 @@ def _forward_backward(self, node_vals: torch.Tensor, element_vals: torch.Tensor, batch_size = node_vals.size(1) block_size = self.block_size - accum = 1 if accum else 0 - partial_eval = 1 if local_ids is not None else 0 + partial_eval = local_ids is not None assert num_edges & (num_edges - 1) == 0, "`num_edges` must be a power of 2." @@ -483,7 +586,8 @@ def _forward_backward(self, node_vals: torch.Tensor, element_vals: torch.Tensor, N_NUM_BLKS = triton.cdiv(num_edges, BLOCK_B), block_size = block_size, accum = accum, - partial_eval = partial_eval + partial_eval = partial_eval, + prop_logsumexp = prop_logsumexp ) return None @@ -510,7 +614,8 @@ def _forward_backward(self, node_vals: torch.Tensor, element_vals: torch.Tensor, BLOCK_B = BLOCK_B, block_size = block_size, accum = accum, - partial_eval = partial_eval + partial_eval = partial_eval, + prop_logsumexp = prop_logsumexp ) else: @@ -535,7 +640,8 @@ def _forward_backward(self, node_vals: torch.Tensor, element_vals: torch.Tensor, BLOCK_B = BLOCK_B, block_size = block_size, accum = accum, - partial_eval = partial_eval + partial_eval = partial_eval, + prop_logsumexp = prop_logsumexp ) return None diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index f28597ed..61aa2209 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -40,6 +40,7 @@ def __init__(self, nodes: Sequence[SumNodes], global_nid_start: int, nn.Module.__init__(self) assert len(nodes) > 0, "No input node." + assert len(nodes) == len(set(nodes)), "Input node list contains duplicates." self.nodes = nodes @@ -208,7 +209,8 @@ def num_param_flows(self): return self._layer_pfid_range[1] - self._layer_pfid_range[0] def forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, params: torch.Tensor, - force_use_fp16: bool = False, force_use_fp32: bool = False) -> None: + force_use_fp16: bool = False, force_use_fp32: bool = False, + propagation_alg: str = "LL", **kwargs) -> None: """ Computes the forward pass of a sum layer. @@ -228,7 +230,8 @@ def forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, params: t self._forward( node_mars, element_mars, params, nids, cids, pids, partition_id = partition_id, force_use_fp16 = force_use_fp16, - force_use_fp32 = force_use_fp32 + force_use_fp32 = force_use_fp32, + propagation_alg = propagation_alg, **kwargs ) else: @@ -243,7 +246,8 @@ def forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, params: t node_mars, element_mars, params, nids, cids, pids, local_ids = local_ids, partition_id = partition_id, force_use_fp16 = force_use_fp16, - force_use_fp32 = force_use_fp32 + force_use_fp32 = force_use_fp32, + propagation_alg = propagation_alg, **kwargs ) return None @@ -251,7 +255,8 @@ def forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, params: t def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, node_mars: torch.Tensor, element_mars: torch.Tensor, params: torch.Tensor, param_flows: Optional[torch.Tensor] = None, - allow_modify_flows: bool = False) -> None: + allow_modify_flows: bool = False, propagation_alg: str = "LL", + logspace_flows: bool = False, negate_pflows: bool = False, **kwargs) -> None: """ Computes the forward pass of a sum layer: ``` @@ -272,6 +277,8 @@ def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, `params`: [num_params, B] or [num_params] """ + assert not (allow_modify_flows and logspace_flows), "`allow_modify_flows` should be set to `False` when using `logspace_flows`." + # Disallow modifications of `node_flows` in case of partial evaluation if self.provided("bk_partition_local_ids") and allow_modify_flows: allow_modify_flows = False @@ -283,7 +290,8 @@ def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, nids = self.partitioned_nids[partition_id] self._bk_triton_modify_flow( - node_flows, node_mars, nids, local_ids = None + node_flows, node_mars, nids, local_ids = None, + propagation_alg = propagation_alg, **kwargs ) ## Compute flows w.r.t. elements (i.e., product nodes) ## @@ -301,7 +309,11 @@ def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, chids = chids, parids = parids, parpids = parpids, cs_block_size = cs_block_size, partition_id = partition_id, - allow_modify_flows = allow_modify_flows + allow_modify_flows = allow_modify_flows, + propagation_alg = propagation_alg, + logspace_flows = logspace_flows, + negate_pflows = negate_pflows, + **kwargs ) else: @@ -319,7 +331,11 @@ def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, chids = chids, parids = parids, parpids = parpids, cs_block_size = cs_block_size, local_ids = local_ids, partition_id = partition_id, - allow_modify_flows = allow_modify_flows + allow_modify_flows = allow_modify_flows, + propagation_alg = propagation_alg, + logspace_flows = logspace_flows, + negate_pflows = negate_pflows, + **kwargs ) ## Compute flows w.r.t. sum parameters ## @@ -335,7 +351,11 @@ def backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, element_mars, param_flows, nids = nids, cids = cids, pids = pids, pfids = pfids, partition_id = partition_id, - allow_modify_flows = allow_modify_flows + allow_modify_flows = allow_modify_flows, + propagation_alg = propagation_alg, + logspace_flows = logspace_flows, + negate_pflows = negate_pflows, + **kwargs ) return None @@ -344,7 +364,8 @@ def _forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, params: torch.Tensor, nids: torch.Tensor, cids: torch.Tensor, pids: torch.Tensor, local_ids: Optional[torch.Tensor] = None, partition_id: int = -1, mode: Optional[str] = None, - force_use_fp16: bool = False, force_use_fp32: bool = False) -> None: + force_use_fp16: bool = False, force_use_fp32: bool = False, + propagation_alg: str = "LL", **kwargs) -> None: """ Forward pass of sum layers. @@ -380,18 +401,19 @@ def _forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, self._forward_block_sparse( node_mars, element_mars, params, nids, cids, pids, local_ids, partition_id = partition_id, force_use_fp16 = force_use_fp16, - force_use_fp32 = force_use_fp32 + force_use_fp32 = force_use_fp32, propagation_alg = propagation_alg, **kwargs ) elif mode == self.SPARSE: self._forward_sparse( node_mars, element_mars, params, nids, cids, pids, local_ids, - partition_id = partition_id + partition_id = partition_id, propagation_alg = propagation_alg, **kwargs ) elif mode == self.PYTORCH: self._forward_pytorch( - node_mars, element_mars, params, nids, cids, pids, local_ids + node_mars, element_mars, params, nids, cids, pids, local_ids, + propagation_alg = propagation_alg, **kwargs ) else: @@ -403,7 +425,8 @@ def _forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, def _fw_triton_block_sparse_tlmm_kernel(node_mars, element_mars, params, nids, cids_start, cids_increment, pids_start, pids_increment, local_ids, batch_size: tl.constexpr, partial_eval: tl.constexpr, BLOCK_B: tl.constexpr, TILE_SIZE_K: tl.constexpr, K_NUM_TILES: tl.constexpr, - TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, use_fp16: tl.constexpr): + TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, use_fp16: tl.constexpr, + propagation_alg_id: tl.constexpr, alpha = 0.0): pid_b = tl.program_id(0) # ID of size-`BLOCK_B` batches pid_m = tl.program_id(1) # ID of size-`TILE_SIZE_M` nodes @@ -453,22 +476,45 @@ def _fw_triton_block_sparse_tlmm_kernel(node_mars, element_mars, params, nids, c epars = tl.load(epars_ptr) emars = tl.load(emars_ptr, mask = mask_batch[None,:]) - emars_max = tl.max(emars, axis = 0)[None,:] - emars_sub = tl.where(emars_max != -float("inf"), tl.exp(emars - emars_max), 0.0) + if propagation_alg_id == 1: + # MPE propagation method + lpars = tl.log(epars) + nmars = tl.max(lpars[:,:,None] + emars[None,:,:], axis = 1) + + acc = tl.maximum(acc, nmars) - if use_fp16 == 1: - # Built-in matmul kernel of triton + float16 - epars_fp16 = (epars * (2**12)).to(tl.float16) - emars_fp16 = emars_sub.to(tl.float16) - nmars = tl.dot(epars_fp16, emars_fp16).to(tl.float32) / (2**12) else: - # Built-in matmul kernel of triton + float32 - nmars = tl.dot(epars, emars_sub) - acc = tl.where(emars_max > acc, - tl.log(nmars + tl.exp(acc - emars_max)) + emars_max, - tl.log(tl.exp(emars_max - acc) * nmars + 1.0) + acc - ) + if propagation_alg_id == 0: + # LL propagation method + emars_max = tl.max(emars, axis = 0)[None,:] + emars_sub = tl.where(emars_max != -float("inf"), tl.exp(emars - emars_max), 0.0) + + if propagation_alg_id == 2: + # GeneralLL propagation method + + emars_max = tl.max(emars, axis = 0)[None,:] + # Compute p_i^{alpha} for every i + emars_sub = tl.where(emars_max != -float("inf"), tl.exp((emars - emars_max) * alpha), 0.0) + # Compute w_i^{alpha} for every i + epars = tl.exp(tl.log(epars) * alpha) + + # Also scale `emars_max` + emars_max *= alpha + + if use_fp16 == 1: + # Built-in matmul kernel of triton + float16 + epars_fp16 = (epars * (2**4)).to(tl.float16) + emars_fp16 = emars_sub.to(tl.float16) + nmars = tl.dot(epars_fp16, emars_fp16).to(tl.float32) / (2**4) + else: + # Built-in matmul kernel of triton + float32 + nmars = tl.dot(epars, emars_sub) + + acc = tl.where(emars_max > acc, + tl.log(nmars + tl.exp(acc - emars_max) + 1e-24) + emars_max, + tl.log(tl.exp(emars_max - acc) * nmars + 1.0) + acc + ) # Increment `epars_ptr` pids_inc = tl.load(pids_inc_ptr) @@ -480,6 +526,10 @@ def _fw_triton_block_sparse_tlmm_kernel(node_mars, element_mars, params, nids, c emars_ptr += cids_inc[:,None] * batch_size cids_inc_ptr += TILE_SIZE_K + if propagation_alg_id == 2: + # Compute p_i^{1/alpha} + acc *= (1.0 / alpha) + # Write back off_nids = tl.load(nids + nblock_id) offs_nmars = (off_nids + offs_node[:,None]) * batch_size + offs_batch[None,:] @@ -491,7 +541,8 @@ def _fw_triton_block_sparse_tlmm_kernel(node_mars, element_mars, params, nids, c def _fw_triton_block_sparse_csmm1_kernel(node_mars, element_mars, params, nids, cids_start, cids_increment, pids_start, pids_increment, local_ids, batch_size: tl.constexpr, partial_eval: tl.constexpr, BLOCK_B: tl.constexpr, TILE_SIZE_K: tl.constexpr, K_NUM_TILES: tl.constexpr, - TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, use_fp16: tl.constexpr): + TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, use_fp16: tl.constexpr, + propagation_alg_id: tl.constexpr, alpha = 0.0): pid_b = tl.program_id(0) # ID of size-`BLOCK_B` batches pid_m = tl.program_id(1) # ID of size-`TILE_SIZE_M` nodes @@ -541,22 +592,45 @@ def _fw_triton_block_sparse_csmm1_kernel(node_mars, element_mars, params, nids, epars = tl.load(epars_ptr) emars = tl.load(emars_ptr, mask = mask_batch[None,:]) - emars_max = tl.max(emars, axis = 0)[None,:] - emars_sub = tl.where(emars_max != -float("inf"), tl.exp(emars - emars_max), 0.0) + if propagation_alg_id == 1: + # MPE propagation method + lpars = tl.log(epars) + nmars = tl.max(lpars[:,:,None] + emars[None,:,:], axis = 1) + + acc = tl.maximum(acc, nmars) - if use_fp16 == 1: - # Simulated matmul kernel + float16 - epars = (epars * (2**4)).to(tl.float16) - emars_sub = emars_sub.to(tl.float16) - nmars = tl.sum(epars[:,:,None] * emars_sub[None,:,:], axis = 1).to(tl.float32) / (2**4) else: - # Simulated matmul kernel + float32 - nmars = tl.sum(epars[:,:,None] * emars_sub[None,:,:], axis = 1) - acc = tl.where(emars_max > acc, - tl.log(nmars + tl.exp(acc - emars_max)) + emars_max, - tl.log(tl.exp(emars_max - acc) * nmars + 1.0) + acc - ) + if propagation_alg_id == 0: + # LL propagation method + emars_max = tl.max(emars, axis = 0)[None,:] + emars_sub = tl.where(emars_max != -float("inf"), tl.exp(emars - emars_max), 0.0) + + if propagation_alg_id == 2: + # GeneralLL propagation method + + emars_max = tl.max(emars, axis = 0)[None,:] + # Compute p_i^{alpha} for every i + emars_sub = tl.where(emars_max != -float("inf"), tl.exp((emars - emars_max) * alpha), 0.0) + # Compute w_i^{alpha} for every i + epars = tl.exp(tl.log(epars) * alpha) + + # Also scale `emars_max` + emars_max *= alpha + + if use_fp16 == 1: + # Simulated matmul kernel + float16 + epars = (epars * (2**4)).to(tl.float16) + emars_sub = emars_sub.to(tl.float16) + nmars = tl.sum(epars[:,:,None] * emars_sub[None,:,:], axis = 1).to(tl.float32) / (2**4) + else: + # Simulated matmul kernel + float32 + nmars = tl.sum(epars[:,:,None] * emars_sub[None,:,:], axis = 1) + + acc = tl.where(emars_max > acc, + tl.log(nmars + tl.exp(acc - emars_max) + 1e-24) + emars_max, + tl.log(tl.exp(emars_max - acc) * nmars + 1.0) + acc + ) # Increment `epars_ptr` pids_inc = tl.load(pids_inc_ptr) @@ -568,6 +642,10 @@ def _fw_triton_block_sparse_csmm1_kernel(node_mars, element_mars, params, nids, emars_ptr += cids_inc[:,None] * batch_size cids_inc_ptr += TILE_SIZE_K + if propagation_alg_id == 2: + # Compute p_i^{1/alpha} + acc *= (1.0 / alpha) + # Write back off_nids = tl.load(nids + nblock_id) offs_nmars = (off_nids + offs_node[:,None]) * batch_size + offs_batch[None,:] @@ -579,7 +657,8 @@ def _fw_triton_block_sparse_csmm1_kernel(node_mars, element_mars, params, nids, def _fw_triton_block_sparse_csmm2_kernel(node_mars, element_mars, params, nids, cids_start, cids_increment, pids_start, pids_increment, local_ids, batch_size: tl.constexpr, partial_eval: tl.constexpr, BLOCK_B: tl.constexpr, TILE_SIZE_K: tl.constexpr, K_NUM_TILES: tl.constexpr, - TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, use_fp16: tl.constexpr): + TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, use_fp16: tl.constexpr, + propagation_alg_id: tl.constexpr, alpha = 0.0): pid_b = tl.program_id(0) # ID of size-`BLOCK_B` batches pid_m = tl.program_id(1) # ID of size-`TILE_SIZE_M` nodes @@ -629,16 +708,39 @@ def _fw_triton_block_sparse_csmm2_kernel(node_mars, element_mars, params, nids, epars = tl.load(epars_ptr) emars = tl.load(emars_ptr, mask = mask_batch[:,None]) - emars_max = tl.max(emars, axis = 1) - emars_sub = tl.where(emars_max[:,None] != -float("inf"), tl.exp(emars - emars_max[:,None]), 0.0) + if propagation_alg_id == 1: + # MPE propagation method + lpars = tl.log(epars) + nmars = tl.max(lpars[:,:,None] + tl.trans(emars)[None,:,:], axis = 1) - # Simulated matmul kernel + float32 - nmars = tl.sum(epars[:,:,None] * tl.trans(emars_sub)[None,:,:], axis = 1) + acc = tl.maximum(acc, nmars) - acc = tl.where(emars_max[None,:] > acc, - tl.log(nmars + tl.exp(acc - emars_max[None,:])) + emars_max[None,:], - tl.log(tl.exp(emars_max[None,:] - acc) * nmars + 1.0) + acc - ) + else: + + if propagation_alg_id == 0: + # LL propagation method + emars_max = tl.max(emars, axis = 1) + emars_sub = tl.where(emars_max[:,None] != -float("inf"), tl.exp(emars - emars_max[:,None]), 0.0) + + if propagation_alg_id == 2: + # GeneralLL propagation method + + emars_max = tl.max(emars, axis = 1) + # Compute p_i^{alpha} for every i + emars_sub = tl.where(emars_max[:,None] != -float("inf"), tl.exp((emars - emars_max[:,None]) * alpha), 0.0) + # Compute w_i^{alpha} for every i + epars = tl.exp(tl.log(epars) * alpha) + + # Also scale `emars_max` + emars_max *= alpha + + # Simulated matmul kernel + float32 + nmars = tl.sum(epars[:,:,None] * tl.trans(emars_sub)[None,:,:], axis = 1) + + acc = tl.where(emars_max[None,:] > acc, + tl.log(nmars + tl.exp(acc - emars_max[None,:]) + 1e-24) + emars_max[None,:], + tl.log(tl.exp(emars_max[None,:] - acc) * nmars + 1.0) + acc + ) # Increment `epars_ptr` pids_inc = tl.load(pids_inc_ptr) @@ -650,6 +752,10 @@ def _fw_triton_block_sparse_csmm2_kernel(node_mars, element_mars, params, nids, emars_ptr += cids_inc[None,:] * batch_size cids_inc_ptr += TILE_SIZE_K + if propagation_alg_id == 2: + # Compute p_i^{1/alpha} + acc *= (1.0 / alpha) + # Write back off_nids = tl.load(nids + nblock_id) offs_nmars = (off_nids + offs_node[:,None]) * batch_size + offs_batch[None,:] @@ -659,7 +765,7 @@ def _forward_block_sparse(self, node_mars: torch.Tensor, element_mars: torch.Ten params: torch.Tensor, nids: torch.Tensor, cids: torch.Tensor, pids: torch.Tensor, local_ids: Optional[torch.Tensor] = None, partition_id: int = -1, force_use_fp16: bool = False, - force_use_fp32: bool = False) -> None: + force_use_fp32: bool = False, propagation_alg: str = "LL", **kwargs) -> None: """ Forward pass of sum layers with the block-sparse processing kernel. @@ -680,6 +786,10 @@ def _forward_block_sparse(self, node_mars: torch.Tensor, element_mars: torch.Ten batch_size = node_mars.size(1) BATCH_SIZE_NP2 = triton.next_power_of_2(batch_size) + # Propagation algorithm + propagation_alg_id = self.propagation_alg_mapping[propagation_alg] + propagation_alg_kwargs = self._get_propagation_alg_kwargs(propagation_alg, **kwargs) + # Heuristic to set `TILE_SIZE_M`, `TILE_SIZE_K`, and `BLOCK_B` base_size = min(self.block_size, num_edges, BATCH_SIZE_NP2, 128) if base_size >= 64: @@ -751,7 +861,9 @@ def _forward_block_sparse(self, node_mars: torch.Tensor, element_mars: torch.Ten K_NUM_TILES = K_NUM_TILES, TILE_SIZE_M = TILE_SIZE_M, BLOCK_SIZE_M = BLOCK_SIZE_M, - use_fp16 = use_fp16 + use_fp16 = use_fp16, + propagation_alg_id = propagation_alg_id, + **propagation_alg_kwargs ) elif TILE_SIZE_M >= 8 and TILE_SIZE_K >= 8 and BLOCK_B >= 8: @@ -772,8 +884,11 @@ def _forward_block_sparse(self, node_mars: torch.Tensor, element_mars: torch.Ten K_NUM_TILES = K_NUM_TILES, TILE_SIZE_M = TILE_SIZE_M, BLOCK_SIZE_M = BLOCK_SIZE_M, - use_fp16 = use_fp16 + use_fp16 = use_fp16, + propagation_alg_id = propagation_alg_id, + **propagation_alg_kwargs ) + else: self._fw_triton_block_sparse_csmm2_kernel[grid]( node_mars, @@ -792,7 +907,9 @@ def _forward_block_sparse(self, node_mars: torch.Tensor, element_mars: torch.Ten K_NUM_TILES = K_NUM_TILES, TILE_SIZE_M = TILE_SIZE_M, BLOCK_SIZE_M = BLOCK_SIZE_M, - use_fp16 = use_fp16 + use_fp16 = use_fp16, + propagation_alg_id = propagation_alg_id, + **propagation_alg_kwargs ) return None @@ -802,7 +919,7 @@ def _forward_block_sparse(self, node_mars: torch.Tensor, element_mars: torch.Ten @FastJITFunction def _fw_triton_sparse_kernel(node_mars, element_mars, params, nids, cids, pids, local_ids, batch_size, partial_eval: tl.constexpr, num_edges: tl.constexpr, - BLOCK_B: tl.constexpr, BLOCK_SIZE_M: tl.constexpr): + BLOCK_B: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, propagation_alg_id: tl.constexpr, alpha = 0.0): pid_b = tl.program_id(axis = 0) # ID of size-`BLOCK_B` batches pid_m = tl.program_id(axis = 1) # ID of size-`BLOCK_SIZE_M` nodes @@ -830,9 +947,15 @@ def _fw_triton_sparse_kernel(node_mars, element_mars, params, nids, cids, pids, offs_batch[None,:] emars = tl.load(emars_ptr, mask = mask_batch[None,:]) # [num_edges, BLOCK_B] - # Compute max and subtract - emars_max = tl.max(emars, axis = 0) - emars = tl.exp(emars - emars_max[None,:]) + # Compute max and subtract (only when using LL or GeneralLL propagation method) + if propagation_alg_id == 0: + emars_max = tl.max(emars, axis = 0) + emars = tl.exp(emars - emars_max[None,:]) + + if propagation_alg_id == 2: + emars_max = tl.max(emars, axis = 0) + emars = tl.exp((emars - emars_max[None,:]) * alpha) + emars_max *= alpha # Initialize pointers to `node_mars` off_nids = tl.load(nids + nblock_id) @@ -844,7 +967,16 @@ def _fw_triton_sparse_kernel(node_mars, element_mars, params, nids, cids, pids, for i in range(0, BLOCK_SIZE_M): epars = tl.load(epars_ptr) - nmars = tl.log(tl.sum(emars * epars[:,None], axis = 0)) + emars_max + if propagation_alg_id == 0: + nmars = tl.log(tl.sum(emars * epars[:,None], axis = 0)) + emars_max + + if propagation_alg_id == 1: + nmars = tl.max(emars + tl.log(epars)[:,None], axis = 0) + + if propagation_alg_id == 2: + epars = tl.exp(tl.log(epars) * alpha) + + nmars = (tl.log(tl.sum(emars * epars[:,None], axis = 0)) + emars_max) * (1.0 / alpha) tl.store(nmars_ptr, nmars, mask = mask_batch) @@ -857,10 +989,9 @@ def _fw_triton_sparse_kernel(node_mars, element_mars, params, nids, cids, pids, @staticmethod # @triton.jit @FastJITFunction - def _fw_triton_large_sparse_kernel(node_mars, element_mars, params, nids, cids, pids, - local_ids, batch_size, num_nodes, pid_m_offset, partial_eval: tl.constexpr, num_edges: tl.constexpr, - BLOCK_B: tl.constexpr, TILE_SIZE_M: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr): + def _fw_triton_large_sparse_kernel(node_mars, element_mars, params, nids, cids, pids, local_ids, batch_size, + num_nodes, pid_m_offset, partial_eval: tl.constexpr, num_edges: tl.constexpr, BLOCK_B: tl.constexpr, + TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, propagation_alg_id: tl.constexpr, alpha = 0.0): pid_b = tl.program_id(axis = 0) # ID of size-`BLOCK_B` batches pid_m = tl.program_id(axis = 1) + pid_m_offset # ID of size-`TILE_SIZE_M` nodes @@ -892,10 +1023,27 @@ def _fw_triton_large_sparse_kernel(node_mars, element_mars, params, nids, cids, offs_batch[None,None,:] # [TILE_SIZE_M, num_edges, BLOCK_B] emars = tl.load(emars_ptr, mask = (mask_m[:,None,None] & mask_batch[None,None,:]), other = 0.0) # [TILE_SIZE_M, num_edges, BLOCK_B] - # Compute max and subtract - emars_max = tl.max(emars, axis = 1) - emars = tl.exp(emars - emars_max[:,None,:]) - nmars = tl.log(tl.sum(emars * epars[:,:,None], axis = 1)) + emars_max + # Compute max and subtract (only when using LL or GeneralLL propagation method) + if propagation_alg_id == 0: + emars_max = tl.max(emars, axis = 1) + emars = tl.exp(emars - emars_max[:,None,:]) + + if propagation_alg_id == 2: + emars_max = tl.max(emars, axis = 1) + emars = tl.exp((emars - emars_max[:,None,:]) * alpha) + emars_max *= alpha + + # Compute sum node marginals + if propagation_alg_id == 0: + nmars = tl.log(tl.sum(emars * epars[:,:,None], axis = 1)) + emars_max + + if propagation_alg_id == 1: + nmars = tl.max(emars + tl.log(epars)[:,:,None], axis = 1) + + if propagation_alg_id == 2: + epars = tl.exp(tl.log(epars) * alpha) + + nmars = (tl.log(tl.sum(emars * epars[:,:,None], axis = 1)) + emars_max) * (1.0 / alpha) # Initialize pointers to `node_mars` off_nids = tl.load(nids + nblock_ids) # [TILE_SIZE_M] @@ -908,7 +1056,7 @@ def _fw_triton_large_sparse_kernel(node_mars, element_mars, params, nids, cids, def _forward_sparse(self, node_mars: torch.Tensor, element_mars: torch.Tensor, params: torch.Tensor, nids: torch.Tensor, cids: torch.Tensor, pids: torch.Tensor, local_ids: Optional[torch.Tensor] = None, - partition_id: int = -1) -> None: + partition_id: int = -1, propagation_alg: str = "LL", **kwargs) -> None: """ Forward pass of sum layers with the sparse processing kernel. @@ -927,7 +1075,9 @@ def _forward_sparse(self, node_mars: torch.Tensor, element_mars: torch.Tensor, batch_size = node_mars.size(1) BATCH_SIZE_NP2 = triton.next_power_of_2(batch_size) - # assert num_edges <= 16384, "The sparse forward kernel only support nodes with # edges smaller than 16384." + # Propagation algorithm + propagation_alg_id = self.propagation_alg_mapping[propagation_alg] + propagation_alg_kwargs = self._get_propagation_alg_kwargs(propagation_alg, **kwargs) if triton.cdiv(layer_n_nodes, self.block_size) <= 2048: BLOCK_B = max(min(2048 // num_edges, BATCH_SIZE_NP2), 1) @@ -949,7 +1099,9 @@ def _forward_sparse(self, node_mars: torch.Tensor, element_mars: torch.Tensor, partial_eval = partial_eval, num_edges = num_edges, BLOCK_B = BLOCK_B, - BLOCK_SIZE_M = BLOCK_SIZE_M + BLOCK_SIZE_M = BLOCK_SIZE_M, + propagation_alg_id = propagation_alg_id, + **propagation_alg_kwargs ) else: @@ -977,7 +1129,9 @@ def _forward_sparse(self, node_mars: torch.Tensor, element_mars: torch.Tensor, num_edges = num_edges, BLOCK_B = BLOCK_B, TILE_SIZE_M = TILE_SIZE_M, - BLOCK_SIZE_M = BLOCK_SIZE_M + BLOCK_SIZE_M = BLOCK_SIZE_M, + propagation_alg_id = propagation_alg_id, + **propagation_alg_kwargs ) else: for pid_m_start in range(0, grid[1], 32768): @@ -1000,7 +1154,9 @@ def _forward_sparse(self, node_mars: torch.Tensor, element_mars: torch.Tensor, num_edges = num_edges, BLOCK_B = BLOCK_B, TILE_SIZE_M = TILE_SIZE_M, - BLOCK_SIZE_M = BLOCK_SIZE_M + BLOCK_SIZE_M = BLOCK_SIZE_M, + propagation_alg_id = propagation_alg_id, + **propagation_alg_kwargs ) return None @@ -1009,7 +1165,7 @@ def _forward_sparse(self, node_mars: torch.Tensor, element_mars: torch.Tensor, @torch.compile def _forward_pytorch_kernel(node_mars: torch.Tensor, element_mars: torch.Tensor, params: torch.Tensor, nids: torch.Tensor, cids: torch.Tensor, pids: torch.Tensor, - local_ids: torch.Tensor): + local_ids: torch.Tensor, propagation_alg_id: int, alpha: float = 0.0): if local_ids is not None: nids = nids[local_ids] @@ -1025,18 +1181,33 @@ def _forward_pytorch_kernel(node_mars: torch.Tensor, element_mars: torch.Tensor, torch.arange(0, self.block_size, device = cids.device)[None,:,None]).reshape(num_nblocks * self.block_size, num_edges) ch_mars = element_mars[cids] - maxval = ch_mars.max(dim = 1, keepdim = True).values - node_mars[nids] = (((ch_mars - maxval).exp() * params[pids].unsqueeze(-1)).sum( - dim = 1).clamp(min = 1e-10)).log() + maxval.squeeze(1) + + if propagation_alg_id == 0: + maxval = ch_mars.max(dim = 1, keepdim = True).values + node_mars[nids] = (((ch_mars - maxval).exp() * params[pids].unsqueeze(-1)).sum( + dim = 1).clamp(min = 1e-10)).log() + maxval.squeeze(1) + + elif propagation_alg_id == 1: + node_mars[nids] = (ch_mars + params[pids].log().unsqueeze(-1)).max(dim = 1).values + + elif propagation_alg_id == 2: + maxval = ch_mars.max(dim = 1, keepdim = True).values + node_mars[nids] = ((((ch_mars - maxval).exp() * params[pids].unsqueeze(-1)) ** alpha).sum( + dim = 1).clamp(min = 1e-10)).log() ** (1.0 / alpha) + maxval.squeeze(1) return None def _forward_pytorch(node_mars: torch.Tensor, element_mars: torch.Tensor, params: torch.Tensor, nids: torch.Tensor, cids: torch.Tensor, pids: torch.Tensor, - local_ids: torch.Tensor): + local_ids: torch.Tensor, propagation_alg: str = "LL", **kwargs): + + # Propagation algorithm + propagation_alg_id = self.propagation_alg_mapping[propagation_alg] + propagation_alg_kwargs = self._get_propagation_alg_kwargs(propagation_alg, **kwargs) self._forward_pytorch_kernel( - node_mars, element_mars, params, nids, cids, pids, local_ids + node_mars, element_mars, params, nids, cids, pids, local_ids, + propagation_alg_id = propagation_alg_id, **propagation_alg_kwargs ) def _backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, @@ -1048,7 +1219,10 @@ def _backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, parpids: Optional[torch.Tensor] = None, cs_block_size: int = 0, local_ids: Optional[torch.Tensor] = None, partition_id: int = -1, mode: Optional[str] = None, - allow_modify_flows: bool = False) -> None: + allow_modify_flows: bool = False, + propagation_alg: str = "LL", + logspace_flows: bool = False, + negate_pflows: bool = False, **kwargs) -> None: """ Back pass of sum layers. @@ -1086,20 +1260,24 @@ def _backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, elif num_edges <= 32768: mode = self.BLOCK_SPARSE else: - mode = self.SPARSE + mode = self.BLOCK_SPARSE if mode == self.BLOCK_SPARSE: self._backward_block_sparse( node_flows, element_flows, params, node_mars, element_mars, param_flows, nids, cids, pids, pfids, chids, parids, parpids, cs_block_size, local_ids, - partition_id = partition_id, allow_modify_flows = allow_modify_flows + partition_id = partition_id, allow_modify_flows = allow_modify_flows, + propagation_alg = propagation_alg, logspace_flows = logspace_flows, + negate_pflows = negate_pflows, **kwargs ) elif mode == self.SPARSE: self._backward_sparse( node_flows, element_flows, params, node_mars, element_mars, param_flows, nids, cids, pids, pfids, chids, parids, parpids, cs_block_size, local_ids, - partition_id = partition_id, allow_modify_flows = allow_modify_flows + partition_id = partition_id, allow_modify_flows = allow_modify_flows, + propagation_alg = propagation_alg, logspace_flows = logspace_flows, + negate_pflows = negate_pflows, **kwargs ) elif mode == self.PYTORCH: @@ -1108,7 +1286,9 @@ def _backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, self._backward_pytorch( node_flows, element_flows, params, node_mars, element_mars, param_flows, nids, cids, pids, pfids, - chids, parids, parpids, cs_block_size + chids, parids, parpids, cs_block_size, + propagation_alg = propagation_alg, + negate_pflows = negate_pflows, **kwargs ) else: raise ValueError(f"Not supported mode `{mode}`.") @@ -1119,7 +1299,7 @@ def _backward(self, node_flows: torch.Tensor, element_flows: torch.Tensor, # @triton.jit @FastJITFunction def _bk_triton_modify_flow_kernel(node_flows, node_mars, local_ids, nids, batch_size: tl.constexpr, partial_eval: tl.constexpr, - BLOCK_B: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr): + BLOCK_B: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, propagation_alg_id: tl.constexpr, alpha = 0.0): pid_b = tl.program_id(0) # ID of size-`BLOCK_B` examples pid_m = tl.program_id(1) # ID of size-`BLOCK_M` nodes @@ -1144,7 +1324,16 @@ def _bk_triton_modify_flow_kernel(node_flows, node_mars, local_ids, nids, batch_ nmars = tl.load(node_mars + offs_nmfs, mask = mask_batch[None,:]) nflows = tl.load(node_flows + offs_nmfs, mask = mask_batch[None,:]) - uflows = tl.where(nmars != -float("inf"), tl.log(nflows) - nmars, -float("inf")) + if propagation_alg_id == 0: + lflows = tl.log(nflows) + uflows = tl.where(nmars != -float("inf"), lflows - nmars, -float("inf")) + + if propagation_alg_id == 1: + uflows = nflows + + if propagation_alg_id == 2: + lflows = tl.log(nflows) + uflows = tl.where(nmars != -float("inf"), lflows - nmars * alpha, -float("inf")) tl.store(node_flows + offs_nmfs, uflows, mask = mask_batch[None,:]) @@ -1152,7 +1341,7 @@ def _bk_triton_modify_flow_kernel(node_flows, node_mars, local_ids, nids, batch_ # @triton.jit @FastJITFunction def _bk_triton_large_modify_flow_kernel(node_flows, node_mars, local_ids, nids, num_nodes, batch_size: tl.constexpr, partial_eval: tl.constexpr, - BLOCK_B: tl.constexpr, TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr): + BLOCK_B: tl.constexpr, TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, propagation_alg_id: tl.constexpr, alpha = 0.0): pid_b = tl.program_id(0) # ID of size-`BLOCK_B` examples pid_m = tl.program_id(1) # ID of size-`TILE_SIZE_M` nodes @@ -1179,12 +1368,22 @@ def _bk_triton_large_modify_flow_kernel(node_flows, node_mars, local_ids, nids, nmars = tl.load(node_mars + offs_nmfs, mask = (mask_m[:,None] & mask_batch[None,:])) nflows = tl.load(node_flows + offs_nmfs, mask = (mask_m[:,None] & mask_batch[None,:])) - uflows = tl.where(nmars != -float("inf"), tl.log(nflows) - nmars, -float("inf")) + if propagation_alg_id == 0: + lflows = tl.log(nflows) + uflows = tl.where(nmars != -float("inf"), lflows - nmars, -float("inf")) + + if propagation_alg_id == 1: + uflows = nflows + + if propagation_alg_id == 2: + lflows = tl.log(nflows) + uflows = tl.where(nmars != -float("inf"), lflows - nmars * alpha, -float("inf")) tl.store(node_flows + offs_nmfs, uflows, mask = (mask_m[:,None] & mask_batch[None,:])) def _bk_triton_modify_flow(self, node_flows: torch.Tensor, node_mars: torch.Tensor, - nids: torch.Tensor, local_ids: Optional[torch.Tensor] = None): + nids: torch.Tensor, local_ids: Optional[torch.Tensor] = None, + propagation_alg: str = "LL", **kwargs): """ Replace `node_flows[nids]` with `node_flows[nids].log() - node_mars[nids]` """ @@ -1194,6 +1393,10 @@ def _bk_triton_modify_flow(self, node_flows: torch.Tensor, node_mars: torch.Tens batch_size = node_mars.size(1) BATCH_SIZE_NP2 = triton.next_power_of_2(batch_size) + # Propagation algorithm + propagation_alg_id = self.propagation_alg_mapping[propagation_alg] + propagation_alg_kwargs = self._get_propagation_alg_kwargs(propagation_alg, **kwargs) + if triton.cdiv(layer_n_nodes, self.block_size) <= 4096: if BATCH_SIZE_NP2 >= 64 and self.block_size >= 64: @@ -1217,7 +1420,9 @@ def _bk_triton_modify_flow(self, node_flows: torch.Tensor, node_mars: torch.Tens partial_eval = partial_eval, BLOCK_B = BLOCK_B, BLOCK_M = BLOCK_M, - BLOCK_SIZE_M = BLOCK_SIZE_M + BLOCK_SIZE_M = BLOCK_SIZE_M, + propagation_alg_id = propagation_alg_id, + **propagation_alg_kwargs ) else: @@ -1240,7 +1445,9 @@ def _bk_triton_modify_flow(self, node_flows: torch.Tensor, node_mars: torch.Tens partial_eval = partial_eval, BLOCK_B = BLOCK_B, TILE_SIZE_M = TILE_SIZE_M, - BLOCK_SIZE_M = BLOCK_SIZE_M + BLOCK_SIZE_M = BLOCK_SIZE_M, + propagation_alg_id = propagation_alg_id, + **propagation_alg_kwargs ) return None @@ -1251,7 +1458,8 @@ def _backward_block_sparse(self, node_flows: torch.Tensor, element_flows: torch. nids: Optional[torch.Tensor], cids: Optional[torch.Tensor], pids: Optional[torch.Tensor], pfids: Optional[torch.Tensor], chids: Optional[torch.Tensor], parids: Optional[torch.Tensor], parpids: Optional[torch.Tensor], cs_block_size: int, local_ids: Optional[torch.Tensor] = None, - partition_id: int = -1, allow_modify_flows: bool = False) -> None: + partition_id: int = -1, allow_modify_flows: bool = False, propagation_alg: str = "LL", + logspace_flows: bool = False, negate_pflows: bool = False, **kwargs) -> None: """ Back pass of sum layers with block-sparse processing kernel. @@ -1273,7 +1481,9 @@ def _backward_block_sparse(self, node_flows: torch.Tensor, element_flows: torch. node_flows, element_flows, params, node_mars, element_mars, chids = chids, parids = parids, parpids = parpids, cs_block_size = cs_block_size, local_ids = local_ids, - partition_id = partition_id, allow_modify_flows = allow_modify_flows + partition_id = partition_id, allow_modify_flows = allow_modify_flows, + propagation_alg = propagation_alg, + logspace_flows = logspace_flows, **kwargs ) # Flows w.r.t. parameters @@ -1281,7 +1491,10 @@ def _backward_block_sparse(self, node_flows: torch.Tensor, element_flows: torch. self._backward_block_sparse_par_flows( node_flows, params, node_mars, element_mars, param_flows, nids = nids, cids = cids, pids = pids, pfids = pfids, - allow_modify_flows = allow_modify_flows + allow_modify_flows = allow_modify_flows, + propagation_alg = propagation_alg, + logspace_flows = logspace_flows, + negate_pflows = negate_pflows, **kwargs ) return None @@ -1292,9 +1505,10 @@ def _backward_block_sparse(self, node_flows: torch.Tensor, element_flows: torch. def _bk_triton_block_sparse_ele_kernel(node_flows, element_flows, node_mars, element_mars, params, chids, parids_start, parids_increment, parpids_start, parpids_increment, local_ids, batch_size: tl.constexpr, partial_eval: tl.constexpr, ptr_inc_step: tl.constexpr, - allow_modify_flows: tl.constexpr, BLOCK_B: tl.constexpr, TILE_SIZE_K: tl.constexpr, - K_NUM_TILES: tl.constexpr, TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, TL_DOT: tl.constexpr): + allow_modify_flows: tl.constexpr, logspace_flows: tl.constexpr, BLOCK_B: tl.constexpr, + TILE_SIZE_K: tl.constexpr, K_NUM_TILES: tl.constexpr, TILE_SIZE_M: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, TL_DOT: tl.constexpr, + propagation_alg_id: tl.constexpr, alpha = 0.0): pid_b = tl.program_id(0) # ID of size-`BLOCK_B` batches pid_m = tl.program_id(1) # ID of size-`TILE_SIZE_M` nodes @@ -1334,40 +1548,88 @@ def _bk_triton_block_sparse_ele_kernel(node_flows, element_flows, node_mars, ele parids_inc_ptr = parids_increment + eleblock_id * (K_NUM_TILES * ptr_inc_step) + offs_edge_gid parpids_inc_ptr = parpids_increment + eleblock_id * (K_NUM_TILES * ptr_inc_step) + offs_edge_gid + # Initialize pointers to `element_mars` (only when using MPE propagation method) + off_eleids = tl.load(chids + eleblock_id) + emars_ptr = element_mars + (off_eleids + offs_ele[:,None]) * batch_size + offs_batch[None,:] + emars = tl.load(emars_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_M, BLOCK_B] + + if propagation_alg_id == 2: + emars *= alpha + # Inner loop - acc = tl.zeros([TILE_SIZE_M, BLOCK_B], dtype = tl.float32) - float("inf") + if logspace_flows: + acc = tl.zeros([TILE_SIZE_M, BLOCK_B], dtype = tl.float32) - float("inf") + else: + acc = tl.zeros([TILE_SIZE_M, BLOCK_B], dtype = tl.float32) for k in range(0, K_NUM_TILES): epars = tl.load(epars_ptr) # [TILE_SIZE_M, TILE_SIZE_K] - if allow_modify_flows == 1: - log_n_fdm = tl.load(nflows_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_K, BLOCK_B] - else: + if propagation_alg_id == 1: nflows = tl.load(nflows_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_K, BLOCK_B] nmars = tl.load(nmars_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_K, BLOCK_B] - log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars) - - log_n_fdm_max = tl.max(log_n_fdm, axis = 0)[None,:] - n_fdm_sub = tl.where(log_n_fdm_max != -float("inf"), tl.exp(log_n_fdm - log_n_fdm_max), 0.0) - - if TL_DOT == 1: - partial_flows = tl.dot(epars, n_fdm_sub) + elpars = tl.log(tl.trans(epars)) # [TILE_SIZE_K, TILE_SIZE_M] + + eflows = tl.sum(tl.where(tl.abs(elpars[:,:,None] + emars[None,:,:] - nmars[:,None,:]) < 1e-6, nflows[:,None,:], 0.0), axis = 0) + + if logspace_flows: + # logaddexp + diff = acc - eflows + acc = tl.where( + diff == 0, + acc + 0.69314718055994530942, # log(2) + tl.where( + diff > 0, + acc + tlmath.log1p(tl.exp(-diff)), + eflows + tlmath.log1p(tl.exp(diff)) + ) + ) + else: + # sum + acc += eflows else: - partial_flows = tl.sum(epars[:,:,None] * n_fdm_sub[None,:,:], axis = 1) - acc = tl.where(log_n_fdm_max == acc, - acc + 0.69314718056, # log(2) - tl.where(log_n_fdm_max > acc, - tl.log(partial_flows + tl.exp(acc - log_n_fdm_max)) + log_n_fdm_max, - tl.log(tl.exp(log_n_fdm_max - acc) * partial_flows + 1.0) + acc - ) - ) - # neginf_flag = (log_n_fdm_max == -float("inf")) & (acc == -float("inf")) - # acc = tl.where(log_n_fdm_max > acc, - # tl.log(partial_flows + tl.exp(acc - log_n_fdm_max)) + log_n_fdm_max, - # tl.log(tl.exp(log_n_fdm_max - acc) * partial_flows + 1.0) + acc - # ) - # acc = tl.where(neginf_flag, -float("inf"), acc) + if propagation_alg_id == 2: + epars = tl.exp(tl.log(epars) * alpha) + + if allow_modify_flows == 1: + log_n_fdm = tl.load(nflows_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_K, BLOCK_B] + else: + nflows = tl.load(nflows_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_K, BLOCK_B] + nmars = tl.load(nmars_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_K, BLOCK_B] + + if logspace_flows: + if propagation_alg_id == 0: + log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), nflows - nmars) + + if propagation_alg_id == 2: + log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), nflows - nmars * alpha) + else: + if propagation_alg_id == 0: + log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars) + + if propagation_alg_id == 2: + log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars * alpha) + + log_n_fdm_max = tl.max(log_n_fdm, axis = 0)[None,:] + n_fdm_sub = tl.where(log_n_fdm_max != -float("inf"), tl.exp(log_n_fdm - log_n_fdm_max), 0.0) + + if TL_DOT == 1: + partial_flows = tl.dot(epars, n_fdm_sub) + else: + partial_flows = tl.sum(epars[:,:,None] * n_fdm_sub[None,:,:], axis = 1) + + if logspace_flows: + partial_flows_max = emars + log_n_fdm_max + acc = tl.where(log_n_fdm_max == -float("inf"), + acc, + tl.where(partial_flows_max > acc, + tl.log(partial_flows + tl.exp(acc - partial_flows_max) + 1e-24) + partial_flows_max, + tl.log(tl.exp(partial_flows_max - acc) * partial_flows + 1.0) + acc + ) + ) + else: + acc += partial_flows * tl.exp(emars + log_n_fdm_max) # Increment `epars_ptr` parpids_inc = tl.load(parpids_inc_ptr) @@ -1381,16 +1643,9 @@ def _bk_triton_block_sparse_ele_kernel(node_flows, element_flows, node_mars, ele nflows_ptr += parids_inc[:,None] * batch_size parids_inc_ptr += ptr_inc_step - # Initialize pointers to `element_mars` - off_eleids = tl.load(chids + eleblock_id) - emars_ptr = element_mars + (off_eleids + offs_ele[:,None]) * batch_size + offs_batch[None,:] - emars = tl.load(emars_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_K, BLOCK_B] - - eflows = tl.exp(acc + emars) - # Write back offs_elemfs = (off_eleids + offs_ele[:,None]) * batch_size + offs_batch[None,:] - tl.store(element_flows + offs_elemfs, eflows, mask = mask_batch[None,:]) + tl.store(element_flows + offs_elemfs, acc, mask = mask_batch[None,:]) @staticmethod # @triton.jit @@ -1398,9 +1653,10 @@ def _bk_triton_block_sparse_ele_kernel(node_flows, element_flows, node_mars, ele def _bk_triton_block_sparse_ele_csmm2_kernel(node_flows, element_flows, node_mars, element_mars, params, chids, parids_start, parids_increment, parpids_start, parpids_increment, local_ids, batch_size: tl.constexpr, partial_eval: tl.constexpr, ptr_inc_step: tl.constexpr, - allow_modify_flows: tl.constexpr, BLOCK_B: tl.constexpr, TILE_SIZE_K: tl.constexpr, - K_NUM_TILES: tl.constexpr, TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, TL_DOT: tl.constexpr): + allow_modify_flows: tl.constexpr, logspace_flows: tl.constexpr, BLOCK_B: tl.constexpr, + TILE_SIZE_K: tl.constexpr, K_NUM_TILES: tl.constexpr, TILE_SIZE_M: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, TL_DOT: tl.constexpr, + propagation_alg_id: tl.constexpr, alpha = 0.0): pid_b = tl.program_id(0) # ID of size-`BLOCK_B` batches pid_m = tl.program_id(1) # ID of size-`TILE_SIZE_M` nodes @@ -1440,37 +1696,87 @@ def _bk_triton_block_sparse_ele_csmm2_kernel(node_flows, element_flows, node_mar parids_inc_ptr = parids_increment + eleblock_id * (K_NUM_TILES * ptr_inc_step) + offs_edge_gid parpids_inc_ptr = parpids_increment + eleblock_id * (K_NUM_TILES * ptr_inc_step) + offs_edge_gid + # Initialize pointers to `element_mars` + off_eleids = tl.load(chids + eleblock_id) + emars_ptr = element_mars + (off_eleids + offs_ele[:,None]) * batch_size + offs_batch[None,:] + emars = tl.load(emars_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_M, BLOCK_B] + + if propagation_alg_id == 2: + emars *= alpha + # Inner loop - acc = tl.zeros([TILE_SIZE_M, BLOCK_B], dtype = tl.float32) - float("inf") + if logspace_flows: + acc = tl.zeros([TILE_SIZE_M, BLOCK_B], dtype = tl.float32) - float("inf") + else: + acc = tl.zeros([TILE_SIZE_M, BLOCK_B], dtype = tl.float32) for k in range(0, K_NUM_TILES): epars = tl.load(epars_ptr) # [TILE_SIZE_M, TILE_SIZE_K] - if allow_modify_flows == 1: - log_n_fdm = tl.load(nflows_ptr, mask = mask_batch[:,None]) # [BLOCK_B, TILE_SIZE_K] - else: + if propagation_alg_id == 1: nflows = tl.load(nflows_ptr, mask = mask_batch[:,None]) # [BLOCK_B, TILE_SIZE_K] nmars = tl.load(nmars_ptr, mask = mask_batch[:,None]) # [BLOCK_B, TILE_SIZE_K] - log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars) - - log_n_fdm_max = tl.max(log_n_fdm, axis = 1) - n_fdm_sub = tl.where(log_n_fdm_max[:,None] != -float("inf"), tl.exp(log_n_fdm - log_n_fdm_max[:,None]), 0.0) + elpars = tl.log(tl.trans(epars)) # [TILE_SIZE_K, TILE_SIZE_M] + + eflows = tl.sum(tl.where(tl.abs(elpars[None,:,:] + tl.trans(emars)[:,None,:] - nmars[:,:,None]) < 1e-6, nflows[:,:,None], 0.0), axis = 1) + eflows = tl.trans(eflows) + + if logspace_flows: + # logaddexp + diff = acc - eflows + acc = tl.where( + diff == 0, + acc + 0.69314718055994530942, # log(2) + tl.where( + diff > 0, + acc + tlmath.log1p(tl.exp(-diff)), + eflows + tlmath.log1p(tl.exp(diff)) + ) + ) + else: + # sum + acc += eflows - partial_flows = tl.sum(epars[:,:,None] * tl.trans(n_fdm_sub)[None,:,:], axis = 1) + else: - acc = tl.where(log_n_fdm_max[None,:] == acc, - acc + 0.69314718056, # log(2) - tl.where(log_n_fdm_max[None,:] > acc, - tl.log(partial_flows + tl.exp(acc - log_n_fdm_max[None,:])) + log_n_fdm_max[None,:], - tl.log(tl.exp(log_n_fdm_max[None,:] - acc) * partial_flows + 1.0) + acc - ) - ) - # neginf_flag = (log_n_fdm_max[None,:] == -float("inf")) & (acc == -float("inf")) - # acc = tl.where(log_n_fdm_max[None,:] > acc, - # tl.log(partial_flows + tl.exp(acc - log_n_fdm_max[None,:])) + log_n_fdm_max[None,:], - # tl.log(tl.exp(log_n_fdm_max[None,:] - acc) * partial_flows + 1.0) + acc - # ) - # acc = tl.where(neginf_flag, -float("inf"), acc) + if propagation_alg_id == 2: + epars = tl.exp(tl.log(epars) * alpha) + + if allow_modify_flows == 1: + log_n_fdm = tl.load(nflows_ptr, mask = mask_batch[:,None]) # [BLOCK_B, TILE_SIZE_K] + else: + nflows = tl.load(nflows_ptr, mask = mask_batch[:,None]) # [BLOCK_B, TILE_SIZE_K] + nmars = tl.load(nmars_ptr, mask = mask_batch[:,None]) # [BLOCK_B, TILE_SIZE_K] + + if logspace_flows: + if propagation_alg_id == 0: + log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), nflows - nmars) + + if propagation_alg_id == 2: + log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), nflows - nmars * alpha) + else: + if propagation_alg_id == 0: + log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars) + + if propagation_alg_id == 2: + log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars * alpha) + + log_n_fdm_max = tl.max(log_n_fdm, axis = 1) + n_fdm_sub = tl.where(log_n_fdm_max[:,None] != -float("inf"), tl.exp(log_n_fdm - log_n_fdm_max[:,None]), 0.0) + + partial_flows = tl.sum(epars[:,:,None] * tl.trans(n_fdm_sub)[None,:,:], axis = 1) + + if logspace_flows: + partial_flows_max = emars + log_n_fdm_max[None,:] + acc = tl.where(log_n_fdm_max[None,:] == -float("inf"), + acc, + tl.where(partial_flows_max > acc, + tl.log(partial_flows + tl.exp(acc - partial_flows_max) + 1e-24) + partial_flows_max, + tl.log(tl.exp(partial_flows_max - acc) * partial_flows + 1.0) + acc + ) + ) + else: + acc += partial_flows * tl.exp(emars + log_n_fdm_max[None,:]) # Increment `epars_ptr` parpids_inc = tl.load(parpids_inc_ptr) @@ -1484,22 +1790,16 @@ def _bk_triton_block_sparse_ele_csmm2_kernel(node_flows, element_flows, node_mar nflows_ptr += parids_inc[None,:] * batch_size parids_inc_ptr += ptr_inc_step - # Initialize pointers to `element_mars` - off_eleids = tl.load(chids + eleblock_id) - emars_ptr = element_mars + (off_eleids + offs_ele[:,None]) * batch_size + offs_batch[None,:] - emars = tl.load(emars_ptr, mask = mask_batch[None,:]) # [TILE_SIZE_K, BLOCK_B] - - eflows = tl.exp(acc + emars) - # Write back offs_elemfs = (off_eleids + offs_ele[:,None]) * batch_size + offs_batch[None,:] - tl.store(element_flows + offs_elemfs, eflows, mask = mask_batch[None,:]) + tl.store(element_flows + offs_elemfs, acc, mask = mask_batch[None,:]) def _backward_block_sparse_ele_flows(self, node_flows: torch.Tensor, element_flows: torch.Tensor, params: torch.Tensor, node_mars: torch.Tensor, element_mars: torch.Tensor, chids: torch.Tensor, parids: torch.Tensor, parpids: torch.Tensor, cs_block_size: int, local_ids: Optional[torch.Tensor] = None, - partition_id: int = -1, allow_modify_flows: bool = False) -> None: + partition_id: int = -1, allow_modify_flows: bool = False, + propagation_alg: str = "LL", logspace_flows: bool = False, **kwargs) -> None: assert params.dim() == 1, "Expecting a 1D `params`." @@ -1509,6 +1809,10 @@ def _backward_block_sparse_ele_flows(self, node_flows: torch.Tensor, element_flo batch_size = node_flows.size(1) BATCH_SIZE_NP2 = triton.next_power_of_2(batch_size) + # Propagation algorithm + propagation_alg_id = self.propagation_alg_mapping[propagation_alg] + propagation_alg_kwargs = self._get_propagation_alg_kwargs(propagation_alg, **kwargs) + # Heuristic to set `TILE_SIZE_M`, `TILE_SIZE_K`, and `BLOCK_B` base_size = min(self.block_size, num_edges, BATCH_SIZE_NP2, 64) if base_size >= 64: @@ -1590,6 +1894,7 @@ def _backward_block_sparse_ele_flows(self, node_flows: torch.Tensor, element_flo partial_eval = partial_eval, ptr_inc_step = ptr_inc_step, allow_modify_flows = allow_modify_flows, + logspace_flows = logspace_flows, BLOCK_B = BLOCK_B, TILE_SIZE_K = TILE_SIZE_K, K_NUM_TILES = K_NUM_TILES, @@ -1598,7 +1903,9 @@ def _backward_block_sparse_ele_flows(self, node_flows: torch.Tensor, element_flo BLOCK_SIZE_K = BLOCK_SIZE_K, TL_DOT = TL_DOT, num_warps = 2, # TODO: test for different devices - num_stages = 1 + num_stages = 1, + propagation_alg_id = propagation_alg_id, + **propagation_alg_kwargs ) else: self._bk_triton_block_sparse_ele_csmm2_kernel[grid]( @@ -1617,6 +1924,7 @@ def _backward_block_sparse_ele_flows(self, node_flows: torch.Tensor, element_flo partial_eval = partial_eval, ptr_inc_step = ptr_inc_step, allow_modify_flows = allow_modify_flows, + logspace_flows = logspace_flows, BLOCK_B = BLOCK_B, TILE_SIZE_K = TILE_SIZE_K, K_NUM_TILES = K_NUM_TILES, @@ -1625,7 +1933,9 @@ def _backward_block_sparse_ele_flows(self, node_flows: torch.Tensor, element_flo BLOCK_SIZE_K = BLOCK_SIZE_K, TL_DOT = TL_DOT, num_warps = 2, # TODO: test for different devices - num_stages = 1 + num_stages = 1, + propagation_alg_id = propagation_alg_id, + **propagation_alg_kwargs ) return None @@ -1635,8 +1945,10 @@ def _backward_block_sparse_ele_flows(self, node_flows: torch.Tensor, element_flo @FastJITFunction def _bk_triton_block_sparse_par_kernel(node_flows, node_mars, element_mars, params, param_flows, nids, cids, pids, pfids, batch_size: tl.constexpr, num_edges: tl.constexpr, allow_modify_flows: tl.constexpr, - TILE_SIZE_B: tl.constexpr, B_NUM_TILES: tl.constexpr, TILE_SIZE_K: tl.constexpr, - TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, TL_DOT: tl.constexpr): + logspace_flows: tl.constexpr, TILE_SIZE_B: tl.constexpr, B_NUM_TILES: tl.constexpr, + TILE_SIZE_K: tl.constexpr, TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, + TL_DOT: tl.constexpr, propagation_alg_id: tl.constexpr, negate_pflows: tl.constexpr, + alpha = 0.0): pid_k = tl.program_id(0) # ID of size-`TILE_SIZE_K` edges pid_m = tl.program_id(1) # ID of size-`TILE_SIZE_M` nodes @@ -1662,30 +1974,57 @@ def _bk_triton_block_sparse_par_kernel(node_flows, node_mars, element_mars, para nmars_ptr = node_mars + (off_nids + offs_node[:,None]) * batch_size + offs_batch[None,:] nflows_ptr = node_flows + (off_nids + offs_node[:,None]) * batch_size + offs_batch[None,:] + # Initialize `params` (only when using MPE propagation method) + if propagation_alg_id == 1: + par_start = tl.load(pids + nblock_id * num_edges + offs_edge) + epars_offsets = offs_node[:,None] + par_start[None,:] # [TILE_SIZE_M, TILE_SIZE_K] + epars = tl.load(params + epars_offsets) + elpars = tl.log(epars) + # Inner loop acc = tl.zeros([TILE_SIZE_M, TILE_SIZE_K], dtype = tl.float32) - + for b in range(0, B_NUM_TILES): emars = tl.load(emars_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_K] - if allow_modify_flows == 1: - log_n_fdm = tl.load(nflows_ptr, mask = mask_batch[None,:], other = -float("inf")) # [TILE_SIZE_M, TILE_SIZE_B] - else: + if propagation_alg_id == 1: nflows = tl.load(nflows_ptr, mask = mask_batch[None,:], other = 0.0) # [TILE_SIZE_M, TILE_SIZE_B] nmars = tl.load(nmars_ptr, mask = mask_batch[None,:], other = 0.0) # [TILE_SIZE_M, TILE_SIZE_B] - log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars) - log_n_fdm_max = tl.max(log_n_fdm, axis = 0) - n_fdm_sub = tl.where(log_n_fdm_max[None,:] != -float("inf"), tl.exp(log_n_fdm - log_n_fdm_max[None,:]), 0.0) + cond = tl.abs(elpars[:,None,:] + emars[None,:,:] - nmars[:,:,None]) < 1e-6 + if logspace_flows: + acc += tl.sum(tl.where(cond, tl.exp(nflows[:,:,None]), 0.0), axis = 1) + else: + acc += tl.sum(tl.where(cond, nflows[:,:,None], 0.0), axis = 1) - scaled_emars = tl.exp(emars + log_n_fdm_max[:,None]) - - if TL_DOT == 1: - partial_flows = tl.dot(n_fdm_sub, scaled_emars) else: - partial_flows = tl.sum(n_fdm_sub[:,:,None] * scaled_emars[None,:,:], axis = 1) - acc += partial_flows + nmars = tl.load(nmars_ptr, mask = mask_batch[None,:], other = 0.0) # [TILE_SIZE_M, TILE_SIZE_B] + + if allow_modify_flows == 1: + log_n_fdm = tl.load(nflows_ptr, mask = mask_batch[None,:], other = -float("inf")) # [TILE_SIZE_M, TILE_SIZE_B] + + if propagation_alg_id == 2: + log_n_fdm += (alpha - 1.0) * nmars + else: + nflows = tl.load(nflows_ptr, mask = mask_batch[None,:], other = 0.0) # [TILE_SIZE_M, TILE_SIZE_B] + + if logspace_flows: + log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), nflows - nmars) + else: + log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars) + + log_n_fdm_max = tl.max(log_n_fdm, axis = 0) + n_fdm_sub = tl.where(log_n_fdm_max[None,:] != -float("inf"), tl.exp(log_n_fdm - log_n_fdm_max[None,:]), 0.0) + + scaled_emars = tl.exp(emars + log_n_fdm_max[:,None]) + + if TL_DOT == 1: + partial_flows = tl.dot(n_fdm_sub, scaled_emars) + else: + partial_flows = tl.sum(n_fdm_sub[:,:,None] * scaled_emars[None,:,:], axis = 1) + + acc += partial_flows # Increment `emars_ptr`, `nmars_ptr`, and `nmars_ptr` emars_ptr += TILE_SIZE_B @@ -1697,25 +2036,34 @@ def _bk_triton_block_sparse_par_kernel(node_flows, node_mars, element_mars, para offs_batch += TILE_SIZE_B mask_batch = offs_batch < batch_size - # Initialize `params` - par_start = tl.load(pids + nblock_id * num_edges + offs_edge) - epars_offsets = offs_node[:,None] + par_start[None,:] # [TILE_SIZE_M, TILE_SIZE_K] - epars = tl.load(params + epars_offsets) + # Initialize `params` (only when NOT using MPE propagation method) + if propagation_alg_id != 1: + par_start = tl.load(pids + nblock_id * num_edges + offs_edge) + epars_offsets = offs_node[:,None] + par_start[None,:] # [TILE_SIZE_M, TILE_SIZE_K] + epars = tl.load(params + epars_offsets) - pflows = acc * epars + if propagation_alg_id != 1: + pflows = acc * epars + else: + pflows = acc parflow_start = tl.load(pfids + nblock_id * num_edges + offs_edge) eparflows_offsets = offs_node[:,None] + parflow_start[None,:] # [TILE_SIZE_M, TILE_SIZE_K] - tl.atomic_add(param_flows + eparflows_offsets, pflows) + if negate_pflows: + tl.atomic_add(param_flows + eparflows_offsets, -1.0 * pflows) + else: + tl.atomic_add(param_flows + eparflows_offsets, pflows) @staticmethod # @triton.jit @FastJITFunction def _bk_triton_block_sparse_par_csmm2_kernel(node_flows, node_mars, element_mars, params, param_flows, nids, cids, pids, pfids, batch_size: tl.constexpr, num_edges: tl.constexpr, allow_modify_flows: tl.constexpr, - TILE_SIZE_B: tl.constexpr, B_NUM_TILES: tl.constexpr, TILE_SIZE_K: tl.constexpr, - TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, TL_DOT: tl.constexpr): + logspace_flows: tl.constexpr, TILE_SIZE_B: tl.constexpr, B_NUM_TILES: tl.constexpr, + TILE_SIZE_K: tl.constexpr, TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, + TL_DOT: tl.constexpr, propagation_alg_id: tl.constexpr, negate_pflows: tl.constexpr, + alpha = 0.0): pid_k = tl.program_id(0) # ID of size-`TILE_SIZE_K` edges pid_m = tl.program_id(1) # ID of size-`TILE_SIZE_M` nodes @@ -1741,27 +2089,51 @@ def _bk_triton_block_sparse_par_csmm2_kernel(node_flows, node_mars, element_mars nmars_ptr = node_mars + (off_nids + offs_node[None,:]) * batch_size + offs_batch[:,None] nflows_ptr = node_flows + (off_nids + offs_node[None,:]) * batch_size + offs_batch[:,None] + # Initialize `params` (only when using MPE propagation method) + if propagation_alg_id == 1: + par_start = tl.load(pids + nblock_id * num_edges + offs_edge) + epars_offsets = offs_node[:,None] + par_start[None,:] # [TILE_SIZE_M, TILE_SIZE_K] + epars = tl.load(params + epars_offsets) + elpars = tl.log(epars) + # Inner loop acc = tl.zeros([TILE_SIZE_M, TILE_SIZE_K], dtype = tl.float32) for b in range(0, B_NUM_TILES): emars = tl.load(emars_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_K] + nmars = tl.load(nmars_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_M] - if allow_modify_flows == 1: - log_n_fdm = tl.load(nflows_ptr, mask = mask_batch[:,None], other = -float("inf")) # [TILE_SIZE_B, TILE_SIZE_M] - else: + if propagation_alg_id == 1: nflows = tl.load(nflows_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_M] - nmars = tl.load(nmars_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_M] - log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars) - log_n_fdm_max = tl.max(log_n_fdm, axis = 1) - n_fdm_sub = tl.where(log_n_fdm_max[:,None] != -float("inf"), tl.exp(log_n_fdm - log_n_fdm_max[:,None]), 0.0) + if logspace_flows: + acc += tl.sum(tl.where(tl.abs(elpars[None,:,:] + emars[:,None,:] - nmars[:,:,None]) < 1e-6, tl.exp(nflows[:,:,None]), 0.0), axis = 0) + else: + acc += tl.sum(tl.where(tl.abs(elpars[None,:,:] + emars[:,None,:] - nmars[:,:,None]) < 1e-6, nflows[:,:,None], 0.0), axis = 0) + + else: + + if allow_modify_flows == 1: + log_n_fdm = tl.load(nflows_ptr, mask = mask_batch[:,None], other = -float("inf")) # [TILE_SIZE_B, TILE_SIZE_M] - scaled_emars = tl.exp(emars + log_n_fdm_max[:,None]) + if propagation_alg_id == 2: + log_n_fdm += (alpha - 1.0) * nmars + else: + nflows = tl.load(nflows_ptr, mask = mask_batch[:,None], other = 0.0) # [TILE_SIZE_B, TILE_SIZE_M] - partial_flows = tl.sum(tl.trans(n_fdm_sub)[:,:,None] * scaled_emars[None,:,:], axis = 1) + if logspace_flows: + log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), nflows - nmars) + else: + log_n_fdm = tl.where(nmars == -float("inf"), -float("inf"), tl.log(nflows) - nmars) - acc += partial_flows + log_n_fdm_max = tl.max(log_n_fdm, axis = 1) + n_fdm_sub = tl.where(log_n_fdm_max[:,None] != -float("inf"), tl.exp(log_n_fdm - log_n_fdm_max[:,None]), 0.0) + + scaled_emars = tl.exp(emars + log_n_fdm_max[:,None]) + + partial_flows = tl.sum(tl.trans(n_fdm_sub)[:,:,None] * scaled_emars[None,:,:], axis = 1) + + acc += partial_flows # Increment `emars_ptr`, `nmars_ptr`, and `nmars_ptr` emars_ptr += TILE_SIZE_B @@ -1773,22 +2145,30 @@ def _bk_triton_block_sparse_par_csmm2_kernel(node_flows, node_mars, element_mars offs_batch += TILE_SIZE_B mask_batch = offs_batch < batch_size - # Initialize `params` - par_start = tl.load(pids + nblock_id * num_edges + offs_edge) - epars_offsets = offs_node[:,None] + par_start[None,:] # [TILE_SIZE_M, TILE_SIZE_K] - epars = tl.load(params + epars_offsets) + # Initialize `params` (only when NOT using MPE propagation method) + if propagation_alg_id != 1: + par_start = tl.load(pids + nblock_id * num_edges + offs_edge) + epars_offsets = offs_node[:,None] + par_start[None,:] # [TILE_SIZE_M, TILE_SIZE_K] + epars = tl.load(params + epars_offsets) - pflows = acc * epars + if propagation_alg_id != 1: + pflows = acc * epars + else: + pflows = acc parflow_start = tl.load(pfids + nblock_id * num_edges + offs_edge) eparflows_offsets = offs_node[:,None] + parflow_start[None,:] # [TILE_SIZE_M, TILE_SIZE_K] - tl.atomic_add(param_flows + eparflows_offsets, pflows) + if negate_pflows: + tl.atomic_add(param_flows + eparflows_offsets, -1.0 * pflows) + else: + tl.atomic_add(param_flows + eparflows_offsets, pflows) def _backward_block_sparse_par_flows(self, node_flows: torch.Tensor, params: torch.Tensor, node_mars: torch.Tensor, element_mars: torch.Tensor, param_flows: torch.Tensor, nids: torch.Tensor, cids: torch.Tensor, pids: torch.Tensor, pfids: torch.Tensor, - allow_modify_flows: bool = False) -> None: + allow_modify_flows: bool = False, propagation_alg: str = "LL", + logspace_flows: bool = False, negate_pflows: bool = False, **kwargs) -> None: """ Backward pass of sum layers w.r.t. sum parameters with the block-sparse processing kernel. @@ -1812,6 +2192,10 @@ def _backward_block_sparse_par_flows(self, node_flows: torch.Tensor, params: tor batch_size = node_mars.size(1) BATCH_SIZE_NP2 = triton.next_power_of_2(batch_size) + # Propagation algorithm + propagation_alg_id = self.propagation_alg_mapping[propagation_alg] + propagation_alg_kwargs = self._get_propagation_alg_kwargs(propagation_alg, **kwargs) + # Heuristic to set `TILE_SIZE_M`, `TILE_SIZE_K`, and `BLOCK_B` base_size = min(self.block_size, num_edges, BATCH_SIZE_NP2) if base_size >= 64: @@ -1821,6 +2205,13 @@ def _backward_block_sparse_par_flows(self, node_flows: torch.Tensor, params: tor TILE_SIZE_B = min(2048 // remainder, base_size * remainder, BATCH_SIZE_NP2) TILE_SIZE_M = min(2048 // TILE_SIZE_B, self.block_size) TILE_SIZE_K = min(2048 // TILE_SIZE_B, num_edges) + + if propagation_alg_id == 1: + # The kernel will stall if the tile sizes are too large + TILE_SIZE_M = min(TILE_SIZE_M, 16) + TILE_SIZE_K = min(TILE_SIZE_K, 16) + TILE_SIZE_B = min(TILE_SIZE_B, 16) + B_NUM_TILES = batch_size // TILE_SIZE_B allow_modify_flows = 1 if allow_modify_flows else 0 @@ -1850,12 +2241,16 @@ def _backward_block_sparse_par_flows(self, node_flows: torch.Tensor, params: tor batch_size = batch_size, num_edges = num_edges, allow_modify_flows = allow_modify_flows, + logspace_flows = logspace_flows, TILE_SIZE_B = TILE_SIZE_B, B_NUM_TILES = B_NUM_TILES, TILE_SIZE_K = TILE_SIZE_K, TILE_SIZE_M = TILE_SIZE_M, BLOCK_SIZE_M = self.block_size, - TL_DOT = TL_DOT + TL_DOT = TL_DOT, + propagation_alg_id = propagation_alg_id, + negate_pflows = negate_pflows, + **propagation_alg_kwargs ) else: self._bk_triton_block_sparse_par_csmm2_kernel[grid]( @@ -1870,22 +2265,30 @@ def _backward_block_sparse_par_flows(self, node_flows: torch.Tensor, params: tor pfids = pfids, batch_size = batch_size, num_edges = num_edges, - allow_modify_flows = allow_modify_flows, + allow_modify_flows = allow_modify_flows, + logspace_flows = logspace_flows, TILE_SIZE_B = TILE_SIZE_B, B_NUM_TILES = B_NUM_TILES, TILE_SIZE_K = TILE_SIZE_K, TILE_SIZE_M = TILE_SIZE_M, BLOCK_SIZE_M = self.block_size, - TL_DOT = TL_DOT + TL_DOT = TL_DOT, + propagation_alg_id = propagation_alg_id, + negate_pflows = negate_pflows, + **propagation_alg_kwargs ) + return None + def _backward_sparse(self, node_flows: torch.Tensor, element_flows: torch.Tensor, params: torch.Tensor, node_mars: torch.Tensor, element_mars: torch.Tensor, param_flows: torch.Tensor, nids: Optional[torch.Tensor], cids: Optional[torch.Tensor], pids: Optional[torch.Tensor], pfids: Optional[torch.Tensor], chids: Optional[torch.Tensor], parids: Optional[torch.Tensor], parpids: Optional[torch.Tensor], cs_block_size: int, local_ids: Optional[torch.Tensor] = None, - partition_id: int = -1, allow_modify_flows: bool = False) -> None: + partition_id: int = -1, allow_modify_flows: bool = False, + propagation_alg: str = "LL", logspace_flows: bool = False, + negate_pflows: bool = False, **kwargs) -> None: """ Back pass of sum layers with sparse processing kernel. @@ -1907,7 +2310,9 @@ def _backward_sparse(self, node_flows: torch.Tensor, element_flows: torch.Tensor node_flows, element_flows, params, node_mars, element_mars, chids = chids, parids = parids, parpids = parpids, cs_block_size = cs_block_size, local_ids = local_ids, - allow_modify_flows = allow_modify_flows + allow_modify_flows = allow_modify_flows, + propagation_alg = propagation_alg, + logspace_flows = logspace_flows, **kwargs ) # Flows w.r.t. parameters @@ -1915,7 +2320,10 @@ def _backward_sparse(self, node_flows: torch.Tensor, element_flows: torch.Tensor self._backward_sparse_par_flows( node_flows, params, node_mars, element_mars, param_flows, nids = nids, cids = cids, pids = pids, pfids = pfids, - allow_modify_flows = allow_modify_flows + allow_modify_flows = allow_modify_flows, + propagation_alg = propagation_alg, + logspace_flows = logspace_flows, + negate_pflows = negate_pflows, **kwargs ) return None @@ -1925,8 +2333,9 @@ def _backward_sparse(self, node_flows: torch.Tensor, element_flows: torch.Tensor @FastJITFunction def _bk_triton_sparse_ele_kernel(node_flows, element_flows, node_mars, element_mars, params, chids, parids, parpids, local_ids, batch_size: tl.constexpr, partial_eval: tl.constexpr, - n_edge_blocks: tl.constexpr, allow_modify_flows: tl.constexpr, BLOCK_B: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr): + n_edge_blocks: tl.constexpr, allow_modify_flows: tl.constexpr, logspace_flows: tl.constexpr, + BLOCK_B: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + propagation_alg_id: tl.constexpr, alpha = 0.0): pid_b = tl.program_id(0) # ID of size-`BLOCK_B` batches pid_m = tl.program_id(1) # ID of size-`BLOCK_M` nodes @@ -1973,10 +2382,41 @@ def _bk_triton_sparse_ele_kernel(node_flows, element_flows, node_mars, element_m epars = tl.load(epars_ptr) # [num_edges] emars = tl.load(emars_ptr, mask = mask_batch) # [BLOCK_B] - if allow_modify_flows == 1: - eflows = tl.sum(epars[:,None] * tl.exp(emars[None,:] + log_n_fdm), axis = 0) + if propagation_alg_id == 1: + if allow_modify_flows: + nflows = log_n_fdm + + lpars = tl.log(epars) + eflows = tl.sum(tl.where(tl.abs(lpars[:,None] + emars[None,:] - nmars) < 1e-6, nflows, 0.0), axis = 0) + else: - eflows = tl.sum(nflows * epars[:,None] * tl.exp(emars[None,:] - nmars), axis = 0) + lpars = tl.log(epars) + if propagation_alg_id == 2: + lpars *= alpha + epars = tl.exp(lpars) + + if allow_modify_flows == 1: + if propagation_alg_id == 0: + eflows = tl.sum(epars[:,None] * tl.exp(emars[None,:] + log_n_fdm), axis = 0) + + if propagation_alg_id == 2: + eflows = tl.sum(epars[:,None] * tl.exp(emars[None,:] * alpha + log_n_fdm), axis = 0) + else: + if logspace_flows: + if propagation_alg_id == 0: + elflows = nflows + lpars[:,None] + emars[None,:] - nmars + + if propagation_alg_id == 2: + elflows = nflows + lpars[:,None] + (emars[None,:] - nmars) * alpha + + elflows_max = tl.max(elflows, axis = 0) + eflows = tl.log(tl.sum(tl.exp(elflows - elflows_max[None,:]), axis = 0)) + elflows_max + else: + if propagation_alg_id == 0: + eflows = tl.sum(nflows * epars[:,None] * tl.exp(emars[None,:] - nmars), axis = 0) + + if propagation_alg_id == 2: + eflows = tl.sum(nflows * epars[:,None] * tl.exp((emars[None,:] - nmars) * alpha), axis = 0) tl.store(eflows_ptr, eflows, mask = mask_batch) @@ -1993,8 +2433,10 @@ def _bk_triton_sparse_ele_kernel(node_flows, element_flows, node_mars, element_m def _bk_triton_large_sparse_ele_kernel(node_flows, element_flows, node_mars, element_mars, params, chids, parids, parpids, local_ids, num_eles, pid_m_offset, batch_size: tl.constexpr, partial_eval: tl.constexpr, - n_edge_blocks: tl.constexpr, allow_modify_flows: tl.constexpr, BLOCK_B: tl.constexpr, - TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr): + n_edge_blocks: tl.constexpr, allow_modify_flows: tl.constexpr, + logspace_flows: tl.constexpr, BLOCK_B: tl.constexpr, + TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + propagation_alg_id: tl.constexpr, alpha = 0.0): pid_b = tl.program_id(0) # ID of size-`BLOCK_B` batches pid_m = tl.program_id(1) + pid_m_offset # ID of size-`TILE_SIZE_M` nodes @@ -2042,10 +2484,41 @@ def _bk_triton_large_sparse_ele_kernel(node_flows, element_flows, node_mars, ele mask = (mask_m[:,None] & mask_batch[None,:])) # [TILE_SIZE_M, BLOCK_B] # Compute eflows - if allow_modify_flows == 1: - eflows = tl.sum(epars[:,:,None] * tl.exp(emars[:,None,:] + log_n_fdm), axis = 1) + if propagation_alg_id == 1: + if allow_modify_flows: + nflows = log_n_fdm + + lpars = tl.log(epars) + eflows = tl.sum(tl.where(tl.abs(lpars[:,:,None] + emars[:,None,:] - nmars) < 1e-6, nflows, 0.0), axis = 1) + else: - eflows = tl.sum(nflows * epars[:,:,None] * tl.exp(emars[:,None,:] - nmars), axis = 1) + lpars = tl.log(epars) + if propagation_alg_id == 2: + lpars *= alpha + epars = tl.exp(lpars) + + if allow_modify_flows == 1: + if propagation_alg_id == 0: + eflows = tl.sum(epars[:,:,None] * tl.exp(emars[:,None,:] + log_n_fdm), axis = 1) + + if propagation_alg_id == 2: + eflows = tl.sum(epars[:,:,None] * tl.exp(emars[:,None,:] * alpha + log_n_fdm), axis = 1) + else: + if logspace_flows: + if propagation_alg_id == 0: + elflows = nflows + lpars[:,:,None] + emars[:,None,:] - nmars + + if propagation_alg_id == 2: + elflows = nflows + lpars[:,:,None] + (emars[:,None,:] - nmars) * alpha + + elflows_max = tl.max(elflows, axis = 1) + eflows = tl.log(tl.sum(tl.exp(elflows - elflows_max[:,None,:]), axis = 1)) + elflows_max + else: + if propagation_alg_id == 0: + eflows = tl.sum(nflows * epars[:,:,None] * tl.exp(emars[:,None,:] - nmars), axis = 1) + + if propagation_alg_id == 2: + eflows = tl.sum(nflows * epars[:,:,None] * tl.exp((emars[:,None,:] - nmars) * alpha), axis = 0) tl.store(eflows_ptr, eflows, mask = (mask_m[:,None] & mask_batch[None,:])) @@ -2053,7 +2526,8 @@ def _backward_sparse_ele_flows(self, node_flows: torch.Tensor, element_flows: to params: torch.Tensor, node_mars: torch.Tensor, element_mars: torch.Tensor, chids: torch.Tensor, parids: torch.Tensor, parpids: torch.Tensor, cs_block_size: int, local_ids: Optional[torch.Tensor] = None, - allow_modify_flows: bool = False) -> None: + allow_modify_flows: bool = False, propagation_alg: str = "LL", + logspace_flows: bool = False, **kwargs) -> None: assert params.dim() == 1, "Expecting a 1D `params`." @@ -2064,6 +2538,10 @@ def _backward_sparse_ele_flows(self, node_flows: torch.Tensor, element_flows: to batch_size = node_flows.size(1) BATCH_SIZE_NP2 = triton.next_power_of_2(batch_size) + # Propagation algorithm + propagation_alg_id = self.propagation_alg_mapping[propagation_alg] + propagation_alg_kwargs = self._get_propagation_alg_kwargs(propagation_alg, **kwargs) + assert num_edges <= 16384, "The sparse backward kernel only support nodes with # edges smaller than 16384." if triton.cdiv(layer_n_nodes, cs_block_size) <= 32768: @@ -2089,9 +2567,12 @@ def _backward_sparse_ele_flows(self, node_flows: torch.Tensor, element_flows: to partial_eval = 1 if local_ids is not None else 0, n_edge_blocks = n_edge_blocks, allow_modify_flows = allow_modify_flows, + logspace_flows = logspace_flows, BLOCK_B = BLOCK_B, BLOCK_M = BLOCK_M, - BLOCK_SIZE_K = self.block_size + BLOCK_SIZE_K = self.block_size, + propagation_alg_id = propagation_alg_id, + **propagation_alg_kwargs ) else: @@ -2120,10 +2601,13 @@ def _backward_sparse_ele_flows(self, node_flows: torch.Tensor, element_flows: to partial_eval = 1 if local_ids is not None else 0, n_edge_blocks = n_edge_blocks, allow_modify_flows = allow_modify_flows, + logspace_flows = logspace_flows, BLOCK_B = BLOCK_B, TILE_SIZE_M = TILE_SIZE_M, BLOCK_SIZE_M = cs_block_size, - BLOCK_SIZE_K = self.block_size + BLOCK_SIZE_K = self.block_size, + propagation_alg_id = propagation_alg_id, + **propagation_alg_kwargs ) else: @@ -2148,10 +2632,13 @@ def _backward_sparse_ele_flows(self, node_flows: torch.Tensor, element_flows: to partial_eval = 1 if local_ids is not None else 0, n_edge_blocks = n_edge_blocks, allow_modify_flows = allow_modify_flows, + logspace_flows = logspace_flows, BLOCK_B = BLOCK_B, TILE_SIZE_M = TILE_SIZE_M, BLOCK_SIZE_M = cs_block_size, - BLOCK_SIZE_K = self.block_size + BLOCK_SIZE_K = self.block_size, + propagation_alg_id = propagation_alg_id, + **propagation_alg_kwargs ) return None @@ -2161,8 +2648,9 @@ def _backward_sparse_ele_flows(self, node_flows: torch.Tensor, element_flows: to @FastJITFunction def _bk_triton_sparse_par_kernel(node_flows, node_mars, element_mars, params, param_flows, nids, cids, pids, pfids, pid_m_offset, num_edges: tl.constexpr, batch_size: tl.constexpr, allow_modify_flows: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_K: tl.constexpr, BLOCK_B: tl.constexpr, - TILE_SIZE_B: tl.constexpr, B_NUM_BLOCKS: tl.constexpr): + logspace_flows: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_K: tl.constexpr, BLOCK_B: tl.constexpr, + TILE_SIZE_B: tl.constexpr, B_NUM_BLOCKS: tl.constexpr, propagation_alg_id: tl.constexpr, + negate_pflows: tl.constexpr, alpha = 0.0): pid_b = tl.program_id(0) # ID of size-`BLOCK_B` samples pid_e = tl.program_id(1) # ID of size-`BLOCK_K` edges @@ -2188,6 +2676,12 @@ def _bk_triton_sparse_par_kernel(node_flows, node_mars, element_mars, params, pa nmars_ptr = node_mars + (off_nids + tile_id) * batch_size + offs_batch # [BLOCK_B] nflows_ptr = node_flows + (off_nids + tile_id) * batch_size + offs_batch # [BLOCK_B] + if propagation_alg_id == 1: + par_start = tl.load(pids + nblock_id * num_edges + offs_edge) + epars_ptr = params + par_start + tile_id + epars = tl.load(epars_ptr) # [BLOCK_K] + elpars = tl.log(epars) + # Inner loop acc = tl.zeros([BLOCK_K], dtype = tl.float32) @@ -2198,107 +2692,70 @@ def _bk_triton_sparse_par_kernel(node_flows, node_mars, element_mars, params, pa emars = tl.load(emars_ptr, mask = mask_batch[None,:], other = -float("inf")) # [BLOCK_K, BLOCK_B] - if allow_modify_flows == 1: - log_n_fdm = tl.load(nflows_ptr, mask = mask_batch, other = -float("inf")) # [BLOCK_B] - pflows = tl.sum(tl.exp(emars + log_n_fdm[None,:]), axis = 1) - else: + if propagation_alg_id == 1: nmars = tl.load(nmars_ptr, mask = mask_batch, other = 0.0) # [BLOCK_B] nflows = tl.load(nflows_ptr, mask = mask_batch, other = 0.0) # [BLOCK_B] - pflows = tl.sum(nflows[None,:] * tl.exp(emars - nmars[None,:]), axis = 1) - - acc += pflows - # Increment `emars_ptr`, `nmars_ptr`, and `nmars_ptr` - emars_ptr += BLOCK_B - nmars_ptr += BLOCK_B - nflows_ptr += BLOCK_B + if logspace_flows: + acc += tl.sum(tl.where(tl.abs(elpars[:,None] + emars - nmars[None,:]) < 1e-6, tl.exp(nflows[None,:]), 0.0), axis = 1) + else: + acc += tl.sum(tl.where(tl.abs(elpars[:,None] + emars - nmars[None,:]) < 1e-6, nflows[None,:], 0.0), axis = 1) - par_start = tl.load(pids + nblock_id * num_edges + offs_edge) - epars_ptr = params + par_start + tile_id - epars = tl.load(epars_ptr) # [BLOCK_K] - - parflow_start = tl.load(pfids + nblock_id * num_edges + offs_edge) - eparflows_ptr = param_flows + parflow_start + tile_id - - curr_pflows = acc * epars - - tl.atomic_add(eparflows_ptr, curr_pflows) - - @staticmethod - # @triton.jit - @FastJITFunction - def _bk_triton_large_sparse_par_kernel(node_flows, node_mars, element_mars, params, param_flows, nids, cids, pids, pfids, - num_nodes, num_edges: tl.constexpr, batch_size: tl.constexpr, allow_modify_flows: tl.constexpr, - TILE_SIZE_M: tl.constexpr, BLOCK_K: tl.constexpr, BLOCK_B: tl.constexpr, - TILE_SIZE_B: tl.constexpr, B_NUM_BLOCKS: tl.constexpr, BLOCK_SIZE_M: tl.constexpr): - - pid_b = tl.program_id(0) # ID of size-`BLOCK_B` samples - pid_e = tl.program_id(1) # ID of size-`BLOCK_K` edges - pid_m = tl.program_id(2) # ID of size-`TILE_SIZE_M` nodes - - offs_m = tl.arange(0, TILE_SIZE_M) + pid_m * TILE_SIZE_M - mask_m = offs_m < num_nodes - - # Get inferred node block id from `pid_m` - nblock_ids = offs_m // BLOCK_SIZE_M - tile_ids = offs_m % BLOCK_SIZE_M - - # Batch offsets and mask - offs_batch = tl.arange(0, BLOCK_B) + pid_b * TILE_SIZE_B - mask_batch = offs_batch < batch_size - - # Initialize pointers to `element_mars` - offs_edge = tl.arange(0, BLOCK_K) + pid_e * BLOCK_K - edge_start = tl.load(cids + nblock_ids[:,None] * num_edges + offs_edge[None,:], mask = mask_m[:,None]) - emars_ptr = element_mars + \ - edge_start[:,:,None] * batch_size + \ - offs_batch[None,None,:] # [TILE_SIZE_M, BLOCK_K, BLOCK_B] - - # Initialize pointers to `node_flows` and `node_mars` - off_nids = tl.load(nids + nblock_ids, mask = mask_m) - nmars_ptr = node_mars + (off_nids + tile_ids)[:,None] * batch_size + offs_batch[None,:] # [TILE_SIZE_M, BLOCK_B] - nflows_ptr = node_flows + (off_nids + tile_ids)[:,None] * batch_size + offs_batch[None,:] # [TILE_SIZE_M, BLOCK_B] - - # Inner loop - acc = tl.zeros([TILE_SIZE_M, BLOCK_K], dtype = tl.float32) + 0.1 - - for b in range(0, B_NUM_BLOCKS): - # Batch offsets and mask - offs_batch = tl.arange(0, BLOCK_B) + pid_b * TILE_SIZE_B + b * BLOCK_B - mask_batch = offs_batch < batch_size - - emars = tl.load(emars_ptr, mask = (mask_m[:,None,None] & mask_batch[None,None,:]), other = -float("inf")) # [TILE_SIZE_M, BLOCK_K, BLOCK_B] - - if allow_modify_flows == 1: - log_n_fdm = tl.load(nflows_ptr, mask = (mask_m[:,None] & mask_batch[None,:]), other = -float("inf")) # [TILE_SIZE_M, BLOCK_B] - pflows = tl.sum(tl.exp(emars + log_n_fdm[:,None,:]), axis = 2) else: - nmars = tl.load(nmars_ptr, mask = (mask_m[:,None] & mask_batch[None,:]), other = 0.0) # [TILE_SIZE_M, BLOCK_B] - nflows = tl.load(nflows_ptr, mask = (mask_m[:,None] & mask_batch[None,:]), other = 0.0) # [TILE_SIZE_M, BLOCK_B] - pflows = tl.sum(nflows[:,None,:] * tl.exp(emars - nmars[:,None,:]), axis = 2) - acc += pflows + if allow_modify_flows == 1: + log_n_fdm = tl.load(nflows_ptr, mask = mask_batch, other = -float("inf")) # [BLOCK_B] + + if propagation_alg_id == 0: + pflows = tl.sum(tl.exp(emars + log_n_fdm[None,:]), axis = 1) + + if propagation_alg_id == 2: + nmars = tl.load(nmars_ptr, mask = mask_batch, other = 0.0) # [BLOCK_B] + pflows = tl.sum(tl.exp(emars + log_n_fdm[None,:] + (alpha - 1.0) * nmars[None,:]), axis = 1) + else: + nmars = tl.load(nmars_ptr, mask = mask_batch, other = 0.0) # [BLOCK_B] + nflows = tl.load(nflows_ptr, mask = mask_batch, other = 0.0) # [BLOCK_B] + + if logspace_flows: + plflows = nflows[None,:] + emars - nmars[None,:] + plflows_max = tl.max(plflows, axis = 1) + pflows = tl.where(plflows_max != -float("inf"), + tl.exp(tl.log(tl.sum(tl.exp(plflows - plflows_max[:,None]), axis = 1)) + plflows_max), + 0.0 + ) + else: + pflows = tl.sum(nflows[None,:] * tl.exp(emars - nmars[None,:]), axis = 1) + + acc += pflows # Increment `emars_ptr`, `nmars_ptr`, and `nmars_ptr` emars_ptr += BLOCK_B nmars_ptr += BLOCK_B nflows_ptr += BLOCK_B - par_start = tl.load(pids + nblock_ids[:,None] * num_edges + offs_edge[None,:]) - epars_ptr = params + par_start + tile_ids[:,None] - epars = tl.load(epars_ptr, mask = mask_m[:,None]) # [TILE_SIZE_M, BLOCK_K] + if propagation_alg_id != 1: + par_start = tl.load(pids + nblock_id * num_edges + offs_edge) + epars_ptr = params + par_start + tile_id + epars = tl.load(epars_ptr) # [BLOCK_K] - parflow_start = tl.load(pfids + nblock_ids[:,None] * num_edges + offs_edge[None,:]) - eparflows_ptr = param_flows + parflow_start + tile_ids[:,None] + parflow_start = tl.load(pfids + nblock_id * num_edges + offs_edge) + eparflows_ptr = param_flows + parflow_start + tile_id - curr_pflows = acc * epars + if propagation_alg_id != 1: + curr_pflows = acc * epars + else: + curr_pflows = acc - tl.atomic_add(eparflows_ptr, curr_pflows, mask = mask_m[:,None]) + if negate_pflows: + tl.atomic_add(eparflows_ptr, -1.0 * curr_pflows) + else: + tl.atomic_add(eparflows_ptr, curr_pflows) def _backward_sparse_par_flows(self, node_flows: torch.Tensor, params: torch.Tensor, node_mars: torch.Tensor, element_mars: torch.Tensor, param_flows: torch.Tensor, nids: torch.Tensor, cids: torch.Tensor, pids: torch.Tensor, pfids: torch.Tensor, - allow_modify_flows: bool = False) -> None: + allow_modify_flows: bool = False, propagation_alg: str = "LL", + logspace_flows: bool = False, negate_pflows: bool = False, **kwargs) -> None: """ Backward pass of sum layers w.r.t. sum parameters with the block-sparse processing kernel. @@ -2322,7 +2779,9 @@ def _backward_sparse_par_flows(self, node_flows: torch.Tensor, params: torch.Ten batch_size = node_mars.size(1) BATCH_SIZE_NP2 = triton.next_power_of_2(batch_size) - # assert num_edges <= 16384, "The sparse backward kernel only support nodes with # edges smaller than 16384." + # Propagation algorithm + propagation_alg_id = self.propagation_alg_mapping[propagation_alg] + propagation_alg_kwargs = self._get_propagation_alg_kwargs(propagation_alg, **kwargs) if num_edges <= 1024: BLOCK_B = max(min(2048 // num_edges, BATCH_SIZE_NP2), 1) @@ -2364,11 +2823,15 @@ def _backward_sparse_par_flows(self, node_flows: torch.Tensor, params: torch.Ten num_edges = num_edges, batch_size = batch_size, allow_modify_flows = allow_modify_flows, + logspace_flows = logspace_flows, BLOCK_M = BLOCK_M, BLOCK_K = BLOCK_K, BLOCK_B = BLOCK_B, TILE_SIZE_B = TILE_SIZE_B, - B_NUM_BLOCKS = B_NUM_BLOCKS + B_NUM_BLOCKS = B_NUM_BLOCKS, + propagation_alg_id = propagation_alg_id, + negate_pflows = negate_pflows, + **propagation_alg_kwargs ) else: @@ -2392,72 +2855,23 @@ def _backward_sparse_par_flows(self, node_flows: torch.Tensor, params: torch.Ten num_edges = num_edges, batch_size = batch_size, allow_modify_flows = allow_modify_flows, + logspace_flows = logspace_flows, BLOCK_M = BLOCK_M, BLOCK_K = BLOCK_K, BLOCK_B = BLOCK_B, TILE_SIZE_B = TILE_SIZE_B, - B_NUM_BLOCKS = B_NUM_BLOCKS + B_NUM_BLOCKS = B_NUM_BLOCKS, + propagation_alg_id = propagation_alg_id, + negate_pflows = negate_pflows, + **propagation_alg_kwargs ) - # else: - - # if num_edges <= 1024: - # BLOCK_B = max(min(2048 // num_edges, BATCH_SIZE_NP2), 1) - # BLOCK_K = num_edges - # TILE_SIZE_M = max(min(4096 // num_edges, triton.next_power_of_2(layer_n_nodes)), 1) - # else: - # BLOCK_B = min(512, BATCH_SIZE_NP2) - # BLOCK_K = min(2048 // BLOCK_B, num_edges) - # TILE_SIZE_M = max(min(2048 // num_edges, triton.next_power_of_2(layer_n_nodes)), 1) - # B_NUM_BLOCKS = triton.cdiv(batch_size, BLOCK_B) - # K_NUM_BLOCKS = triton.cdiv(num_edges, BLOCK_K) - - # # When a thread-block is allocated for too much work, the overhead - # # outweigh that incurred by `atomic_add`. Add more thread-blocks - # # for parallel processing in this case. - # if B_NUM_BLOCKS >= 4: - # TILE_SIZE_B = 4 * BLOCK_B - # B_NUM_BLOCKS = 4 - # else: - # TILE_SIZE_B = batch_size - # B_NUM_TILES = triton.cdiv(batch_size, TILE_SIZE_B) - - # allow_modify_flows = 1 if allow_modify_flows else 0 - - # grid = (B_NUM_TILES, K_NUM_BLOCKS, triton.cdiv(layer_n_nodes, TILE_SIZE_M)) - - # print(">>>G", grid, "in") - # # TODO: This kernel gets stuck for some input configurations. Fix it. - # if grid[0] == 2 and grid[1] == 1 and grid[2] == 308: - # import pdb; pdb.set_trace() - # self._bk_triton_large_sparse_par_kernel[grid]( - # node_flows = node_flows, - # node_mars = node_mars, - # element_mars = element_mars, - # params = params, - # param_flows = param_flows, - # nids = nids, - # cids = cids, - # pids = pids, - # pfids = pfids, - # num_nodes = layer_n_nodes, - # num_edges = num_edges, - # batch_size = batch_size, - # allow_modify_flows = allow_modify_flows, - # TILE_SIZE_M = TILE_SIZE_M, - # BLOCK_K = BLOCK_K, - # BLOCK_B = BLOCK_B, - # TILE_SIZE_B = TILE_SIZE_B, - # B_NUM_BLOCKS = B_NUM_BLOCKS, - # BLOCK_SIZE_M = self.block_size - # ) - # print(">>>G", grid, "out") - return None def _backward_pytorch(self, node_flows, element_flows, params, node_mars, element_mars, param_flows, nids, cids, pids, pfids, - chids, parids, parpids, cs_block_size): + chids, parids, parpids, cs_block_size, propagation_alg: str = "LL", + logspace_flows: bool = False, negate_pflows: bool = False): """ Back pass of sum layers with native pytorch. @@ -2473,18 +2887,21 @@ def _backward_pytorch(self, node_flows, element_flows, params, node_mars, `parpids`: [ng, c] """ + assert propagation_alg == "LL" + # Flows w.r.t. input elements (product nodes) if chids is not None: self._backward_pytorch_ele_kernel( node_flows, element_flows, params, node_mars, element_mars, - param_flows, chids, parids, parpids, cs_block_size + param_flows, chids, parids, parpids, cs_block_size, logspace_flows ) # Flows w.r.t. parameters if param_flows is not None and nids is not None: self._backward_pytorch_par_kernel( node_flows, params, node_mars, element_mars, param_flows, - nids, cids, pids, pfids, self.block_size + nids, cids, pids, pfids, self.block_size, logspace_flows, + negate_pflows ) @torch.compile @@ -2492,7 +2909,7 @@ def _backward_pytorch_ele_kernel(self, node_flows: torch.Tensor, element_flows: params: torch.Tensor, node_mars: torch.Tensor, element_mars: torch.Tensor, param_flows: Optional[torch.Tensor], chids: torch.Tensor, parids: torch.Tensor, parpids: torch.Tensor, - cs_block_size: int): + cs_block_size: int, logspace_flows: bool): num_nblocks = chids.size(0) num_eblocks = parids.size(1) @@ -2506,15 +2923,20 @@ def _backward_pytorch_ele_kernel(self, node_flows: torch.Tensor, element_flows: num_nblocks * cs_block_size, num_eblocks * self.block_size ) - element_flows[chids] = (node_flows[parids] * params[parpids].unsqueeze(-1) * \ - (element_mars[chids].unsqueeze(1) - node_mars[parids]).exp()).sum(dim = 1) + if logspace_flows: + element_flows[chids] = (node_flows[parids] + params[parpids].log().unsqueeze(-1) + \ + element_mars[chids].unsqueeze(1) - node_mars[parids]).logsumexp(dim = 1) + else: + element_flows[chids] = (node_flows[parids] * params[parpids].unsqueeze(-1) * \ + (element_mars[chids].unsqueeze(1) - node_mars[parids]).exp()).sum(dim = 1) return None @torch.compile def _backward_pytorch_par_kernel(self, node_flows: torch.Tensor, params: torch.Tensor, node_mars: torch.Tensor, element_mars: torch.Tensor, param_flows: torch.Tensor, nids: torch.Tensor, - cids: torch.Tensor, pids: torch.Tensor, pfids: torch.Tensor, ns_block_size: int): + cids: torch.Tensor, pids: torch.Tensor, pfids: torch.Tensor, ns_block_size: int, + logspace_flows: bool, negate_pflows: bool): num_nblocks = nids.size(0) num_edges = cids.size(1) @@ -2526,11 +2948,17 @@ def _backward_pytorch_par_kernel(self, node_flows: torch.Tensor, params: torch.T pfids = (pfids[:,None,:].repeat(1, self.block_size, 1) + \ torch.arange(0, self.block_size, device = cids.device)[None,:,None]).reshape(num_nblocks * self.block_size, num_edges) - parflows = (node_flows[nids].unsqueeze(1) * params[pids].unsqueeze(-1) * (element_mars[cids] - node_mars[nids].unsqueeze(1)).exp()).sum(dim = 2) + if logspace_flows: + parflows = (node_flows[nids].exp().unsqueeze(1) * params[pids].unsqueeze(-1) * (element_mars[cids] - node_mars[nids].unsqueeze(1)).exp()).sum(dim = 2) + else: + parflows = (node_flows[nids].unsqueeze(1) * params[pids].unsqueeze(-1) * (element_mars[cids] - node_mars[nids].unsqueeze(1)).exp()).sum(dim = 2) for i in range(num_nblocks): sid, eid = ns_block_size * i, ns_block_size * (i + 1) - param_flows[pfids[sid:eid,:]] += parflows[sid:eid,:] + if negate_pflows: + param_flows[pfids[sid:eid,:]] -= parflows[sid:eid,:] + else: + param_flows[pfids[sid:eid,:]] += parflows[sid:eid,:] return None diff --git a/src/pyjuice/model/backend/__init__.py b/src/pyjuice/model/backend/__init__.py index 6580dde5..46a4cc63 100644 --- a/src/pyjuice/model/backend/__init__.py +++ b/src/pyjuice/model/backend/__init__.py @@ -1,3 +1,3 @@ from .parflow_fusing import compile_cum_par_flows_fn, compute_cum_par_flows, cum_par_flows_to_device -from .par_update import compile_par_update_fn, em_par_update, par_update_to_device +from .par_update import compile_par_update_fn, em_par_update, par_update_to_device, sgd_par_update from .normalize import normalize_parameters diff --git a/src/pyjuice/model/backend/par_update.py b/src/pyjuice/model/backend/par_update.py index ea6f8cb8..f9d4e803 100644 --- a/src/pyjuice/model/backend/par_update.py +++ b/src/pyjuice/model/backend/par_update.py @@ -209,9 +209,9 @@ def cum_pflow_kernel(cum_pflows, params, param_flows, nchs, par_start_ids, pflow @triton.jit -def par_update_kernel(params, param_flows, cum_pflows, nchs, par_start_ids, pflow_start_ids, blk_sizes, blk_intervals, - global_nids, constexprs, num_blocks, keep_zero_params: tl.constexpr, BLOCK_ID: tl.constexpr, - BLOCK_SIZE: tl.constexpr): +def em_par_update_kernel(params, param_flows, cum_pflows, nchs, par_start_ids, pflow_start_ids, blk_sizes, blk_intervals, + global_nids, constexprs, num_blocks, keep_zero_params: tl.constexpr, BLOCK_ID: tl.constexpr, + BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis = 0) @@ -253,6 +253,42 @@ def par_update_kernel(params, param_flows, cum_pflows, nchs, par_start_ids, pflo tl.store(params + offs_par, updated_param, mask = mask_pflow) +@triton.jit +def sgd_par_update_kernel(params, param_grads, par_start_ids, pgrad_start_ids, blk_sizes, blk_intervals, + global_nids, constexprs, num_blocks, keep_zero_params: tl.constexpr, BLOCK_ID: tl.constexpr, + BLOCK_SIZE: tl.constexpr): + + pid = tl.program_id(axis = 0) + + # Retrieve the constants + lr = tl.load(constexprs) + + offs_m = pid * BLOCK_ID + tl.arange(0, BLOCK_ID) + mask_m = offs_m < num_blocks + + offs_blk = tl.arange(0, BLOCK_SIZE) + + par_start = tl.load(par_start_ids + offs_m, mask = mask_m, other = 0) + pgrad_start = tl.load(pgrad_start_ids + offs_m, mask = mask_m, other = 0) + blk_size = tl.load(blk_sizes + offs_m, mask = mask_m, other = 0) + blk_interval = tl.load(blk_intervals + offs_m, mask = mask_m, other = 0) + global_nid = tl.load(global_nids + offs_m, mask = mask_m, other = 0) + + offs_pgrad = pgrad_start[:,None] + offs_blk[None,:] * blk_interval[:,None] + mask_pgrad = mask_m[:,None] & (offs_blk[None,:] < blk_size[:,None]) + pgrads = tl.load(param_grads + offs_pgrad, mask = mask_pgrad, other = 0) + + offs_par = par_start[:,None] + offs_blk[None,:] * blk_interval[:,None] + old_param = tl.load(params + offs_par, mask = mask_pgrad, other = 0) + + if keep_zero_params: + updated_params = tl.where(old_param < 1e-12, 0.0, tl.exp(tl.log(old_param) + lr * pgrads)) + else: + updated_param = tl.exp(tl.log(old_param) + lr * pgrads) + + tl.store(params + offs_par, updated_param, mask = mask_pgrad) + + def em_par_update(params: torch.Tensor, param_flows: torch.Tensor, par_update_kwargs: Sequence, step_size: float, pseudocount: float = 0.0, keep_zero_params: bool = True): @@ -280,7 +316,47 @@ def em_par_update(params: torch.Tensor, param_flows: torch.Tensor, par_update_kw global_nids, constexprs, num_blocks, keep_zero_params, BLOCK_ID, BLOCK_SIZE ) - par_update_kernel[grid]( + em_par_update_kernel[grid]( params, param_flows, cum_pflows, nchs, par_start_ids, pflow_start_ids, blk_sizes, blk_intervals, global_nids, constexprs, num_blocks, keep_zero_params, BLOCK_ID, BLOCK_SIZE ) + + return None + + +def sgd_par_update(params: torch.Tensor, param_grads: torch.Tensor, par_update_kwargs: Sequence, + lr: float, keep_zero_params: bool = True): + """ + Apply one-step SGD parameter update. + + :param params: the parameter tensor + :type params: torch.Tensor + + :param param_grads: gradients of the log-parameters + :type param_grads: torch.Tensor + + :param lr: learning rate + :type lr: float + + :param keep_zero_params: whether to freeze zero parameters + :type keep_zero_params: bool + """ + + par_start_ids, pgrad_start_ids, blk_sizes, blk_intervals, global_nids, nchs, cum_pflows, metadata = par_update_kwargs + + tot_num_nodes = metadata["tot_num_nodes"] + BLOCK_SIZE = metadata["BLOCK_SIZE"] + + num_blocks = par_start_ids.size(0) + BLOCK_ID = 2048 // BLOCK_SIZE + + grid = (triton.cdiv(num_blocks, BLOCK_ID),) + + constexprs = torch.tensor([lr]).to(params.device) + + sgd_par_update_kernel[grid]( + params, param_grads, par_start_ids, pgrad_start_ids, blk_sizes, blk_intervals, + global_nids, constexprs, num_blocks, keep_zero_params, BLOCK_ID, BLOCK_SIZE + ) + + return None diff --git a/src/pyjuice/model/tensorcircuit.py b/src/pyjuice/model/tensorcircuit.py index 2fa826e3..9a1994bf 100644 --- a/src/pyjuice/model/tensorcircuit.py +++ b/src/pyjuice/model/tensorcircuit.py @@ -19,7 +19,7 @@ normalize_parameters -def _pc_model_backward_hook(grad, pc, inputs, record_cudagraph, apply_cudagraph, **kwargs): +def _pc_model_backward_hook(grad, pc, inputs, record_cudagraph, apply_cudagraph, propagation_alg, **kwargs): grad = grad.permute(1, 0) pc.backward( inputs = inputs, @@ -28,6 +28,7 @@ def _pc_model_backward_hook(grad, pc, inputs, record_cudagraph, apply_cudagraph, flows_memory = pc._optim_hyperparams["flows_memory"], record_cudagraph = record_cudagraph, apply_cudagraph = apply_cudagraph, + propagation_alg = propagation_alg, **kwargs ) @@ -101,6 +102,10 @@ def __init__(self, root_ns: CircuitNodes, layer_sparsity_tol: float = 0.5, # CudaGraph options self._recorded_cuda_graphs = dict() + # Mode for forward and backward pass + self.default_propagation_alg = "LL" # Could be "LL", "MPE", or "GeneralLL" + self.propagation_alg_kwargs = dict() + def to(self, device): super(TensorCircuit, self).to(device) @@ -115,10 +120,26 @@ def to(self, device): self.par_update_kwargs = par_update_to_device(self.par_update_kwargs, device) return self + + def set_propagation_alg(self, propagation_alg: str, **kwargs): + if propagation_alg == "LL": + self.default_propagation_alg = "LL" + self.propagation_alg_kwargs.clear() + elif propagation_alg == "MPE": + self.default_propagation_alg = "MPE" + self.propagation_alg_kwargs.clear() + elif propagation_alg == "GeneralLL": + assert "alpha" in kwargs, "Argument `alpha` should be provided for the `GeneralLL` propagation algorithm." + self.default_propagation_alg = "GeneralLL" + self.propagation_alg_kwargs.clear() + self.propagation_alg_kwargs["alpha"] = kwargs["alpha"] + else: + raise NotImplementedError(f"Unknown propagation algorithm {propagation_alg}.") def forward(self, inputs: torch.Tensor, input_layer_fn: Optional[Union[str,Callable]] = None, cache: Optional[dict] = None, return_cache: bool = False, record_cudagraph: bool = False, - apply_cudagraph: bool = True, force_use_fp16: bool = False, force_use_fp32: bool = False, **kwargs): + apply_cudagraph: bool = True, force_use_fp16: bool = False, force_use_fp32: bool = False, + propagation_alg: Optional[Union[str,Sequence[str]]] = None, **kwargs): """ Forward evaluation of the PC. @@ -132,9 +153,14 @@ def forward(self, inputs: torch.Tensor, input_layer_fn: Optional[Union[str,Calla B = inputs.size(0) if input_layer_fn is None: - assert inputs.dim() == 2 and inputs.size(1) == self.num_vars + assert inputs.dim() == 2 inputs = inputs.permute(1, 0) + + # Set propagation algorithm + if propagation_alg is None: + propagation_alg = self.default_propagation_alg + kwargs.update(self.propagation_alg_kwargs) ## Initialize buffers for forward pass ## @@ -165,7 +191,7 @@ def forward(self, inputs: torch.Tensor, input_layer_fn: Optional[Union[str,Calla # Inner layers def _run_inner_layers(): - for layer_group in self.inner_layer_groups: + for layer_id, layer_group in enumerate(self.inner_layer_groups): if layer_group.is_prod(): # Prod layer layer_group(self.node_mars, self.element_mars) @@ -173,8 +199,10 @@ def _run_inner_layers(): elif layer_group.is_sum(): # Sum layer layer_group(self.node_mars, self.element_mars, self.params, - force_use_fp16 = force_use_fp16, - force_use_fp32 = force_use_fp32) + force_use_fp16 = force_use_fp16, + force_use_fp32 = force_use_fp32, + propagation_alg = propagation_alg if isinstance(propagation_alg, str) else propagation_alg[layer_id], + **kwargs) else: raise ValueError(f"Unknown layer type {type(layer)}.") @@ -226,6 +254,7 @@ def _run_inner_layers(): inputs = inputs, record_cudagraph = record_cudagraph, apply_cudagraph = apply_cudagraph, + propagation_alg = propagation_alg, **kwargs ) ) @@ -245,6 +274,9 @@ def backward(self, inputs: Optional[torch.Tensor] = None, record_cudagraph: bool = False, apply_cudagraph: bool = True, allow_modify_flows: bool = True, + propagation_alg: Union[str,Sequence[str]] = "LL", + logspace_flows: bool = False, + negate_pflows: bool = False, **kwargs): """ Backward evaluation of the PC that computes node flows as well as parameter flows. @@ -269,21 +301,24 @@ def backward(self, inputs: Optional[torch.Tensor] = None, ## Initialize buffers for backward pass ## - self._init_buffer(name = "node_flows", shape = (self.num_nodes, B), set_value = 0.0) - self._init_buffer(name = "element_flows", shape = (self.num_elements, B), set_value = 0.0) + self._init_buffer(name = "node_flows", shape = (self.num_nodes, B), set_value = 0.0 if not logspace_flows else -float("inf")) + self._init_buffer(name = "element_flows", shape = (self.num_elements, B), set_value = 0.0 if not logspace_flows else -float("inf")) # Set root node flows def _set_root_node_flows(): nonlocal ll_weights + nonlocal logspace_flows if ll_weights is None: - self.node_flows[self._root_node_range[0]:self._root_node_range[1],:] = 1.0 + root_flows = 1.0 if not logspace_flows else 0.0 + self.node_flows[self._root_node_range[0]:self._root_node_range[1],:] = root_flows else: if ll_weights.dim() == 1: ll_weights = ll_weights.unsqueeze(1) assert ll_weights.size(0) == self.num_root_nodes - self.node_flows[self._root_node_range[0]:self._root_node_range[1],:] = ll_weights + root_flows = ll_weights if not logspace_flows else ll_weights.log() + self.node_flows[self._root_node_range[0]:self._root_node_range[1],:] = root_flows _set_root_node_flows() @@ -308,7 +343,7 @@ def _run_inner_layers(): if layer_group.is_prod(): # Prod layer - layer_group.backward(self.node_flows, self.element_flows) + layer_group.backward(self.node_flows, self.element_flows, logspace_flows = logspace_flows) elif layer_group.is_sum(): # Sum layer @@ -319,12 +354,14 @@ def _run_inner_layers(): # Backward sum layer layer_group.backward(self.node_flows, self.element_flows, self.node_mars, self.element_mars, self.params, param_flows = self.param_flows if compute_param_flows else None, - allow_modify_flows = allow_modify_flows) + allow_modify_flows = allow_modify_flows, + propagation_alg = propagation_alg if isinstance(propagation_alg, str) else propagation_alg[layer_id], + logspace_flows = logspace_flows, negate_pflows = negate_pflows, **kwargs) else: raise ValueError(f"Unknown layer type {type(layer)}.") - signature = (1, id(self.node_flows), id(self.element_flows), id(self.node_mars), id(self.element_mars), id(self.params), id(self.param_flows), B) + signature = (1, id(self.node_flows), id(self.element_flows), id(self.node_mars), id(self.element_mars), id(self.params), id(self.param_flows), B, allow_modify_flows, logspace_flows) if record_cudagraph and signature not in self._recorded_cuda_graphs: # Warmup s = torch.cuda.Stream() @@ -355,14 +392,14 @@ def _run_inner_layers(): # Compute backward pass for all input layers for idx, layer in enumerate(self.input_layer_group): if input_layer_fn is None: - layer.backward(inputs, self.node_flows, self.node_mars, **kwargs) + layer.backward(inputs, self.node_flows, self.node_mars, logspace_flows = logspace_flows, **kwargs) elif isinstance(input_layer_fn, str): assert hasattr(layer, input_layer_fn), f"Custom input function `{input_layer_fn}` not found for layer type {type(layer)}." - getattr(layer, input_layer_fn)(inputs, self.node_flows, self.node_mars, **kwargs) + getattr(layer, input_layer_fn)(inputs, self.node_flows, self.node_mars, logspace_flows = logspace_flows, **kwargs) elif isinstance(input_layer_fn, Callable): - input_layer_fn(layer, inputs, self.node_flows, self.node_mars, **kwargs) + input_layer_fn(layer, inputs, self.node_flows, self.node_mars, logspace_flows = logspace_flows, **kwargs) else: raise ValueError(f"Custom input function should be either a `str` or a `Callable`. Found {type(input_layer_fn)} instead.") @@ -378,6 +415,15 @@ def _run_inner_layers(): else: return None + def forward_ll(self, *args, **kwargs): + self.forward(*args, propagation_alg = "LL", **kwargs) + + def forward_mpe(self, *args, **kwargs): + self.forward(*args, propagation_alg = "MPE", **kwargs) + + def forward_general_ll(self, *args, alpha: float = 1.0, **kwargs): + self.forward(*args, propagation_alg = "GeneralLL", **kwargs) + def mini_batch_em(self, step_size: float, pseudocount: float = 0.0, keep_zero_params: bool = False): """ Perform an EM parameter update step using the accumulated parameter flows. @@ -809,12 +855,12 @@ def dfs(ns: CircuitNodes): ns.chs[idx] = pass_prod_ns depth2nodes[cs_depth]["sum"].append(pass_sum_ns) - depth2nodes[depth]["prod"].append(pass_prod_ns) nodes2depth[pass_sum_ns] = cs_depth nodes2depth[pass_prod_ns] = depth depth2nodes[depth]["sum"].append(ns) + if ns.block_size > max_node_block_size: max_node_block_size = ns.block_size elif ns.is_prod(): @@ -834,6 +880,7 @@ def dfs(ns: CircuitNodes): assert pns2layer[id(cs)] == layer, "Disallowed circumstance: a product node requested by sum nodes at different layers." else: depth2nodes[layer]["prod"].append(cs) + pns2layer[id(cs)] = layer return depth2nodes, num_layers, max_node_block_size, max_ele_block_size diff --git a/src/pyjuice/nodes/nodes.py b/src/pyjuice/nodes/nodes.py index 00f31e01..b8741f79 100644 --- a/src/pyjuice/nodes/nodes.py +++ b/src/pyjuice/nodes/nodes.py @@ -226,6 +226,27 @@ def clear_hooks(ns): else: clear_hooks(self) + def __len__(self): + count = 0 + + def dfs(ns: CircuitNodes, visited: set = set()): + nonlocal count + + if ns in visited: + return + + visited.add(ns) + + # Recursively traverse children + if ns.is_sum() or ns.is_prod(): + for cs in ns.chs: + dfs(cs, visited = visited) + + count += 1 + + dfs(self) + return count + def __iter__(self): return node_iterator(self, self._reverse_iter) diff --git a/src/pyjuice/nodes/sum_nodes.py b/src/pyjuice/nodes/sum_nodes.py index eb64e4ab..3864f8ae 100644 --- a/src/pyjuice/nodes/sum_nodes.py +++ b/src/pyjuice/nodes/sum_nodes.py @@ -71,6 +71,10 @@ def num_edges(self): """ return self.edge_ids.size(1) * self.block_size * self.ch_block_size + @property + def num_ch_nodes(self): + return self.num_ch_node_blocks * self.ch_block_size + def duplicate(self, *args, tie_params: bool = False) -> SumNodes: """ Create a duplication of the current node with the same specification (i.e., number of nodes, block size). @@ -334,8 +338,6 @@ def _construct_edges(self, edge_ids: Optional[Union[Tensor,Sequence[Tensor]]], r curr_edge_ids[1,:] += ch_gid_start edge_ids.append(curr_edge_ids) - ch_nid_start += self.chs[cs_id].num_node_blocks - edge_ids = torch.cat(edge_ids, dim = 1) if reorder: diff --git a/src/pyjuice/optim/optim.py b/src/pyjuice/optim/optim.py index 16258395..3c89cb19 100644 --- a/src/pyjuice/optim/optim.py +++ b/src/pyjuice/optim/optim.py @@ -8,10 +8,10 @@ class CircuitOptimizer(): - SUPPORTED_OPTIM_METHODS = ["EM"] + SUPPORTED_OPTIM_METHODS = ["EM", "Viterbi", "GeneralEM"] def __init__(self, pc: TensorCircuit, base_optimizer: Optional[Optimizer] = None, method: str = "EM", lr: float = 0.1, - pseudocount: float = 0.1): + pseudocount: float = 0.1, **kwargs): self.pc = pc diff --git a/src/pyjuice/queries/conditional.py b/src/pyjuice/queries/conditional.py index 43ae22f8..9b9a833a 100644 --- a/src/pyjuice/queries/conditional.py +++ b/src/pyjuice/queries/conditional.py @@ -114,7 +114,7 @@ def _categorical_forward(layer, inputs: torch.Tensor, node_mars: torch.Tensor, @triton.jit def _categorical_backward_kernel(cat_probs_ptr, node_flows_ptr, local_ids_ptr, rev_vars_mapping_ptr, vids_ptr, psids_ptr, node_nchs_ptr, params_ptr, sid, eid, num_target_nodes, batch_size: tl.constexpr, - num_cats: tl.constexpr, BLOCK_SIZE: tl.constexpr): + num_cats: tl.constexpr, partial_eval: tl.constexpr, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis = 0) block_start = pid * BLOCK_SIZE @@ -123,7 +123,10 @@ def _categorical_backward_kernel(cat_probs_ptr, node_flows_ptr, local_ids_ptr, r # Get node offsets and batch offsets local_offsets = (offsets // batch_size) - local_node_offsets = tl.load(local_ids_ptr + local_offsets, mask = mask, other = 0) + if partial_eval == 1: + local_node_offsets = tl.load(local_ids_ptr + local_offsets, mask = mask, other = 0) + else: + local_node_offsets = local_offsets batch_offsets = (offsets % batch_size) global_node_offsets = local_node_offsets + sid @@ -182,20 +185,27 @@ def _categorical_backward(layer, inputs: torch.Tensor, node_flows: torch.Tensor, cat_probs = torch.zeros([num_target_vars * num_cats * batch_size], dtype = torch.float32, device = node_flows.device) - local_ids = layer.enable_partial_evaluation(bk_scopes = target_vars, return_ids = True).to(node_flows.device) - num_target_nodes = local_ids.size(0) + if len(target_vars) < num_vars: + local_ids = layer.enable_partial_evaluation(bk_scopes = target_vars, return_ids = True).to(node_flows.device) + num_target_nodes = local_ids.size(0) + partial_eval = 1 + else: + local_ids = None + num_target_nodes = eid - sid + partial_eval = 0 node_nchs = layer.metadata[layer.s_mids] grid = lambda meta: (triton.cdiv(num_target_nodes * batch_size, meta['BLOCK_SIZE']),) + _categorical_backward_kernel[grid]( cat_probs, node_flows, local_ids, rev_vars_mapping, layer.vids, layer.s_pids, node_nchs, layer.params, - sid, eid, num_target_nodes, batch_size, num_cats, BLOCK_SIZE = 512 + sid, eid, num_target_nodes, batch_size, num_cats, partial_eval = partial_eval, BLOCK_SIZE = 512 ) cat_probs = cat_probs.reshape(num_target_vars, num_cats, batch_size) - cat_probs /= cat_probs.sum(dim = 1, keepdim = True) + cat_probs /= (cat_probs.sum(dim = 1, keepdim = True) + 1e-12) cat_probs = cat_probs.permute(2, 0, 1) return cat_probs diff --git a/src/pyjuice/structures/hmm.py b/src/pyjuice/structures/hmm.py index d897d936..7351cac3 100644 --- a/src/pyjuice/structures/hmm.py +++ b/src/pyjuice/structures/hmm.py @@ -50,7 +50,7 @@ def HMM(seq_length: int, num_latents: int, num_emits: int, homogeneous: bool = T block_size = min(max_cdf_power_of_2(num_latents), 1024) num_node_blocks = num_latents // block_size - with juice.set_block_size(block_size = block_size): + with set_block_size(block_size = block_size): ns_input = inputs( seq_length - 1, num_node_blocks = num_node_blocks, diff --git a/src/pyjuice/structures/pd.py b/src/pyjuice/structures/pd.py index 541b672f..05ee7f6e 100644 --- a/src/pyjuice/structures/pd.py +++ b/src/pyjuice/structures/pd.py @@ -22,9 +22,9 @@ def PD(data_shape: Tuple, num_latents: int, structure_type: str = "sum_dominated", input_layer_fn: Optional[Callable] = None, input_dist: Optional[Distribution] = None, - input_layer_type: Type[Distribution] = Categorical, - input_layer_params: Dict = {"num_cats": 256}, - use_linear_mixing: bool = False, + input_node_type: Type[Distribution] = Categorical, + input_node_params: Dict = {"num_cats": 256}, + tie_homogeneous_params: bool = False, block_size: Optional[int] = None): """ Generate PCs with the PD structure (https://arxiv.org/pdf/1202.3732.pdf). @@ -53,6 +53,9 @@ def PD(data_shape: Tuple, num_latents: int, :param input_dist: input distribution :type input_dist: Distribution + :param tie_homogeneous_params: whether to tie parameters of sum/input nodes with compatible structures + :type tie_homogeneous_params: bool + :param block_size: block size :type block_size: int """ @@ -72,6 +75,9 @@ def PD(data_shape: Tuple, num_latents: int, num_axes = len(data_shape) + # A dictionary of source nodes + source_ns_dict = dict() + # Construct split points if split_intervals is not None: if isinstance(split_intervals, int): @@ -117,6 +123,12 @@ def updated_hypercube(hypercube, axis, s = None, e = None): hypercube = (tuple(hypercube[0]), tuple(hypercube[1])) return hypercube + def hypercube2shape(hypercube): + return tuple(hypercube[1][i] - hypercube[0][i] for i in range(len(hypercube[0]))) + + def get_signature(hypercube, *ch_ns): + return (hypercube2shape(hypercube), tuple(len(ns.scope) for ns in ch_ns)) + def create_input_ns(hypercube): scope = hypercube2scope(hypercube) if input_layer_fn is not None: @@ -124,11 +136,28 @@ def create_input_ns(hypercube): else: input_nodes = [] for var in scope: - ns = inputs(var, num_node_blocks = num_node_blocks, dist = input_layer_type(**input_layer_params)) + if not tie_homogeneous_params: + ns = inputs(var, num_node_blocks = num_node_blocks, dist = input_node_type(**input_node_params)) + else: + if "input" in source_ns_dict: + ns = source_ns_dict["input"].duplicate(var, tie_params = True) + else: + ns = inputs(var, num_node_blocks = num_node_blocks, dist = input_node_type(**input_node_params)) + source_ns_dict["input"] = ns input_nodes.append(ns) edge_ids = torch.arange(0, num_node_blocks)[None,:].repeat(2, 1) - return summate(multiply(*input_nodes), num_node_blocks = num_node_blocks, edge_ids = edge_ids) + pns = multiply(*input_nodes) + if not tie_homogeneous_params: + return summate(pns, num_node_blocks = num_node_blocks, edge_ids = edge_ids) + else: + signature = get_signature(hypercube, pns) + if signature in source_ns_dict: + return source_ns_dict[signature].duplicate(pns, tie_params = True) + else: + ns = summate(pns, num_node_blocks = num_node_blocks, edge_ids = edge_ids) + source_ns_dict[signature] = ns + return ns def recursive_construct(hypercube, depth = 1): if hypercube in hypercube2ns: @@ -161,24 +190,35 @@ def recursive_construct(hypercube, depth = 1): ns = create_input_ns(hypercube) elif hypercube == root_hypercube: ns = summate(*pns, num_node_blocks = 1, block_size = 1) - elif not use_linear_mixing: + else: if len(pns) <= max_prod_block_conns: - ns = summate(*pns, num_node_blocks = num_node_blocks) + if not tie_homogeneous_params: + ns = summate(*pns, num_node_blocks = num_node_blocks) + else: + signature = get_signature(hypercube, *pns) + if signature in source_ns_dict: + ns = source_ns_dict[signature].duplicate(*pns, tie_params = True) + else: + ns = summate(*pns, num_node_blocks = num_node_blocks) + source_ns_dict[signature] = ns else: block_ids = torch.topk(torch.rand([num_node_blocks, len(pns)]), k = max_prod_block_conns, dim = 1).indices par_ids = torch.arange(0, num_node_blocks)[:,None,None].repeat(1, max_prod_block_conns, num_node_blocks) chs_ids = block_ids[:,:,None] * num_node_blocks + torch.arange(0, num_node_blocks)[None,None,:] edge_ids = torch.stack((par_ids.reshape(-1), chs_ids.reshape(-1)), dim = 0) - ns = summate(*pns, num_node_blocks = num_node_blocks, edge_ids = edge_ids) - else: - # Linear mixing as implemented in EiNet's Mixing layer - if len(pns) <= max_prod_block_conns: - ns = summate(*pns, num_node_blocks = num_node_blocks) - else: - ch_ns = [multiply(summate(pn, num_node_blocks = num_node_blocks)) for pn in pns] - ns = summate(*ch_ns, num_node_blocks = num_node_blocks, edge_ids = torch.arange(0, num_node_blocks)[None,:].repeat(2, 1)) + + if not tie_homogeneous_params: + ns = summate(*pns, num_node_blocks = num_node_blocks, edge_ids = edge_ids) + else: + signature = get_signature(hypercube, *pns) + if signature in source_ns_dict: + ns = source_ns_dict[signature].duplicate(*pns, tie_params = True) + else: + ns = summate(*pns, num_node_blocks = num_node_blocks, edge_ids = edge_ids) + source_ns_dict[signature] = ns hypercube2ns[hypercube] = ns + return ns with set_block_size(block_size = block_size): @@ -194,21 +234,25 @@ def PDHCLT(data: torch.Tensor, data_shape: Tuple, num_latents: int, max_split_depth: Optional[int] = None, max_prod_block_conns: int = 4, structure_type: str = "sum_dominated", - input_layer_type: Type[Distribution] = Categorical, - input_layer_params: Dict = {"num_cats": 256}, + input_dist: Optional[Distribution] = None, + input_node_type: Type[Distribution] = Categorical, + input_node_params: Dict = {"num_cats": 256}, hclt_kwargs: Dict = {"num_bins": 32, "sigma": 0.5 / 32, "chunk_size": 32}, block_size: Optional[int] = None): assert data.dim() == 2 assert data.size(1) == reduce(lambda x, y: x * y, data_shape) + if input_dist is not None: + input_node_type, input_node_params = input_dist._get_constructor() + def input_layer_fn(scope, num_latents, block_size): vars = torch.tensor(scope.to_list()).sort().values ns = HCLT( x = data[:,vars], num_latents = num_latents, - input_layer_type = input_layer_type, - input_layer_params = input_layer_params, + input_node_type = input_node_type, + input_node_params = input_node_params, num_root_ns = num_latents, block_size = block_size, **hclt_kwargs @@ -224,7 +268,7 @@ def input_layer_fn(scope, num_latents, block_size): split_intervals = split_intervals, split_points = split_points, max_split_depth = max_split_depth, max_prod_block_conns = max_prod_block_conns, structure_type = structure_type, input_layer_fn = input_layer_fn, - input_layer_type = input_layer_type, input_layer_params = input_layer_params, + input_node_type = input_node_type, input_node_params = input_node_params, block_size = block_size) if ns.num_node_blocks > 1: diff --git a/src/pyjuice/structures/rat_spn.py b/src/pyjuice/structures/rat_spn.py index 7af44608..7d011e21 100644 --- a/src/pyjuice/structures/rat_spn.py +++ b/src/pyjuice/structures/rat_spn.py @@ -14,8 +14,8 @@ def RAT_SPN(num_vars: int, num_latents: int, depth: int, num_repetitions: int, num_pieces: int = 2, input_dist: Optional[Distribution] = None, - input_layer_type: Type[Distribution] = Categorical, - input_layer_params: dict = {"num_cats": 256}, + input_node_type: Type[Distribution] = Categorical, + input_node_params: dict = {"num_cats": 256}, block_size: Optional[int] = None): """ Generate Random and Tensorized SPNs (https://proceedings.mlr.press/v115/peharz20a/peharz20a.pdf) @@ -57,7 +57,7 @@ def RAT_SPN(num_vars: int, num_latents: int, depth: int, num_repetitions: int, n # Input nodes input_ns = [] for v in range(num_vars): - ns = inputs(v, num_node_blocks = num_node_blocks, dist = input_layer_type(**input_layer_params)) + ns = inputs(v, num_node_blocks = num_node_blocks, dist = input_node_type(**input_node_params)) input_ns.append(ns) # Top-down partition diff --git a/src/pyjuice/transformations/copy.py b/src/pyjuice/transformations/copy.py index 5ba4f2a4..73481923 100644 --- a/src/pyjuice/transformations/copy.py +++ b/src/pyjuice/transformations/copy.py @@ -1,5 +1,6 @@ from __future__ import annotations +import torch from copy import deepcopy as pydeepcopy from typing import Optional, Dict @@ -7,7 +8,7 @@ from pyjuice.utils import BitSet -def deepcopy(root_ns: CircuitNodes, tie_params: bool = False, +def deepcopy(root_ns: CircuitNodes, tie_params: bool = False, max_block_size: Optional[int] = None, var_mapping: Optional[Dict[int,int]] = None) -> CircuitNodes: """ Create a deepcopy of the input PC. @@ -18,12 +19,19 @@ def deepcopy(root_ns: CircuitNodes, tie_params: bool = False, :param tie_params: whether to tie the parameters between the original PC and the copied PC (if tied, their parameters will always be the same) :type tie_params: bool + :param max_block_size: the maximum block size of the copied PC + :type max_block_size: Optional[int] + :param var_mapping: a mapping dictionary between the variables of the original PC and the copied PC :type var_mapping: Optional[Dict[int,int]] :returns: a copied PC """ + assert not (max_block_size is not None and tie_params), "Could not change block size when `tie_params=True`." + if max_block_size is not None: + assert max_block_size > 0 and (max_block_size & (max_block_size - 1)) == 0, f"`max_block_size` must be a power of 2, but got `max_block_size={max_block_size}`." + old2new = dict() tied_ns_pairs = [] @@ -43,24 +51,74 @@ def dfs(ns: CircuitNodes): if ns.is_sum(): if not tie_params: + if max_block_size is None: + edge_ids = ns.edge_ids.clone() + block_size = ns.block_size + params = ns.get_params() + else: + old_ch_blk_size = ns.chs[0].block_size + old_blk_size = ns.block_size + + new_ch_blk_size = new_chs[0].block_size + new_blk_size = min(old_blk_size, max_block_size) + + blk_factor = old_blk_size // new_blk_size + ch_blk_factor = old_ch_blk_size // new_ch_blk_size + + edge_ids = torch.stack( + (ns.edge_ids[0,:][:,None,None].repeat(1, blk_factor, ch_blk_factor) * blk_factor + torch.arange(0, blk_factor)[None,:,None], + ns.edge_ids[1,:][:,None,None].repeat(1, blk_factor, ch_blk_factor) * ch_blk_factor + torch.arange(0, ch_blk_factor)[None,None,:]), + dim = 0 + ).flatten(1, 3) + block_size = new_blk_size + + params = ns.get_params() + if params is not None: + num_edges = params.size(0) + params = params.reshape(num_edges, blk_factor, new_blk_size, ch_blk_factor, new_ch_blk_size).permute(0, 1, 3, 2, 4).flatten(0, 2) + new_ns = SumNodes( - ns.num_node_blocks, + ns.num_nodes // block_size, new_chs, - ns.edge_ids.clone(), - block_size = ns.block_size + edge_ids, + block_size = block_size ) - params = ns.get_params() if params is not None: new_ns.set_params(params.clone(), normalize = False) else: new_ns = ns.duplicate(*new_chs, tie_params = True) elif ns.is_prod(): + if max_block_size is None: + edge_ids = ns.edge_ids.clone() + block_size = ns.block_size + else: + old_ch_blk_size = ns.chs[0].block_size + old_blk_size = ns.block_size + + new_ch_blk_size = new_chs[0].block_size + new_blk_size = min(old_blk_size, max_block_size) + + if old_blk_size == new_blk_size and old_ch_blk_size == new_ch_blk_size: + edge_ids = ns.edge_ids.clone() + block_size = ns.block_size + else: + blk_factor = old_blk_size // new_blk_size + ch_blk_factor = old_ch_blk_size // new_ch_blk_size + + if blk_factor == ch_blk_factor: + edge_ids = ns.edge_ids.clone() + edge_ids = edge_ids[:,None,:].repeat(1, blk_factor, 1) * blk_factor + torch.arange(0, blk_factor)[None,:,None] + edge_ids = edge_ids.flatten(0, 1) + block_size = new_blk_size + else: + raise NotImplementedError() + new_ns = ProdNodes( - ns.num_node_blocks, + ns.num_nodes // block_size, new_chs, - ns.edge_ids.clone(), - block_size = ns.block_size + edge_ids, + block_size = block_size ) else: @@ -76,12 +134,17 @@ def dfs(ns: CircuitNodes): else: scope = pydeepcopy(ns.scope) + if max_block_size is None: + block_size = ns.block_size + else: + block_size = min(ns.block_size, max_block_size) + if not tie_params: new_ns = InputNodes( - num_node_blocks = ns.num_node_blocks, + num_node_blocks = ns.num_nodes // block_size, scope = pydeepcopy(scope), dist = pydeepcopy(ns.dist), - block_size = ns.block_size + block_size = block_size ) params = ns.get_params() if params is not None: diff --git a/src/pyjuice/visualize/visualize.py b/src/pyjuice/visualize/visualize.py index e88a4cae..ca90cd39 100644 --- a/src/pyjuice/visualize/visualize.py +++ b/src/pyjuice/visualize/visualize.py @@ -33,6 +33,8 @@ def plot_pc(ns, """ G = nx.DiGraph() node_list = serialize_nodes(ns) + for item in node_list: + item["num_nodes"] = item["num_node_blocks"] * item["block_size"] pos = {} nx.set_node_attributes(G, [], "node_type") nx.set_node_attributes(G, [], "num_nodes") @@ -102,19 +104,13 @@ def plot_pc(ns, def plot_tensor_node_connection(ns, node_id : int = 0): G = nx.DiGraph() node_list = serialize_nodes(ns) + for item in node_list: + item["num_nodes"] = item["num_node_blocks"] * item["block_size"] node_target = node_list[node_id] pos = {} if node_target['type']=='Input': - # print(f'The target node {node_id} is a Input node...') - # input_node_list = [] - # for i in range(node_target['num_nodes']): - # G.add_node(f'i{i}') - # input_node_list.append(f'i{i}') - # pos = nx.circular_layout(G) - # nx.draw(G, pos, node_color="#A5CE9D", with_labels=True, font_size=6) # green is sum nodes - # return input_node_list #return input node list - print(f"\n>>The target node {node_id} is a Input node, it has {node_target['num_nodes']} nodes, & no connection among them.<<\n") + print(f"\n>>The target node {node_id} is an Input node, it has {node_target['num_nodes']} nodes, & no connection among them.<<\n") return elif node_target['type']=='Product': @@ -158,7 +154,7 @@ def plot_tensor_node_connection(ns, node_id : int = 0): # nx.draw_networkx_nodes(G, pos, nodelist=prod_node_list, node_color="#A2C8DD", **options) # blue is product nodes # nx.draw_networkx_edges(G, pos, width=0.1) - #adjacency matrix heatmap plot + # Adjacency matrix heatmap plot fig, ax = plt.subplots() ax.spy(adjacency_manual) ax.set_aspect('auto') diff --git a/tests/io/io_test.py b/tests/io/io_test.py index 59313a83..6a50b3d6 100644 --- a/tests/io/io_test.py +++ b/tests/io/io_test.py @@ -11,7 +11,7 @@ import pytest -def io_test(): +def test_io(): num_node_blocks = 2 block_size = 4 @@ -46,7 +46,7 @@ def io_test(): assert n0.chs[1].chs[1].dist.num_cats == n0_dup.chs[1].chs[1].dist.num_cats -def io_param_test(): +def test_io_param(): num_node_blocks = 2 block_size = 4 @@ -79,5 +79,5 @@ def io_param_test(): if __name__ == "__main__": - io_test() - io_param_test() + test_io() + test_io_param() diff --git a/tests/layer/input_layer_test.py b/tests/layer/input_layer_test.py index e53200ff..d0e4b47b 100644 --- a/tests/layer/input_layer_test.py +++ b/tests/layer/input_layer_test.py @@ -13,7 +13,7 @@ import pytest -def input_layer_test(): +def test_input_layer(): device = torch.device("cuda:0") @@ -111,7 +111,7 @@ def input_layer_test(): assert torch.all(torch.abs(new_params - layer.params) < 1e-4) -def tied_bp_test(): +def test_tied_bp(): device = torch.device("cuda:0") @@ -166,7 +166,8 @@ def tied_bp_test(): assert torch.all(torch.abs(param_flows - layer.param_flows) < 1e-4) -def speed_test(): +@pytest.mark.slow +def test_speed(): device = torch.device("cuda:0") @@ -283,6 +284,6 @@ def speed_test(): if __name__ == "__main__": - input_layer_test() - tied_bp_test() - speed_test() \ No newline at end of file + test_input_layer() + test_tied_bp() + test_speed() \ No newline at end of file diff --git a/tests/layer/layer_compilation_test.py b/tests/layer/layer_compilation_test.py index f3c4eed6..01072d7c 100644 --- a/tests/layer/layer_compilation_test.py +++ b/tests/layer/layer_compilation_test.py @@ -14,7 +14,7 @@ import pytest -def prod_layer_compilation_test(): +def test_prod_layer_compilation(): for block_size in [1, 8, 16]: @@ -38,7 +38,7 @@ def prod_layer_compilation_test(): prod_layer_cpu = ProdLayer([np0, np1, np2, np3, np4, np5], layer_sparsity_tol = 0.1, disable_gpu_compilation = True) prod_layer_gpu = ProdLayer([np0, np1, np2, np3, np4, np5], layer_sparsity_tol = 0.1, force_gpu_compilation = True) - for i in range(3): + for i in range(2): assert torch.all(prod_layer_cpu.partitioned_nids[i] == prod_layer_gpu.partitioned_nids[i]) assert torch.all(prod_layer_cpu.partitioned_cids[i] == prod_layer_gpu.partitioned_cids[i]) @@ -47,7 +47,7 @@ def prod_layer_compilation_test(): assert torch.all(prod_layer_cpu.partitioned_parids[i] == prod_layer_gpu.partitioned_parids[i]) -def sum_layer_compilation_test(): +def test_sum_layer_compilation(): for block_size in [1, 8, 16]: @@ -112,5 +112,5 @@ def sum_layer_compilation_test(): if __name__ == "__main__": - # prod_layer_compilation_test() - sum_layer_compilation_test() + test_prod_layer_compilation() + test_sum_layer_compilation() diff --git a/tests/layer/prod_layer_test.py b/tests/layer/prod_layer_test.py index 9203e464..0daba678 100644 --- a/tests/layer/prod_layer_test.py +++ b/tests/layer/prod_layer_test.py @@ -14,7 +14,7 @@ import pytest -def prod_layer_test(): +def test_prod_layer(): device = torch.device("cuda:0") @@ -91,7 +91,8 @@ def prod_layer_test(): assert torch.all(torch.abs(node_flows[8*block_size+i,:] - element_flows[4*block_size+i,:]) < 1e-4) -def speed_test(): +@pytest.mark.slow +def test_speed(): device = torch.device("cuda:0") @@ -166,5 +167,5 @@ def speed_test(): if __name__ == "__main__": torch.manual_seed(2390) - prod_layer_test() - speed_test() + test_prod_layer() + test_speed() diff --git a/tests/layer/propagation_algs_test.py b/tests/layer/propagation_algs_test.py new file mode 100644 index 00000000..db90f9bf --- /dev/null +++ b/tests/layer/propagation_algs_test.py @@ -0,0 +1,344 @@ +import pyjuice as juice +import torch +import numpy as np +import time +import random + +import pyjuice.nodes.distributions as dists +from pyjuice.utils import BitSet +from pyjuice.nodes import multiply, summate, inputs +from pyjuice.model import TensorCircuit + +from pyjuice.layer import InputLayer, ProdLayer, SumLayer + +import pytest + + +def test_ll_prop(): + + device = torch.device("cuda:0") + + batch_size = 16 + + for block_size in [1, 4, 8, 16]: + + for allow_modify_flows in [True, False]: + + with juice.set_block_size(block_size): + + ni0 = inputs(0, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni1 = inputs(1, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni2 = inputs(2, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni3 = inputs(3, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + + np0 = multiply(ni0, ni1) + np1 = multiply(ni2, ni3) + np2 = multiply(ni1, ni2) + + ns0 = summate(np0, num_node_blocks = 2) + ns1 = summate(np1, num_node_blocks = 2) + ns2 = summate(np2, num_node_blocks = 2) + + input_layer = InputLayer([ni0, ni1, ni2, ni3], cum_nodes = block_size) + + prod_layer = ProdLayer([np0, np1, np2]) + + layer = SumLayer([ns0, ns1, ns2], global_nid_start = block_size, + global_pid_start = block_size ** 2, + global_pfid_start = 0, node2tiednodes = dict()) + + layer.to(device) + + ## Forward pass ## + + element_mars = torch.rand([block_size + 3 * 2 * 2 * block_size, batch_size]).log().to(device) + element_mars[:block_size,:] = -float("inf") + node_mars = torch.zeros([block_size + block_size * 2 * 3, batch_size]).to(device) + + params = torch.rand([block_size ** 2 + 3 * 4 * block_size * block_size]).to(device) + + layer(node_mars, element_mars, params, propagation_alg = "LL") + + for i in range(block_size): + for j in range(6): + cmars = element_mars[layer.partitioned_cids[0][j,:]].exp() + epars = params[layer.partitioned_pids[0][j,:]+i] + scaled_lls = (epars[:,None] * cmars).sum(dim = 0).log() + + assert torch.all(torch.abs(node_mars[(j+1)*block_size+i,:] - scaled_lls) < 2e-3) + + ## Backward pass ## + + node_flows = torch.rand([block_size + block_size * 2 * 3, batch_size]).to(device) + element_flows = torch.zeros([block_size + 3 * 2 * 2 * block_size, batch_size]).to(device) + + param_flows = torch.zeros([block_size ** 2 + 3 * 4 * block_size * block_size]).to(device) + + origin_node_flows = node_flows.clone() + + layer.backward(node_flows, element_flows, node_mars, element_mars, params, param_flows, + allow_modify_flows = allow_modify_flows, propagation_alg = "LL") + + chids = layer.partitioned_chids[0] + parids = layer.partitioned_parids[0] + parpids = layer.partitioned_parpids[0] + + num_nblocks = chids.size(0) + num_eblocks = parids.size(1) + parids = (parids[:,:,None].repeat(1, 1, block_size) + torch.arange(0, block_size, device = parids.device)).reshape(num_nblocks, num_eblocks * block_size) + parpids_start = (parpids[:,:,None] + torch.arange(0, block_size, device = parids.device)).reshape( + num_nblocks, num_eblocks * block_size) + + for j in range(6): + parpids = parpids_start.clone() + for i in range(block_size): + nmars = node_mars[parids[j,:]].exp() + nflows = origin_node_flows[parids[j,:]] + emars = element_mars[(j+1)*block_size+i,:].exp() + epars = params[parpids[j,:]] + eflows = (nflows * (epars[:,None] * emars[None,:]) / nmars).sum(dim = 0) + + if allow_modify_flows: + uflows1 = node_flows[parids[j,:]] + uflows2 = origin_node_flows[parids[j,:]].log() - nmars.log() + + assert torch.all(torch.abs(uflows1 - uflows2) < 1e-3) + + assert torch.all(torch.abs(eflows - element_flows[(j+1)*block_size+i,:]) < 1e-2) + + parpids += block_size + + my_pflows = torch.zeros_like(param_flows) + + for i in range(block_size): + for j in range(6): + emars = element_mars[layer.partitioned_cids[0][j,:]].exp() + epars = params[layer.partitioned_pids[0][j,:]+i] + nmars = node_mars[(j+1)*block_size+i,:].exp() + nflows = origin_node_flows[(j+1)*block_size+i,:] + pflows = epars * (nflows[None,:] * emars / nmars[None,:]).sum(dim = 1) + + my_pflows[layer.partitioned_pfids[0][j,:]+i] = pflows + + assert torch.all(torch.abs(my_pflows - param_flows) < 2e-3) + + +def test_general_ll_prop(): + + device = torch.device("cuda:0") + + batch_size = 16 + + for block_size in [1, 4, 8, 16]: + + for allow_modify_flows in [True, False]: + + with juice.set_block_size(block_size): + + ni0 = inputs(0, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni1 = inputs(1, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni2 = inputs(2, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni3 = inputs(3, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + + np0 = multiply(ni0, ni1) + np1 = multiply(ni2, ni3) + np2 = multiply(ni1, ni2) + + ns0 = summate(np0, num_node_blocks = 2) + ns1 = summate(np1, num_node_blocks = 2) + ns2 = summate(np2, num_node_blocks = 2) + + input_layer = InputLayer([ni0, ni1, ni2, ni3], cum_nodes = block_size) + + prod_layer = ProdLayer([np0, np1, np2]) + + layer = SumLayer([ns0, ns1, ns2], global_nid_start = block_size, + global_pid_start = block_size ** 2, + global_pfid_start = 0, node2tiednodes = dict()) + + layer.to(device) + + alphas = [1.2, 2.0, 3.0] + + for alpha in alphas: + + ## Forward pass ## + + element_mars = torch.rand([block_size + 3 * 2 * 2 * block_size, batch_size]).log().to(device) + element_mars[:block_size,:] = -float("inf") + node_mars = torch.zeros([block_size + block_size * 2 * 3, batch_size]).to(device) + + params = torch.rand([block_size ** 2 + 3 * 4 * block_size * block_size]).to(device) + + layer(node_mars, element_mars, params, propagation_alg = "GeneralLL", alpha = alpha) + + for i in range(block_size): + for j in range(6): + cmars = element_mars[layer.partitioned_cids[0][j,:]].exp() + epars = params[layer.partitioned_pids[0][j,:]+i] + scaled_lls = (epars[:,None]**alpha * cmars**alpha).sum(dim = 0).log() * (1.0 / alpha) + + assert torch.all(torch.abs(node_mars[(j+1)*block_size+i,:] - scaled_lls) < 2e-3) + + ## Backward pass ## + + node_flows = torch.rand([block_size + block_size * 2 * 3, batch_size]).to(device) + element_flows = torch.zeros([block_size + 3 * 2 * 2 * block_size, batch_size]).to(device) + + param_flows = torch.zeros([block_size ** 2 + 3 * 4 * block_size * block_size]).to(device) + + origin_node_flows = node_flows.clone() + + layer.backward(node_flows, element_flows, node_mars, element_mars, params, param_flows, + allow_modify_flows = allow_modify_flows, propagation_alg = "GeneralLL", alpha = alpha) + + chids = layer.partitioned_chids[0] + parids = layer.partitioned_parids[0] + parpids = layer.partitioned_parpids[0] + + num_nblocks = chids.size(0) + num_eblocks = parids.size(1) + parids = (parids[:,:,None].repeat(1, 1, block_size) + torch.arange(0, block_size, device = parids.device)).reshape(num_nblocks, num_eblocks * block_size) + parpids_start = (parpids[:,:,None] + torch.arange(0, block_size, device = parids.device)).reshape( + num_nblocks, num_eblocks * block_size) + + for j in range(6): + parpids = parpids_start.clone() + for i in range(block_size): + nmars = node_mars[parids[j,:]].exp() + nflows = origin_node_flows[parids[j,:]] + emars = element_mars[(j+1)*block_size+i,:].exp() + epars = params[parpids[j,:]] + eflows = (nflows * (epars[:,None] * emars[None,:]) ** alpha / nmars ** alpha).sum(dim = 0) + + if allow_modify_flows: + uflows1 = node_flows[parids[j,:]] + uflows2 = origin_node_flows[parids[j,:]].log() - nmars.log() * alpha + assert torch.all(torch.abs(uflows1 - uflows2) < 1e-3) + + assert torch.all(torch.abs(eflows - element_flows[(j+1)*block_size+i,:]) < 1e-2) + + parpids += block_size + + my_pflows = torch.zeros_like(param_flows) + + for i in range(block_size): + for j in range(6): + emars = element_mars[layer.partitioned_cids[0][j,:]].exp() + epars = params[layer.partitioned_pids[0][j,:]+i] + nmars = node_mars[(j+1)*block_size+i,:].exp() + nflows = origin_node_flows[(j+1)*block_size+i,:] + pflows = epars * (nflows[None,:] * emars / nmars[None,:]).sum(dim = 1) + + my_pflows[layer.partitioned_pfids[0][j,:]+i] = pflows + + assert torch.all(torch.abs(my_pflows - param_flows) < 4e-3) + + +def test_mpe_prop(): + + device = torch.device("cuda:0") + + batch_size = 16 + + for block_size in [1, 4, 8, 16]: + + for allow_modify_flows in [True, False]: + + with juice.set_block_size(block_size): + + ni0 = inputs(0, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni1 = inputs(1, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni2 = inputs(2, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni3 = inputs(3, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + + np0 = multiply(ni0, ni1) + np1 = multiply(ni2, ni3) + np2 = multiply(ni1, ni2) + + ns0 = summate(np0, num_node_blocks = 2) + ns1 = summate(np1, num_node_blocks = 2) + ns2 = summate(np2, num_node_blocks = 2) + + input_layer = InputLayer([ni0, ni1, ni2, ni3], cum_nodes = block_size) + + prod_layer = ProdLayer([np0, np1, np2]) + + layer = SumLayer([ns0, ns1, ns2], global_nid_start = block_size, + global_pid_start = block_size ** 2, + global_pfid_start = 0, node2tiednodes = dict()) + + layer.to(device) + + ## Forward pass ## + + element_mars = torch.rand([block_size + 3 * 2 * 2 * block_size, batch_size]).log().to(device) + element_mars[:block_size,:] = -float("inf") + node_mars = torch.zeros([block_size + block_size * 2 * 3, batch_size]).to(device) + + params = torch.rand([block_size ** 2 + 3 * 4 * block_size * block_size]).to(device) + + layer(node_mars, element_mars, params, propagation_alg = "MPE") + + for i in range(block_size): + for j in range(6): + cmars = element_mars[layer.partitioned_cids[0][j,:]].exp() + epars = params[layer.partitioned_pids[0][j,:]+i] + scaled_lls = (epars[:,None] * cmars).max(dim = 0).values.log() + assert torch.all(torch.abs(node_mars[(j+1)*block_size+i,:] - scaled_lls) < 1e-3) + + ## Backward pass ## + + node_flows = torch.rand([block_size + block_size * 2 * 3, batch_size]).to(device) + element_flows = torch.zeros([block_size + 3 * 2 * 2 * block_size, batch_size]).to(device) + + param_flows = torch.zeros([block_size ** 2 + 3 * 4 * block_size * block_size]).to(device) + + origin_node_flows = node_flows.clone() + + layer.backward(node_flows, element_flows, node_mars, element_mars, params, param_flows, + allow_modify_flows = allow_modify_flows, propagation_alg = "MPE") + + chids = layer.partitioned_chids[0] + parids = layer.partitioned_parids[0] + parpids = layer.partitioned_parpids[0] + + num_nblocks = chids.size(0) + num_eblocks = parids.size(1) + parids = (parids[:,:,None].repeat(1, 1, block_size) + torch.arange(0, block_size, device = parids.device)).reshape(num_nblocks, num_eblocks * block_size) + parpids_start = (parpids[:,:,None] + torch.arange(0, block_size, device = parids.device)).reshape( + num_nblocks, num_eblocks * block_size) + + for j in range(6): + parpids = parpids_start.clone() + for i in range(block_size): + nmars = node_mars[parids[j,:]].exp() + nflows = origin_node_flows[parids[j,:]] + emars = element_mars[(j+1)*block_size+i,:].exp() + epars = params[parpids[j,:]] + eflows = (nflows * (((epars[:,None] * emars[None,:]) - nmars).abs() < 1e-6).float()).sum(dim = 0) + + assert torch.all(torch.abs(eflows - element_flows[(j+1)*block_size+i,:]) < 1e-2) + + parpids += block_size + + my_pflows = torch.zeros_like(param_flows) + + for i in range(block_size): + for j in range(6): + emars = element_mars[layer.partitioned_cids[0][j,:]].exp() + epars = params[layer.partitioned_pids[0][j,:]+i] + nmars = node_mars[(j+1)*block_size+i,:].exp() + nflows = origin_node_flows[(j+1)*block_size+i,:] + pflows = (nflows[None,:] * ((epars[:,None] * emars - nmars[None,:]).abs() < 1e-6).float()).sum(dim = 1) + + my_pflows[layer.partitioned_pfids[0][j,:]+i] = pflows + + assert torch.all(torch.abs(my_pflows - param_flows) < 2e-3) + + +if __name__ == "__main__": + torch.manual_seed(280) + test_ll_prop() + test_general_ll_prop() + test_mpe_prop() diff --git a/tests/layer/sparse_prod_layer_test.py b/tests/layer/sparse_prod_layer_test.py index d34c0343..4b02357b 100644 --- a/tests/layer/sparse_prod_layer_test.py +++ b/tests/layer/sparse_prod_layer_test.py @@ -14,7 +14,7 @@ import pytest -def sparse_prod_layer_test(): +def test_sparse_prod_layer(): device = torch.device("cuda:0") @@ -104,4 +104,4 @@ def sparse_prod_layer_test(): if __name__ == "__main__": torch.manual_seed(2390) - sparse_prod_layer_test() + test_sparse_prod_layer() diff --git a/tests/layer/sum_layer_test.py b/tests/layer/sum_layer_test.py index 683dacf2..85870d24 100644 --- a/tests/layer/sum_layer_test.py +++ b/tests/layer/sum_layer_test.py @@ -14,7 +14,7 @@ import pytest -def sum_layer_test(): +def test_sum_layer(): device = torch.device("cuda:0") @@ -141,7 +141,7 @@ def sum_layer_test(): assert torch.all(torch.abs(my_pflows - param_flows) < 2e-3) -def corner_case_test(): +def test_corner_case(): device = torch.device("cuda:0") @@ -280,7 +280,8 @@ def corner_case_test(): assert torch.all(torch.abs(my_pflows - param_flows) < 2e-3) -def speed_test(): +@pytest.mark.slow +def test_speed(): device = torch.device("cuda:0") @@ -360,7 +361,8 @@ def speed_test(): print("--------------------------------------------------------------") -def block_sparse_speed_test(): +@pytest.mark.slow +def test_block_sparse_speed(): device = torch.device("cuda:0") @@ -448,7 +450,7 @@ def block_sparse_speed_test(): if __name__ == "__main__": torch.manual_seed(3890) - sum_layer_test() - corner_case_test() - speed_test() - block_sparse_speed_test() \ No newline at end of file + test_sum_layer() + test_corner_case() + test_speed() + test_block_sparse_speed() \ No newline at end of file diff --git a/tests/lvd/counting_lvd_test.py b/tests/lvd/counting_lvd_test.py index 71aac6e6..16cd10e9 100644 --- a/tests/lvd/counting_lvd_test.py +++ b/tests/lvd/counting_lvd_test.py @@ -9,7 +9,7 @@ import pytest -def counting_lvd_test(): +def test_counting_lvd(): num_nodes = 2 with juice.LVDistiller(backend = "counting", pseudocount = 0.0): @@ -32,4 +32,4 @@ def counting_lvd_test(): if __name__ == "__main__": - counting_lvd_test() \ No newline at end of file + test_counting_lvd() \ No newline at end of file diff --git a/tests/model/backward_test.py b/tests/model/backward_test.py index 3fd59541..8da0ac33 100644 --- a/tests/model/backward_test.py +++ b/tests/model/backward_test.py @@ -10,7 +10,7 @@ import pytest -def backward_test(): +def test_backward(): ni0 = inputs(0, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) ni1 = inputs(1, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) @@ -119,7 +119,7 @@ def backward_test(): assert torch.abs(inner_param_flows[13] + inner_param_flows[14] - 1.0) < 1e-4 -def non_sd_pc_backward_test(): +def test_non_sd_pc_backward(): ni00 = inputs(0, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) ni10 = inputs(1, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) ni20 = inputs(2, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) @@ -177,7 +177,7 @@ def non_sd_pc_backward_test(): assert torch.abs(pc.param_flows[11] - fp4) < 1e-3 -def sparse_pc_backward_test(): +def test_sparse_pc_backward(): ni0 = inputs(0, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) ni1 = inputs(1, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) @@ -242,6 +242,6 @@ def sparse_pc_backward_test(): if __name__ == "__main__": - backward_test() - non_sd_pc_backward_test() - sparse_pc_backward_test() \ No newline at end of file + test_backward() + test_non_sd_pc_backward() + test_sparse_pc_backward() \ No newline at end of file diff --git a/tests/model/block_sparse_pc_test.py b/tests/model/block_sparse_pc_test.py index 7f62d4df..83153640 100644 --- a/tests/model/block_sparse_pc_test.py +++ b/tests/model/block_sparse_pc_test.py @@ -14,7 +14,7 @@ import pytest -def block_sparse_pc_test(): +def test_block_sparse_pc(): device = torch.device("cuda:0") @@ -120,9 +120,9 @@ def block_sparse_pc_test(): else: curr_par_flows = torch.matmul(1.0 / ns_vals[ni], np2_vals[ci-num_node_blocks*2].permute(1, 0)) * params[i] - assert torch.all(torch.abs(param_flows[i] - curr_par_flows) < 1e-2) + assert torch.all(torch.abs(param_flows[i] - curr_par_flows) < 1e-3 * curr_par_flows) if __name__ == "__main__": torch.manual_seed(3890) - block_sparse_pc_test() + test_block_sparse_pc() diff --git a/tests/model/compilation_speed_test.py b/tests/model/compilation_speed_test.py index 3f2b48b3..5012f501 100644 --- a/tests/model/compilation_speed_test.py +++ b/tests/model/compilation_speed_test.py @@ -11,7 +11,8 @@ import pytest -def compile_dense_pc_test(): +@pytest.mark.slow +def test_compile_dense_pc(): num_latents = 2048 num_cats = 512 num_vars = 64 @@ -33,7 +34,8 @@ def compile_dense_pc_test(): assert t1 - t0 < 60 -def compile_sparse_pc_test(): +@pytest.mark.slow +def test_compile_sparse_pc(): num_latents = 4096 num_cats = 200 num_vars = 16 @@ -71,5 +73,5 @@ def compile_sparse_pc_test(): if __name__ == "__main__": - compile_dense_pc_test() - compile_sparse_pc_test() \ No newline at end of file + test_compile_dense_pc() + test_compile_sparse_pc() \ No newline at end of file diff --git a/tests/model/forward_test.py b/tests/model/forward_test.py index 3dfdaa96..32f9263e 100644 --- a/tests/model/forward_test.py +++ b/tests/model/forward_test.py @@ -10,7 +10,7 @@ import pytest -def forward_test(): +def test_forward(): ni0 = inputs(0, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) ni1 = inputs(1, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) @@ -71,7 +71,7 @@ def forward_test(): assert torch.abs(pc.node_mars[13,0] - torch.log(s)) < 1e-4 -def non_sd_pc_forward_test(): +def test_non_sd_pc_forward(): ni00 = inputs(0, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) ni10 = inputs(1, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) ni20 = inputs(2, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) @@ -138,7 +138,7 @@ def non_sd_pc_forward_test(): assert torch.abs(torch.exp(pc.node_mars[17,0]) - (f9 + f10 + f11 + f12)) < 1e-3 -def sparse_pc_forward_test(): +def test_sparse_pc_forward(): ni0 = inputs(0, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) ni1 = inputs(1, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) @@ -199,7 +199,7 @@ def sparse_pc_forward_test(): assert torch.abs(pc.node_mars[13,0] - torch.log(s)) < 1e-4 -def non_sd_pc2_forward_test(): +def test_non_sd_pc2_forward(): n0 = inputs(0, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) n1 = multiply(inputs(0, num_nodes = 2, dist = dists.Categorical(num_cats = 2))) n2 = multiply(inputs(0, num_nodes = 2, dist = dists.Categorical(num_cats = 2))) @@ -221,7 +221,7 @@ def non_sd_pc2_forward_test(): if __name__ == "__main__": - forward_test() - non_sd_pc_forward_test() - sparse_pc_forward_test() - non_sd_pc2_forward_test() \ No newline at end of file + test_forward() + test_non_sd_pc_forward() + test_sparse_pc_forward() + test_non_sd_pc2_forward() \ No newline at end of file diff --git a/tests/model/homogeneous_hmm_test.py b/tests/model/homogeneous_hmm_test.py index e6cb8ef9..deb9119e 100644 --- a/tests/model/homogeneous_hmm_test.py +++ b/tests/model/homogeneous_hmm_test.py @@ -11,7 +11,7 @@ import pytest -def homogeneous_hmm_test(): +def test_homogeneous_hmm(): block_size = 1 @@ -217,4 +217,4 @@ def homogeneous_hmm_test(): if __name__ == "__main__": torch.manual_seed(2390) - homogeneous_hmm_test() + test_homogeneous_hmm() diff --git a/tests/model/non_sd_pcs_test.py b/tests/model/non_sd_pcs_test.py index c7be5e37..03cac563 100644 --- a/tests/model/non_sd_pcs_test.py +++ b/tests/model/non_sd_pcs_test.py @@ -10,7 +10,7 @@ import pytest -def non_sd_test(): +def test_non_sd(): ni1 = inputs(0, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) ni2 = inputs(1, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) ni3 = inputs(2, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) @@ -101,6 +101,44 @@ def non_sd_test(): assert torch.all(torch.abs(bk78 - pc.node_flows[7:9,0].cpu()) < 1e-4) +def test_non_sd_generalized_em(): + ni1 = inputs(0, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) + ni2 = inputs(1, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) + ni3 = inputs(2, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) + ni4 = inputs(3, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) + + np12 = multiply(ni1, ni2) + np23 = multiply(ni2, ni3) + np34 = multiply(ni3, ni4) + + ns12 = summate(np12, num_nodes = 2) + ns23 = summate(np23, num_nodes = 2) + ns34 = summate(np34, num_nodes = 2) + + np1 = multiply(ns12, ns34) + np2 = multiply(ni1, ns23, ni4) + np3 = multiply(ni1, ni2, ns34) + + ns = summate(np1, np2, np3, num_nodes = 1) + + pc = TensorCircuit(ns) + + device = torch.device("cuda:0") + pc.to(device) + + data = torch.randint(0, 2, [16, 4]).to(device) + + alpha = 2.0 + + lls = pc(data, propagation_alg = "GeneralLL", alpha = alpha) + + pc.backward(data.permute(1, 0), allow_modify_flows = False, + propagation_alg = "GeneralLL", alpha = alpha) + + pc.update_parameters() + + if __name__ == "__main__": torch.manual_seed(129) - non_sd_test() \ No newline at end of file + test_non_sd() + test_non_sd_generalized_em() \ No newline at end of file diff --git a/tests/model/parameter_tying_test.py b/tests/model/parameter_tying_test.py index 4a0f149f..02874cfb 100644 --- a/tests/model/parameter_tying_test.py +++ b/tests/model/parameter_tying_test.py @@ -11,7 +11,7 @@ import pytest -def simple_structure_test_block1(): +def test_simple_structure_block1(): block_size = 1 @@ -302,7 +302,7 @@ def simple_structure_test_block1(): assert torch.all(torch.abs(param_flows1.reshape(-1) - pc.params[5:13].cpu()) < 1e-4) -def simple_structure_test_block16(): +def test_simple_structure_block16(): block_size = 16 @@ -610,5 +610,5 @@ def simple_structure_test_block16(): if __name__ == "__main__": torch.manual_seed(2390) - simple_structure_test_block1() - simple_structure_test_block16() + test_simple_structure_block1() + test_simple_structure_block16() diff --git a/tests/model/partial_eval_test.py b/tests/model/partial_eval_test.py index e654c787..f6f3751e 100644 --- a/tests/model/partial_eval_test.py +++ b/tests/model/partial_eval_test.py @@ -12,7 +12,7 @@ import pytest -def partial_eval_forward_test(): +def test_partial_eval_forward(): ni0 = inputs(0, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) ni1 = inputs(1, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) @@ -68,7 +68,7 @@ def partial_eval_forward_test(): assert torch.all(torch.abs(lls - lls2) < 1e-4) -def partial_eval_backward_test(): +def test_partial_eval_backward(): ni0 = inputs(0, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) ni1 = inputs(1, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) @@ -112,5 +112,5 @@ def partial_eval_backward_test(): if __name__ == "__main__": - partial_eval_forward_test() - partial_eval_backward_test() + test_partial_eval_forward() + test_partial_eval_backward() diff --git a/tests/model/simple_model_test.py b/tests/model/simple_model_test.py index 0afbdc82..df4d13b6 100644 --- a/tests/model/simple_model_test.py +++ b/tests/model/simple_model_test.py @@ -14,7 +14,7 @@ import pytest -def simple_model_test(): +def test_simple_model(): device = torch.device("cuda:0") @@ -190,7 +190,9 @@ def simple_model_test(): pc.to(device) - data = torch.randint(0, 4, [512, 4], device = device) + batch_size = 512 + + data = torch.randint(0, 4, [batch_size, 4], device = device) lls = pc(data) @@ -303,7 +305,7 @@ def simple_model_test(): ns_parflows = eflows.sum(dim = 1) ref_parflows = param_flows[4096:4192] - assert torch.all(torch.abs(ns_parflows - ref_parflows) < 1e-3) + assert torch.all(torch.abs(ns_parflows - ref_parflows) < 1e-4 * batch_size) sid, eid = ns0._output_ind_range ns0_flows = np4_flows @@ -339,7 +341,7 @@ def simple_model_test(): ns0_parflows = epars * torch.matmul(n_fdm_sub, emars.permute(1, 0)) ref_parflows = param_flows[0:2048].reshape(2, 64, 16).permute(0, 2, 1).reshape(32, 64) - assert torch.all(torch.abs(ns0_parflows - ref_parflows) < 1e-3) + assert torch.all(torch.abs(ns0_parflows - ref_parflows) < 1e-4 * batch_size) ch_lls = np1_lls epars = ns1._params.reshape(2, 2, 16, 16).permute(0, 2, 1, 3).reshape(32, 32) @@ -355,7 +357,7 @@ def simple_model_test(): ns1_parflows = epars * torch.matmul(n_fdm_sub, emars.permute(1, 0)) ref_parflows = param_flows[2048:3072].reshape(2, 32, 16).permute(0, 2, 1).reshape(32, 32) - assert torch.all(torch.abs(ns1_parflows - ref_parflows) < 1e-3) + assert torch.all(torch.abs(ns1_parflows - ref_parflows) < 1e-4 * batch_size) ch_lls = np2_lls epars = ns2._params.reshape(2, 2, 16, 16).permute(0, 2, 1, 3).reshape(32, 32) @@ -371,7 +373,7 @@ def simple_model_test(): ns2_parflows = epars * torch.matmul(n_fdm_sub, emars.permute(1, 0)) ref_parflows = param_flows[3072:4096].reshape(2, 32, 16).permute(0, 2, 1).reshape(32, 32) - assert torch.all(torch.abs(ns2_parflows - ref_parflows) < 1e-3) + assert torch.all(torch.abs(ns2_parflows - ref_parflows) < 1e-4 * batch_size) sid, eid = ni0._output_ind_range ni0_flows = np0_flows + np3_flows + np5_flows + np6_flows @@ -430,13 +432,13 @@ def simple_model_test(): pc.update_param_flows() ref_parflows = ns0._param_flows.reshape(2, 4, 16, 16).permute(0, 2, 1, 3).reshape(32, 64) - assert torch.all(torch.abs(ns0_parflows - ref_parflows) < 1e-3) + assert torch.all(torch.abs(ns0_parflows - ref_parflows) < 1e-4 * batch_size) ref_parflows = ns1._param_flows.reshape(2, 2, 16, 16).permute(0, 2, 1, 3).reshape(32, 32) - assert torch.all(torch.abs(ns1_parflows - ref_parflows) < 1e-3) + assert torch.all(torch.abs(ns1_parflows - ref_parflows) < 1e-4 * batch_size) ref_parflows = ns2._param_flows.reshape(2, 2, 16, 16).permute(0, 2, 1, 3).reshape(32, 32) - assert torch.all(torch.abs(ns2_parflows - ref_parflows) < 1e-3) + assert torch.all(torch.abs(ns2_parflows - ref_parflows) < 1e-4 * batch_size) par_start_ids, pflow_start_ids, blk_sizes, blk_intervals, global_nids, nchs, cum_pflows, metadata = pc.par_update_kwargs @@ -493,9 +495,9 @@ def simple_model_test(): ns1_parflows = ns1._param_flows.reshape(2, 2, 16, 16).permute(0, 2, 1, 3).reshape(32, 32) ns2_parflows = ns2._param_flows.reshape(2, 2, 16, 16).permute(0, 2, 1, 3).reshape(32, 32) - assert torch.all(torch.abs(ns0_parflows.sum(dim = 1) - cum_pflows[0:32]) < 1e-3) - assert torch.all(torch.abs(ns1_parflows.sum(dim = 1) - cum_pflows[32:64]) < 1e-3) - assert torch.all(torch.abs(ns2_parflows.sum(dim = 1) - cum_pflows[64:96]) < 1e-3) + assert torch.all(torch.abs(ns0_parflows.sum(dim = 1) - cum_pflows[0:32]) < 1e-4 * batch_size) + assert torch.all(torch.abs(ns1_parflows.sum(dim = 1) - cum_pflows[32:64]) < 1e-4 * batch_size) + assert torch.all(torch.abs(ns2_parflows.sum(dim = 1) - cum_pflows[64:96]) < 1e-4 * batch_size) assert torch.abs(ns_parflows.sum() - cum_pflows[96]) < 1e-3 ns0_new_params = (ns0_parflows + pseudocount / 64) / (ns0_parflows.sum(dim = 1, keepdim = True) + pseudocount) @@ -524,5 +526,5 @@ def simple_model_test(): if __name__ == "__main__": - torch.manual_seed(23892) - simple_model_test() + # torch.manual_seed(23892) + test_simple_model() diff --git a/tests/model/structured_blk_sparse_pc_test.py b/tests/model/structured_blk_sparse_pc_test.py index 0a4041ba..f9a3ede4 100644 --- a/tests/model/structured_blk_sparse_pc_test.py +++ b/tests/model/structured_blk_sparse_pc_test.py @@ -11,7 +11,7 @@ import pytest -def structured_blk_sparse_pc_test(): +def test_structured_blk_sparse_pc(): device = torch.device("cuda:0") @@ -217,4 +217,4 @@ def structured_blk_sparse_pc_test(): if __name__ == "__main__": torch.manual_seed(89172) - structured_blk_sparse_pc_test() + test_structured_blk_sparse_pc() diff --git a/tests/nodes/input_dists_test.py b/tests/nodes/input_dists_test.py index a3040d60..bf60f0c5 100644 --- a/tests/nodes/input_dists_test.py +++ b/tests/nodes/input_dists_test.py @@ -11,7 +11,7 @@ import pytest -def categorical_nodes_test(): +def test_categorical_nodes(): ni0 = inputs(0, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) ni1 = inputs(1, num_nodes = 2, dist = dists.Categorical(num_cats = 2)) @@ -81,7 +81,7 @@ def categorical_nodes_test(): assert torch.all(torch.abs(new_params - pc.input_layer_group[0].params) < 1e-4) -def bernoulli_nodes_test(): +def test_bernoulli_nodes(): ni0 = inputs(0, num_nodes = 2, dist = dists.Bernoulli()) ni1 = inputs(1, num_nodes = 2, dist = dists.Bernoulli()) @@ -152,7 +152,7 @@ def bernoulli_nodes_test(): assert torch.all(torch.abs(new_params - pc.input_layer_group[0].params) < 1e-4) -def gaussian_nodes_test(): +def test_gaussian_nodes(): ni0 = inputs(0, num_nodes = 2, dist = dists.Gaussian(mu = 0.0, sigma = 1.0)) ni1 = inputs(1, num_nodes = 2, dist = dists.Gaussian(mu = 0.0, sigma = 1.0)) @@ -229,7 +229,7 @@ def gaussian_nodes_test(): assert torch.all(torch.abs(updated_sigma.clamp(min = 0.01) - pc.input_layer_group[0].params.reshape(8, 2)[:,1]) < 1e-4) -def discrete_logistic_nodes_test(): +def test_discrete_logistic_nodes(): ni0 = inputs(0, num_nodes = 2, dist = dists.DiscreteLogistic(val_range = [-1.0, 1.0], num_cats = 5)) ni1 = inputs(1, num_nodes = 2, dist = dists.DiscreteLogistic(val_range = [-1.0, 1.0], num_cats = 5)) @@ -326,7 +326,7 @@ def discrete_logistic_nodes_test(): assert torch.all(torch.abs(updated_s - pc.input_layer_group[0].params.reshape(8, 2)[:,1]) < 1e-4) -def discrete_logistic_nodes_behavior_test(): +def test_discrete_logistic_nodes_behavior(): ni0 = inputs(0, num_nodes = 2, dist = dists.DiscreteLogistic(val_range = [-1.0, 1.0], num_cats = 5)) ni1 = inputs(1, num_nodes = 2, dist = dists.DiscreteLogistic(val_range = [-1.0, 1.0], num_cats = 5)) @@ -361,7 +361,9 @@ def discrete_logistic_nodes_behavior_test(): pc.backward(data.permute(1, 0), flows_memory = 0.0) - pc.mini_batch_em(step_size = 1.0, pseudocount = 0.01) + pc.mini_batch_em(step_size = 1.0, pseudocount = 0.001) + + lls = pc(data) assert lls.mean() > -3.5 @@ -373,7 +375,7 @@ def discrete_logistic_nodes_behavior_test(): assert (ni3._params[0] > 0.65 and ni3._params[2] < 0.4) or (ni3._params[2] > 0.65 and ni3._params[0] < 0.4) -def masked_categorical_nodes_range_test(): +def test_masked_categorical_nodes_range(): ni0 = inputs(0, num_nodes = 2, dist = dists.MaskedCategorical(num_cats = 5, mask_mode = "range"), mask = torch.tensor([[2, 4], [3, 5]])) ni1 = inputs(1, num_nodes = 2, dist = dists.MaskedCategorical(num_cats = 5, mask_mode = "range"), mask = torch.tensor([[2, 4], [3, 5]])) @@ -495,7 +497,7 @@ def masked_categorical_nodes_range_test(): assert torch.all(torch.abs(updated_params - pc.input_layer_group[0].params.reshape(8, 8)) < 1e-4) -def masked_categorical_nodes_full_mask_test(): +def test_masked_categorical_nodes_full_mask(): ni0 = inputs(0, num_nodes = 2, dist = dists.MaskedCategorical(num_cats = 5, mask_mode = "full_mask"), mask = torch.tensor([[0, 0, 1, 1, 0], [0, 0, 0, 1, 1]])) ni1 = inputs(1, num_nodes = 2, dist = dists.MaskedCategorical(num_cats = 5, mask_mode = "full_mask"), mask = torch.tensor([[0, 0, 1, 1, 0], [0, 0, 0, 1, 1]])) @@ -617,7 +619,7 @@ def masked_categorical_nodes_full_mask_test(): assert torch.all(torch.abs(updated_params - pc.input_layer_group[0].params.reshape(8, 11)) < 1e-4) -def masked_categorical_nodes_rev_range_test(): +def test_masked_categorical_nodes_rev_range(): ni0 = inputs(0, num_nodes = 2, dist = dists.MaskedCategorical(num_cats = 5, mask_mode = "rev_range"), mask = torch.tensor([[2, 4], [3, 5]])) ni1 = inputs(1, num_nodes = 2, dist = dists.MaskedCategorical(num_cats = 5, mask_mode = "rev_range"), mask = torch.tensor([[2, 4], [3, 5]])) @@ -732,12 +734,12 @@ def masked_categorical_nodes_rev_range_test(): if __name__ == "__main__": - torch.manual_seed(2390) - categorical_nodes_test() - bernoulli_nodes_test() - gaussian_nodes_test() - discrete_logistic_nodes_test() - discrete_logistic_nodes_behavior_test() - masked_categorical_nodes_range_test() - masked_categorical_nodes_full_mask_test() - masked_categorical_nodes_rev_range_test() + # torch.manual_seed(235) + test_categorical_nodes() + test_bernoulli_nodes() + test_gaussian_nodes() + test_discrete_logistic_nodes() + test_discrete_logistic_nodes_behavior() + test_masked_categorical_nodes_range() + test_masked_categorical_nodes_full_mask() + test_masked_categorical_nodes_rev_range() diff --git a/tests/nodes/nodes_test.py b/tests/nodes/nodes_test.py index 3ec1d88f..ba492494 100644 --- a/tests/nodes/nodes_test.py +++ b/tests/nodes/nodes_test.py @@ -9,7 +9,7 @@ import pytest -def nodes_test(): +def test_nodes(): device = torch.device("cuda:0") @@ -45,4 +45,4 @@ def nodes_test(): if __name__ == "__main__": - nodes_test() \ No newline at end of file + test_nodes() \ No newline at end of file diff --git a/tests/optim/hmm_em_test.py b/tests/optim/hmm_em_test.py new file mode 100644 index 00000000..dc6498fb --- /dev/null +++ b/tests/optim/hmm_em_test.py @@ -0,0 +1,181 @@ +import pyjuice as juice +import torch +import torchvision +import time +from tqdm import tqdm +from torchtext.datasets import PennTreebank +from torchtext.data.utils import get_tokenizer +from torchtext.vocab import build_vocab_from_iterator +from torch.utils.data import TensorDataset, DataLoader +import pyjuice.distributions as dists + +import pytest + + +def load_penn_treebank(seq_length = 32): + + CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .!?,;:-'\"()[]{}" + vocab = {char: idx for idx, char in enumerate(CHARS)} + + # Define a tokenizer + tokenizer = get_tokenizer("basic_english") + + # Load the Penn Treebank dataset + train_dataset, valid_dataset, test_dataset = PennTreebank(root = "./examples/data") + + train_data = [] + for sample in tqdm(train_dataset): + train_data.extend([vocab[token] if token in vocab else len(CHARS) for token in sample]) + + valid_data = [] + for sample in tqdm(valid_dataset): + valid_data.extend([vocab[token] if token in vocab else len(CHARS) for token in sample]) + + test_data = [] + for sample in tqdm(test_dataset): + test_data.extend([vocab[token] if token in vocab else len(CHARS) for token in sample]) + + # Convert to PyTorch tensors + train_data = torch.tensor(train_data) + valid_data = torch.tensor(valid_data) + test_data = torch.tensor(test_data) + + nsamples = train_data.size(0) // seq_length * seq_length + train_data = train_data[:nsamples].reshape(-1, seq_length) + + nsamples = valid_data.size(0) // seq_length * seq_length + valid_data = valid_data[:nsamples].reshape(-1, seq_length) + + nsamples = test_data.size(0) // seq_length * seq_length + test_data = test_data[:nsamples].reshape(-1, seq_length) + + return train_data, valid_data, test_data + + +def train(pc, num_epochs, train_loader, valid_loader, device, propagation_alg, **kwargs): + + best_valid_ll = -10000.0 + for epoch in range(1, num_epochs + 1): + t0 = time.time() + train_ll = 0.0 + for batch in train_loader: + x = batch[0].to(device) + + lls = pc(x, propagation_alg = propagation_alg, **kwargs) + lls.mean().backward() + + train_ll += lls.mean().detach().cpu().numpy().item() + + train_ll /= len(train_loader) + + pc.mini_batch_em(step_size = 1.0, pseudocount = 0.01) + + t1 = time.time() + + with torch.no_grad(): + valid_ll = 0.0 + for batch in valid_loader: + x = batch[0].to(device) + + lls = pc(x, propagation_alg = "LL") + + valid_ll += lls.mean().detach().cpu().numpy().item() + + valid_ll /= len(valid_loader) + + t2 = time.time() + + print(f"[epoch {epoch:3d}][train LL: {train_ll:.2f}; valid LL: {valid_ll:.2f}].....[train forward+backward+step {t1-t0:.2f}; test forward {t2-t1:.2f}]") + + if valid_ll > best_valid_ll: + best_valid_ll = valid_ll + + return best_valid_ll + + +def test_hmm_em(): + + device = torch.device("cuda:0") + + seq_length = 32 + + train_data, valid_data, test_data = load_penn_treebank(seq_length = seq_length) + + vocab_size = train_data.max().item() + 1 + + train_loader = DataLoader( + dataset = TensorDataset(train_data), + batch_size = 512, + shuffle = True, + drop_last = True + ) + valid_loader = DataLoader( + dataset = TensorDataset(valid_data), + batch_size = 512, + shuffle = False, + drop_last = True + ) + + print(f"> Number of training samples: {train_data.size(0)}") + print(f"> Number of validation samples: {valid_data.size(0)}") + + root_ns = juice.structures.HMM( + seq_length = seq_length, + num_latents = 256, + num_emits = vocab_size, + homogeneous = True + ) + + pc = juice.compile(root_ns) + pc.to(device) + + best_valid_ll = train(pc, 20, train_loader, valid_loader, device, propagation_alg = "LL") + + assert best_valid_ll > -95.0 + + +@pytest.mark.slow +def test_hmm_em_slow(): + + device = torch.device("cuda:0") + + seq_length = 32 + + train_data, valid_data, test_data = load_penn_treebank(seq_length = seq_length) + + vocab_size = train_data.max().item() + 1 + + train_loader = DataLoader( + dataset = TensorDataset(train_data), + batch_size = 512, + shuffle = True, + drop_last = True + ) + valid_loader = DataLoader( + dataset = TensorDataset(valid_data), + batch_size = 512, + shuffle = False, + drop_last = True + ) + + print(f"> Number of training samples: {train_data.size(0)}") + print(f"> Number of validation samples: {valid_data.size(0)}") + + root_ns = juice.structures.HMM( + seq_length = seq_length, + num_latents = 256, + num_emits = vocab_size, + homogeneous = True + ) + + pc = juice.compile(root_ns) + pc.to(device) + + best_valid_ll = train(pc, 200, train_loader, valid_loader, device, propagation_alg = "LL") + + assert best_valid_ll > -95.0 + + +if __name__ == "__main__": + test_hmm_em() + test_hmm_em_slow() \ No newline at end of file diff --git a/tests/optim/hmm_general_em_test.py b/tests/optim/hmm_general_em_test.py new file mode 100644 index 00000000..210ceed5 --- /dev/null +++ b/tests/optim/hmm_general_em_test.py @@ -0,0 +1,224 @@ +import pyjuice as juice +import torch +import torchvision +import time +from tqdm import tqdm +from torchtext.datasets import PennTreebank +from torchtext.data.utils import get_tokenizer +from torchtext.vocab import build_vocab_from_iterator +from torch.utils.data import TensorDataset, DataLoader +import pyjuice.distributions as dists + +import pytest + + +def load_penn_treebank(seq_length = 32): + + CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .!?,;:-'\"()[]{}" + vocab = {char: idx for idx, char in enumerate(CHARS)} + + # Define a tokenizer + tokenizer = get_tokenizer("basic_english") + + # Load the Penn Treebank dataset + train_dataset, valid_dataset, test_dataset = PennTreebank(root = "./examples/data") + + train_data = [] + for sample in tqdm(train_dataset): + train_data.extend([vocab[token] if token in vocab else len(CHARS) for token in sample]) + + valid_data = [] + for sample in tqdm(valid_dataset): + valid_data.extend([vocab[token] if token in vocab else len(CHARS) for token in sample]) + + test_data = [] + for sample in tqdm(test_dataset): + test_data.extend([vocab[token] if token in vocab else len(CHARS) for token in sample]) + + # Convert to PyTorch tensors + train_data = torch.tensor(train_data) + valid_data = torch.tensor(valid_data) + test_data = torch.tensor(test_data) + + nsamples = train_data.size(0) // seq_length * seq_length + train_data = train_data[:nsamples].reshape(-1, seq_length) + + nsamples = valid_data.size(0) // seq_length * seq_length + valid_data = valid_data[:nsamples].reshape(-1, seq_length) + + nsamples = test_data.size(0) // seq_length * seq_length + test_data = test_data[:nsamples].reshape(-1, seq_length) + + return train_data, valid_data, test_data + + +def train(pc, num_epochs, train_loader, valid_loader, device): + + best_valid_ll = -10000.0 + for epoch in range(1, num_epochs + 1): + t0 = time.time() + train_ll = 0.0 + for batch in train_loader: + x = batch[0].to(device) + + lls = pc(x, propagation_alg = "GeneralLL", alpha = 1.2) + lls.mean().backward() + + train_ll += lls.mean().detach().cpu().numpy().item() + + train_ll /= len(train_loader) + + pc.mini_batch_em(step_size = 1.0, pseudocount = 0.01) + + t1 = time.time() + + with torch.no_grad(): + valid_ll = 0.0 + for batch in valid_loader: + x = batch[0].to(device) + + lls = pc(x, propagation_alg = "LL") + + valid_ll += lls.mean().detach().cpu().numpy().item() + + valid_ll /= len(valid_loader) + + t2 = time.time() + + print(f"[epoch {epoch:3d}][train LL: {train_ll:.2f}; valid LL: {valid_ll:.2f}].....[train forward+backward+step {t1-t0:.2f}; test forward {t2-t1:.2f}]") + + if valid_ll > best_valid_ll: + best_valid_ll = valid_ll + + return best_valid_ll + + +@pytest.mark.slow +def test_hmm_general_ll(): + + device = torch.device("cuda:0") + + seq_length = 32 + + train_data, valid_data, test_data = load_penn_treebank(seq_length = seq_length) + + vocab_size = train_data.max().item() + 1 + + train_loader = DataLoader( + dataset = TensorDataset(train_data), + batch_size = 512, + shuffle = True, + drop_last = True + ) + valid_loader = DataLoader( + dataset = TensorDataset(valid_data), + batch_size = 512, + shuffle = False, + drop_last = True + ) + + print(f"> Number of training samples: {train_data.size(0)}") + print(f"> Number of validation samples: {valid_data.size(0)}") + + root_ns = juice.structures.HMM( + seq_length = seq_length, + num_latents = 256, + num_emits = vocab_size, + homogeneous = True + ) + + pc = juice.compile(root_ns) + pc.to(device) + + best_valid_ll = train(pc, 40, train_loader, valid_loader, device) + + assert best_valid_ll > -85.0 + + +@pytest.mark.slow +def test_hmm_general_ll_slow(): + + device = torch.device("cuda:0") + + seq_length = 32 + + train_data, valid_data, test_data = load_penn_treebank(seq_length = seq_length) + + vocab_size = train_data.max().item() + 1 + + train_loader = DataLoader( + dataset = TensorDataset(train_data), + batch_size = 512, + shuffle = True, + drop_last = True + ) + valid_loader = DataLoader( + dataset = TensorDataset(valid_data), + batch_size = 512, + shuffle = False, + drop_last = True + ) + + print(f"> Number of training samples: {train_data.size(0)}") + print(f"> Number of validation samples: {valid_data.size(0)}") + + root_ns = juice.structures.HMM( + seq_length = seq_length, + num_latents = 256, + num_emits = vocab_size, + homogeneous = True + ) + + pc = juice.compile(root_ns) + pc.to(device) + + best_valid_ll = train(pc, 300, train_loader, valid_loader, device) + + assert best_valid_ll > -85.0 + + +def test_hmm_general_ll_fast(): + + device = torch.device("cuda:0") + + seq_length = 32 + + train_data, valid_data, test_data = load_penn_treebank(seq_length = seq_length) + + vocab_size = train_data.max().item() + 1 + + train_loader = DataLoader( + dataset = TensorDataset(train_data), + batch_size = 512, + shuffle = True, + drop_last = True + ) + valid_loader = DataLoader( + dataset = TensorDataset(valid_data), + batch_size = 512, + shuffle = False, + drop_last = True + ) + + print(f"> Number of training samples: {train_data.size(0)}") + print(f"> Number of validation samples: {valid_data.size(0)}") + + root_ns = juice.structures.HMM( + seq_length = seq_length, + num_latents = 64, + num_emits = vocab_size, + homogeneous = True + ) + + pc = juice.compile(root_ns) + pc.to(device) + + best_valid_ll = train(pc, 20, train_loader, valid_loader, device) + + assert best_valid_ll > -92.0 + + +if __name__ == "__main__": + # test_hmm_general_ll() + test_hmm_general_ll_fast() + test_hmm_general_ll_slow() \ No newline at end of file diff --git a/tests/optim/hmm_viterbi_test.py b/tests/optim/hmm_viterbi_test.py new file mode 100644 index 00000000..f265b82a --- /dev/null +++ b/tests/optim/hmm_viterbi_test.py @@ -0,0 +1,224 @@ +import pyjuice as juice +import torch +import torchvision +import time +from tqdm import tqdm +from torchtext.datasets import PennTreebank +from torchtext.data.utils import get_tokenizer +from torchtext.vocab import build_vocab_from_iterator +from torch.utils.data import TensorDataset, DataLoader +import pyjuice.distributions as dists + +import pytest + + +def load_penn_treebank(seq_length = 32): + + CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .!?,;:-'\"()[]{}" + vocab = {char: idx for idx, char in enumerate(CHARS)} + + # Define a tokenizer + tokenizer = get_tokenizer("basic_english") + + # Load the Penn Treebank dataset + train_dataset, valid_dataset, test_dataset = PennTreebank(root = "./examples/data") + + train_data = [] + for sample in tqdm(train_dataset): + train_data.extend([vocab[token] if token in vocab else len(CHARS) for token in sample]) + + valid_data = [] + for sample in tqdm(valid_dataset): + valid_data.extend([vocab[token] if token in vocab else len(CHARS) for token in sample]) + + test_data = [] + for sample in tqdm(test_dataset): + test_data.extend([vocab[token] if token in vocab else len(CHARS) for token in sample]) + + # Convert to PyTorch tensors + train_data = torch.tensor(train_data) + valid_data = torch.tensor(valid_data) + test_data = torch.tensor(test_data) + + nsamples = train_data.size(0) // seq_length * seq_length + train_data = train_data[:nsamples].reshape(-1, seq_length) + + nsamples = valid_data.size(0) // seq_length * seq_length + valid_data = valid_data[:nsamples].reshape(-1, seq_length) + + nsamples = test_data.size(0) // seq_length * seq_length + test_data = test_data[:nsamples].reshape(-1, seq_length) + + return train_data, valid_data, test_data + + +def train(pc, num_epochs, train_loader, valid_loader, device, propagation_alg, **kwargs): + + best_valid_ll = -10000.0 + for epoch in range(1, num_epochs + 1): + t0 = time.time() + train_ll = 0.0 + for batch in train_loader: + x = batch[0].to(device) + + lls = pc(x, propagation_alg = propagation_alg, **kwargs) + lls.mean().backward() + + train_ll += lls.mean().detach().cpu().numpy().item() + + train_ll /= len(train_loader) + + pc.mini_batch_em(step_size = 1.0, pseudocount = 0.01) + + t1 = time.time() + + with torch.no_grad(): + valid_ll = 0.0 + for batch in valid_loader: + x = batch[0].to(device) + + lls = pc(x, propagation_alg = "LL") + + valid_ll += lls.mean().detach().cpu().numpy().item() + + valid_ll /= len(valid_loader) + + t2 = time.time() + + print(f"[epoch {epoch:3d}][train LL: {train_ll:.2f}; valid LL: {valid_ll:.2f}].....[train forward+backward+step {t1-t0:.2f}; test forward {t2-t1:.2f}]") + + if valid_ll > best_valid_ll: + best_valid_ll = valid_ll + + return best_valid_ll + + +@pytest.mark.slow +def test_hmm_viterbi(): + + device = torch.device("cuda:0") + + seq_length = 32 + + train_data, valid_data, test_data = load_penn_treebank(seq_length = seq_length) + + vocab_size = train_data.max().item() + 1 + + train_loader = DataLoader( + dataset = TensorDataset(train_data), + batch_size = 512, + shuffle = True, + drop_last = True + ) + valid_loader = DataLoader( + dataset = TensorDataset(valid_data), + batch_size = 512, + shuffle = False, + drop_last = True + ) + + print(f"> Number of training samples: {train_data.size(0)}") + print(f"> Number of validation samples: {valid_data.size(0)}") + + root_ns = juice.structures.HMM( + seq_length = seq_length, + num_latents = 256, + num_emits = vocab_size, + homogeneous = True + ) + + pc = juice.compile(root_ns) + pc.to(device) + + best_valid_ll = train(pc, 20, train_loader, valid_loader, device, propagation_alg = "MPE") + + assert best_valid_ll > -90.0 + + +@pytest.mark.slow +def test_hmm_viterbi_slow(): + + device = torch.device("cuda:0") + + seq_length = 32 + + train_data, valid_data, test_data = load_penn_treebank(seq_length = seq_length) + + vocab_size = train_data.max().item() + 1 + + train_loader = DataLoader( + dataset = TensorDataset(train_data), + batch_size = 512, + shuffle = True, + drop_last = True + ) + valid_loader = DataLoader( + dataset = TensorDataset(valid_data), + batch_size = 512, + shuffle = False, + drop_last = True + ) + + print(f"> Number of training samples: {train_data.size(0)}") + print(f"> Number of validation samples: {valid_data.size(0)}") + + root_ns = juice.structures.HMM( + seq_length = seq_length, + num_latents = 256, + num_emits = vocab_size, + homogeneous = True + ) + + pc = juice.compile(root_ns) + pc.to(device) + + best_valid_ll = train(pc, 200, train_loader, valid_loader, device, propagation_alg = "MPE") + + assert best_valid_ll > -90.0 + + +def test_hmm_viterbi_fast(): + + device = torch.device("cuda:0") + + seq_length = 32 + + train_data, valid_data, test_data = load_penn_treebank(seq_length = seq_length) + + vocab_size = train_data.max().item() + 1 + + train_loader = DataLoader( + dataset = TensorDataset(train_data), + batch_size = 512, + shuffle = True, + drop_last = True + ) + valid_loader = DataLoader( + dataset = TensorDataset(valid_data), + batch_size = 512, + shuffle = False, + drop_last = True + ) + + print(f"> Number of training samples: {train_data.size(0)}") + print(f"> Number of validation samples: {valid_data.size(0)}") + + root_ns = juice.structures.HMM( + seq_length = seq_length, + num_latents = 64, + num_emits = vocab_size, + homogeneous = True + ) + + pc = juice.compile(root_ns) + pc.to(device) + + best_valid_ll = train(pc, 5, train_loader, valid_loader, device, propagation_alg = "MPE") + + assert best_valid_ll > -90.0 + + +if __name__ == "__main__": + # test_hmm_viterbi() + test_hmm_viterbi_slow() + # test_hmm_viterbi_fast() \ No newline at end of file diff --git a/tests/queries/cond_test.py b/tests/queries/cond_test.py index 6bc31e8c..8298fe07 100644 --- a/tests/queries/cond_test.py +++ b/tests/queries/cond_test.py @@ -13,7 +13,7 @@ import pytest -def cat_soft_cond_test(): +def test_cat_soft_cond(): device = torch.device("cuda:0") @@ -54,4 +54,4 @@ def cat_soft_cond_test(): torch.manual_seed(123) torch.cuda.manual_seed(123) - cat_soft_cond_test() + test_cat_soft_cond() diff --git a/tests/queries/marginal_test.py b/tests/queries/marginal_test.py index e4ac1c11..ddb8035e 100644 --- a/tests/queries/marginal_test.py +++ b/tests/queries/marginal_test.py @@ -10,7 +10,7 @@ import pytest -def cat_hard_marginal_test(): +def test_cat_hard_marginal(): device = torch.device("cuda:0") @@ -38,7 +38,7 @@ def cat_hard_marginal_test(): assert torch.all((torch.log(p0 * pc.params[1] + p1 * pc.params[2]) - lls[:,0]).abs() < 1e-4) -def cat_soft_marginal_test(): +def test_cat_soft_marginal(): device = torch.device("cuda:0") @@ -71,5 +71,5 @@ def cat_soft_marginal_test(): torch.manual_seed(123) torch.cuda.manual_seed(123) - cat_hard_marginal_test() - cat_soft_marginal_test() \ No newline at end of file + test_cat_hard_marginal() + test_cat_soft_marginal() \ No newline at end of file diff --git a/tests/structures/hclt_correctness_test.py b/tests/structures/hclt_correctness_test.py index 45f58f60..ccc9df2f 100644 --- a/tests/structures/hclt_correctness_test.py +++ b/tests/structures/hclt_correctness_test.py @@ -7,7 +7,7 @@ import pyjuice.nodes.distributions as dists -def hclt_forward_test(): +def test_hclt_forward(): device = torch.device("cuda:0") @@ -97,7 +97,7 @@ def hclt_forward_test(): assert torch.all(torch.abs(nmars - node_mars[sid:eid,:]) < 4e-3) -def hclt_backward_test(): +def test_hclt_single_layer_backward(): device = torch.device("cuda:0") @@ -128,19 +128,228 @@ def hclt_backward_test(): data_cpu = batch_data.cpu().long() batch_size = batch_data.size(0) + pc.init_param_flows(flows_memory = 0.0) + lls = pc(batch_data) pc.backward(batch_data.permute(1, 0), allow_modify_flows = False) pc.update_param_flows() - node_mars = pc.node_mars.cpu() - node_flows = pc.node_flows.cpu() + for layer_id in range(1, len(pc.inner_layer_groups) - 2, 2): + + node_mars = pc.node_mars.clone() + node_flows = pc.node_flows.clone() + element_mars = pc.element_mars.clone() + element_flows = pc.element_flows.clone() + params = pc.params.clone() + param_flows = pc.param_flows.clone().zero_() + + my_layer = pc.inner_layer_groups[layer_id][0] + previous_layer = pc.inner_layer_groups[layer_id-1][0] + + previous_layer.forward(node_mars, element_mars, _for_backward = True) + + my_layer.backward(node_flows, element_flows, node_mars, element_mars, params, + param_flows = param_flows, allow_modify_flows = False, propagation_alg = "LL") + + chids = my_layer.partitioned_chids[0] + parids = my_layer.partitioned_parids[0] + parpids = my_layer.partitioned_parpids[0] + + nids = my_layer.partitioned_nids[0] + cids = my_layer.partitioned_cids[0] + pids = my_layer.partitioned_pids[0] + pfids = my_layer.partitioned_pfids[0] + + for i in range(chids.size(0)): + eflows = torch.zeros([block_size, batch_size], dtype = torch.float32, device = device) + + for j in range(parids.size(1)): + nflows = node_flows[parids[i,j]:parids[i,j]+block_size,:] # [num_par_nodes, batch_size] + nmars = node_mars[parids[i,j]:parids[i,j]+block_size,:] # [num_par_nodes, batch_size] + emars = element_mars[chids[i]:chids[i]+block_size,:] # [num_ch_nodes, batch_size] + epars = params[parpids[i,j]:parpids[i,j]+block_size**2].reshape(block_size, block_size) # [num_ch_nodes, num_par_nodes] + fpars = param_flows[pfids[i,j]:pfids[i,j]+block_size**2].reshape(block_size, block_size) # [num_ch_nodes, num_par_nodes] + + curr_eflows = (nflows[None,:,:] * (epars.log()[:,:,None] + emars[:,None,:] - nmars[None,:,:]).exp()).sum(dim = 1) + eflows += curr_eflows + + assert torch.all(torch.abs(eflows - element_flows[chids[i]:chids[i]+block_size,:]) < 1e-3) + + for i in range(nids.size(0)): + for j in range(0, cids.size(1), block_size): + nflows = node_flows[nids[i]:nids[i]+block_size,:] # [num_par_nodes, batch_size] + nmars = node_mars[nids[i]:nids[i]+block_size,:] # [num_par_nodes, batch_size] + emars = element_mars[cids[i,j]:cids[i,j]+block_size,:] # [num_ch_nodes, batch_size] + epars = params[pids[i,j]:pids[i,j]+block_size**2].reshape(block_size, block_size) # [num_ch_nodes, num_par_nodes] + fpars = param_flows[pfids[i,j]:pfids[i,j]+block_size**2].reshape(block_size, block_size) # [num_ch_nodes, num_par_nodes] + + pflows = (nflows[None,:,:] * (epars.log()[:,:,None] + emars[:,None,:] - nmars[None,:,:]).exp()).sum(dim = 2) + + assert torch.all(torch.abs(fpars - pflows) < 3e-4 * batch_size) + + +def test_hclt_single_layer_backward_general_em(): + + device = torch.device("cuda:0") + + train_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = True, download = True) + + train_data = train_dataset.data.reshape(60000, 28*28)[:5000,:] + + num_features = train_data.size(1) + num_latents = 128 + + root_ns = juice.structures.HCLT( + train_data.float().to(device), + num_bins = 32, + sigma = 0.5 / 32, + num_latents = num_latents, + chunk_size = 32 + ) + root_ns.init_parameters() + + pc = juice.TensorCircuit(root_ns) + + pc.to(device) + + block_size = root_ns.chs[0].block_size + num_blocks = num_latents // block_size + + batch_data = train_data[:512,:].contiguous().to(device) + data_cpu = batch_data.cpu().long() + batch_size = batch_data.size(0) + + alpha = 2.0 + + pc.init_param_flows(flows_memory = 0.0) + + lls = pc(batch_data, propagation_alg = "GeneralLL", alpha = alpha) + pc.backward(batch_data.permute(1, 0), allow_modify_flows = False, + propagation_alg = "GeneralLL", alpha = alpha) + + pc.update_param_flows() + + for layer_id in range(1, len(pc.inner_layer_groups) - 2, 2): + + node_mars = pc.node_mars.clone() + node_flows = pc.node_flows.clone() + element_mars = pc.element_mars.clone() + element_flows = pc.element_flows.clone() + params = pc.params.clone() + param_flows = pc.param_flows.clone().zero_() + + my_layer = pc.inner_layer_groups[layer_id][0] + previous_layer = pc.inner_layer_groups[layer_id-1][0] + + previous_layer.forward(node_mars, element_mars, _for_backward = True) + + my_layer.backward(node_flows, element_flows, node_mars, element_mars, params, + param_flows = param_flows, allow_modify_flows = False, + propagation_alg = "GeneralLL", alpha = alpha) + + chids = my_layer.partitioned_chids[0] + parids = my_layer.partitioned_parids[0] + parpids = my_layer.partitioned_parpids[0] + + nids = my_layer.partitioned_nids[0] + cids = my_layer.partitioned_cids[0] + pids = my_layer.partitioned_pids[0] + pfids = my_layer.partitioned_pfids[0] + + for i in range(chids.size(0)): + eflows = torch.zeros([block_size, batch_size], dtype = torch.float32, device = device) + + for j in range(parids.size(1)): + nflows = node_flows[parids[i,j]:parids[i,j]+block_size,:] # [num_par_nodes, batch_size] + nmars = node_mars[parids[i,j]:parids[i,j]+block_size,:] # [num_par_nodes, batch_size] + emars = element_mars[chids[i]:chids[i]+block_size,:] # [num_ch_nodes, batch_size] + epars = params[parpids[i,j]:parpids[i,j]+block_size**2].reshape(block_size, block_size) # [num_ch_nodes, num_par_nodes] + fpars = param_flows[pfids[i,j]:pfids[i,j]+block_size**2].reshape(block_size, block_size) # [num_ch_nodes, num_par_nodes] + + curr_eflows = (nflows[None,:,:] * ((epars.log()[:,:,None] + emars[:,None,:] - nmars[None,:,:]) * alpha).exp()).sum(dim = 1) + eflows += curr_eflows + + assert torch.all(torch.abs(eflows - element_flows[chids[i]:chids[i]+block_size,:]) < 1e-3) + + for i in range(nids.size(0)): + for j in range(0, cids.size(1), block_size): + nflows = node_flows[nids[i]:nids[i]+block_size,:] # [num_par_nodes, batch_size] + nmars = node_mars[nids[i]:nids[i]+block_size,:] # [num_par_nodes, batch_size] + emars = element_mars[cids[i,j]:cids[i,j]+block_size,:] # [num_ch_nodes, batch_size] + epars = params[pids[i,j]:pids[i,j]+block_size**2].reshape(block_size, block_size) # [num_ch_nodes, num_par_nodes] + fpars = param_flows[pfids[i,j]:pfids[i,j]+block_size**2].reshape(block_size, block_size) # [num_ch_nodes, num_par_nodes] + + pflows = (nflows[None,:,:] * (epars.log()[:,:,None] + emars[:,None,:] - nmars[None,:,:]).exp()).sum(dim = 2) + + assert torch.all(torch.abs(fpars - pflows) < 3e-4 * batch_size) + + +def test_hclt_backward(): + + device = torch.device("cuda:0") + + train_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = True, download = True) + + train_data = train_dataset.data.reshape(60000, 28*28)[:5000,:] + + num_features = train_data.size(1) + num_latents = 128 + + root_ns = juice.structures.HCLT( + train_data.float().to(device), + num_bins = 32, + sigma = 0.5 / 32, + num_latents = num_latents, + chunk_size = 32 + ) + root_ns.init_parameters() + + pc = juice.TensorCircuit(root_ns) + + pc.to(device) + + block_size = root_ns.chs[0].block_size + num_blocks = num_latents // block_size + + batch_data = train_data[:512,:].contiguous().to(device) + data_cpu = batch_data.long() + batch_size = batch_data.size(0) + + pc.init_param_flows(flows_memory = 0.0) + + lls = pc(batch_data) + pc.backward(batch_data.permute(1, 0), allow_modify_flows = False) + + pc.update_param_flows() + + node_mars = pc.node_mars + node_flows = pc.node_flows + + temp_node_mars = pc.node_mars.clone() + temp_node_flows = pc.node_flows.clone() + temp_element_mars = pc.element_mars.clone() + temp_element_flows = pc.element_flows.clone() + temp_params = pc.params + temp_param_flows = pc.param_flows.clone() ns2flows = dict() - ns2flows[root_ns] = torch.ones([1, batch_size]) + ns2flows[root_ns] = torch.ones([1, batch_size], device = device) + + gt_ch_flows = dict() + + ch2par = dict() + for ns in root_ns: + for cs in ns.chs: + if cs not in ch2par: + ch2par[cs] = set() + ch2par[cs].add(ns) + + visited = set() with torch.no_grad(): for ns in root_ns(reverse = True): + visited.add(ns) if ns == root_ns: sid, eid = ns._output_ind_range @@ -150,14 +359,14 @@ def hclt_backward_test(): nmars = node_mars[sid:eid,:] for i, cs in enumerate(ns.chs): - params = ns._params.reshape(1, num_blocks * ns.num_chs, 1, block_size).permute(0, 2, 1, 3) + params = ns._params.reshape(1, num_blocks * ns.num_chs, 1, block_size).permute(0, 2, 1, 3).to(device) params = params[:,:,i*num_blocks:(i+1)*num_blocks,:].reshape(1, num_latents) - param_flows = ns._param_flows.reshape(1, num_blocks * ns.num_chs, 1, block_size).permute(0, 2, 1, 3) + param_flows = ns._param_flows.reshape(1, num_blocks * ns.num_chs, 1, block_size).permute(0, 2, 1, 3).to(device) param_flows = param_flows[:,:,i*num_blocks:(i+1)*num_blocks,:].reshape(1, num_latents) if cs.is_prod(): - emars = torch.zeros([num_latents, batch_size]) + emars = torch.zeros([num_latents, batch_size], device = device) for cns in cs.chs: sid, eid = cns._output_ind_range emars += node_mars[sid:eid,:] @@ -170,58 +379,135 @@ def hclt_backward_test(): assert torch.all(torch.abs(pflows - param_flows[0,:]) < 6e-3) if cs not in ns2flows: - ns2flows[cs] = torch.zeros([num_latents, batch_size]) + ns2flows[cs] = torch.zeros([num_latents, batch_size], device = device) ns2flows[cs] += eflows + gt_ch_flows[cs] = eflows.detach().clone() + elif ns.is_prod(): nflows = ns2flows[ns] + gt_flows = gt_ch_flows[ns] + + assert torch.all(torch.abs(gt_flows - nflows) < 1e-4) + for cs in ns.chs: if cs not in ns2flows: - ns2flows[cs] = torch.zeros([num_latents, batch_size]) + ns2flows[cs] = torch.zeros([num_latents, batch_size], device = device) ns2flows[cs] += nflows elif ns.is_sum(): + for par_cs in ch2par[ns]: + assert par_cs in visited + nflows = ns2flows[ns] sid, eid = ns._output_ind_range - assert (torch.abs(nflows - node_flows[sid:eid,:]) > 1e-3).float().mean() < 0.02 + # if len(ns.scope) > 2: + # if not torch.all(torch.abs(nflows - node_flows[sid:eid,:]) < 1e-3): + # import pdb; pdb.set_trace() + # assert torch.all(torch.abs(nflows - node_flows[sid:eid,:]) < 1e-3) + + # ns2flows[ns] = node_flows[sid:eid,:] + # print(">>>>>>", torch.abs(nflows - node_flows[sid:eid,:]).max()) + + nflows = node_flows[sid:eid,:] nmars = node_mars[sid:eid,:] + ch_eflows = [] + ch_pflows = [] + for i, cs in enumerate(ns.chs): - params = ns._params.reshape(num_blocks, num_blocks * ns.num_chs, block_size, block_size).permute(0, 2, 1, 3) + params = ns._params.reshape(num_blocks, num_blocks * ns.num_chs, block_size, block_size).permute(0, 2, 1, 3).to(device) params = params[:,:,i*num_blocks:(i+1)*num_blocks,:].reshape(num_latents, num_latents) - param_flows = ns._param_flows.reshape(num_blocks, num_blocks * ns.num_chs, block_size, block_size).permute(0, 2, 1, 3) + param_flows = ns._param_flows.reshape(num_blocks, num_blocks * ns.num_chs, block_size, block_size).permute(0, 2, 1, 3).to(device) param_flows = param_flows[:,:,i*num_blocks:(i+1)*num_blocks,:].reshape(num_latents, num_latents) if cs.is_prod(): - emars = torch.zeros([num_latents, batch_size]) + emars = torch.zeros([num_latents, batch_size], device = device) for cns in cs.chs: sid, eid = cns._output_ind_range emars += node_mars[sid:eid,:] else: raise ValueError() + eflows = (nflows[None,:,:] * (params.permute(1, 0).log()[:,:,None] + emars[:,None,:] - nmars[None,:,:]).exp()).sum(dim = 1) + pflows = (nflows[None,:,:] * (params.permute(1, 0).log()[:,:,None] + emars[:,None,:] - nmars[None,:,:]).exp()).sum(dim = 2).permute(1, 0) + log_n_fdm = nflows.log() - nmars log_n_fdm_max = log_n_fdm.max(dim = 0).values n_fdm_sub = (log_n_fdm - log_n_fdm_max[None,:]).exp() - eflows = torch.matmul(params.permute(1, 0), n_fdm_sub) * (emars + log_n_fdm_max[None,:]).exp() + eflows_prim = torch.matmul(params.permute(1, 0), n_fdm_sub) * (emars + log_n_fdm_max[None,:]).exp() scaled_emars = (emars + log_n_fdm_max[None,:]).exp() - pflows = torch.matmul(n_fdm_sub, scaled_emars.permute(1, 0)) * params + pflows_prim = torch.matmul(n_fdm_sub, scaled_emars.permute(1, 0)) * params + + # From `pc` + pc_eflows = (node_flows[sid:eid,:][None,:,:] * (params.permute(1, 0).log()[:,:,None] + emars[:,None,:] - nmars[None,:,:]).exp()).sum(dim = 1) + pc_pflows = (node_flows[sid:eid,:][None,:,:] * (params.permute(1, 0).log()[:,:,None] + emars[:,None,:] - nmars[None,:,:]).exp()).sum(dim = 2).permute(1, 0) + + # log_n_fdm = node_flows[sid:eid,:].log() - nmars + # log_n_fdm_max = log_n_fdm.max(dim = 0).values + # n_fdm_sub = (log_n_fdm - log_n_fdm_max[None,:]).exp() + + # pc_eflows_prim = torch.matmul(params.permute(1, 0), n_fdm_sub) * (emars + log_n_fdm_max[None,:]).exp() - assert torch.all(torch.abs(pflows - param_flows) < 0.5) + # print(torch.abs(eflows - eflows_prim).max()) + # print(torch.abs(pflows - pflows_prim).max()) + + ch_eflows.append(eflows) + ch_pflows.append(pflows) + + assert torch.all(torch.abs(pflows - param_flows) < 1e-3 * batch_size) if cs not in ns2flows: - ns2flows[cs] = torch.zeros([num_latents, batch_size]) + ns2flows[cs] = torch.zeros([num_latents, batch_size], device = device) ns2flows[cs] += eflows + ## Run the actual layer ## + + curr_layer_id = -1 + curr_layer = None + for layer_id in range(1, len(pc.inner_layer_groups), 2): + layer = pc.inner_layer_groups[layer_id][0] + if ns in layer.nodes: + curr_layer_id = layer_id + curr_layer = layer + + assert curr_layer is not None + + nsid, neid = ns._output_ind_range + + temp_node_flows[nsid:neid,:] = nflows + temp_param_flows[:] = 0.0 + + pc.inner_layer_groups[curr_layer_id - 1].forward(temp_node_mars, temp_element_mars, _for_backward = True) + + curr_layer.backward(temp_node_flows, temp_element_flows, temp_node_mars, temp_element_mars, temp_params, + param_flows = temp_param_flows, allow_modify_flows = False, propagation_alg = "LL") + + pfsid, pfeid = ns._param_flow_range + + for i, cs in enumerate(ns.chs): + eflows = ch_eflows[i] + pflows = ch_pflows[i] + + csid, ceid = cs._output_ind_range + + # print("value", torch.abs(eflows - temp_element_flows[csid:ceid,:]).max()) + + assert torch.all(torch.abs(eflows - temp_element_flows[csid:ceid,:]) < 1e-3) + assert torch.all(torch.abs(temp_param_flows[pfsid:pfeid].reshape(num_latents, num_latents) - pflows.permute(1, 0)) < batch_size * 1e-4) + + assert cs not in gt_ch_flows + gt_ch_flows[cs] = eflows.detach().clone() + -def hclt_em_test(): +def test_hclt_em(): device = torch.device("cuda:0") @@ -303,7 +589,8 @@ def hclt_em_test(): if __name__ == "__main__": - torch.manual_seed(320942) - hclt_forward_test() - hclt_backward_test() - hclt_em_test() + test_hclt_forward() + test_hclt_single_layer_backward() + test_hclt_backward() + test_hclt_em() + test_hclt_single_layer_backward_general_em() diff --git a/tests/structures/hclt_test.py b/tests/structures/hclt_test.py index 6d93d22b..089fd0c8 100644 --- a/tests/structures/hclt_test.py +++ b/tests/structures/hclt_test.py @@ -5,6 +5,8 @@ from torch.utils.data import TensorDataset, DataLoader import pyjuice.nodes.distributions as dists +import pytest + def evaluate(pc, loader): lls_total = 0.0 @@ -17,7 +19,7 @@ def evaluate(pc, loader): return lls_total -def mini_batch_em_epoch(num_epochs, pc, optimizer, scheduler, train_loader, test_loader, device): +def mini_batch_em_epoch(num_epochs, pc, optimizer, scheduler, train_loader, test_loader, device, logspace_flows = False): for epoch in range(num_epochs): t0 = time.time() train_ll = 0.0 @@ -27,7 +29,10 @@ def mini_batch_em_epoch(num_epochs, pc, optimizer, scheduler, train_loader, test optimizer.zero_grad() lls = pc(x) - lls.mean().backward() + if not logspace_flows: + lls.mean().backward() + else: + pc.backward(x.permute(1, 0), allow_modify_flows = False, logspace_flows = True) train_ll += lls.mean().detach().cpu().numpy().item() @@ -64,7 +69,7 @@ def full_batch_em_epoch(pc, train_loader, test_loader, device): print(f"[train LL: {train_ll:.2f}; test LL: {test_ll:.2f}].....[train forward+backward+step {t1-t0:.2f}; test forward {t2-t1:.2f}] ") -def hclt_test(): +def test_hclt(): device = torch.device("cuda:0") @@ -93,9 +98,10 @@ def hclt_test(): train_data.float().to(device), num_bins = 32, sigma = 0.5 / 32, - num_latents = 256, + num_latents = 128, chunk_size = 32 ) + ns.init_parameters(perturbation = 2.0) pc = juice.TensorCircuit(ns) pc.to(device) @@ -115,38 +121,195 @@ def hclt_test(): lls.mean().backward() break - # for i, batch in enumerate(train_loader): - # x = batch[0].to(device) + mini_batch_em_epoch(5, pc, optimizer, scheduler, train_loader, test_loader, device) - # lls = pc(x, record_cudagraph = False) - # lls.mean().backward() - # if i > 5: - # break + test_ll = evaluate(pc, test_loader) - # from torch.profiler import profile, record_function, ProfilerActivity - # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], with_stack = True) as prof: - # for i, batch in enumerate(train_loader): - # x = batch[0].to(device) + assert test_ll > -785 - # lls = pc(x, record_cudagraph = False) - # lls.mean().backward() - # if i > 5: - # break - # prof.export_chrome_trace("trace3.json") - # # torch.autograd.profiler.tensorboard_trace_to_flame_graph('trace.json', 'flamegraph.svg') - # # prof.export_stacks("trace.txt", "cpu_time_total") - # import pdb; pdb.set_trace() - # exit() +def test_hclt_logspace_flows(): - mini_batch_em_epoch(5, pc, optimizer, scheduler, train_loader, test_loader, device) + device = torch.device("cuda:0") + + train_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = True, download = True) + test_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = False, download = True) + + train_data = train_dataset.data.reshape(60000, 28*28) + test_data = test_dataset.data.reshape(10000, 28*28) + + num_features = train_data.size(1) + + train_loader = DataLoader( + dataset = TensorDataset(train_data), + batch_size = 512, + shuffle = True, + drop_last = True + ) + test_loader = DataLoader( + dataset = TensorDataset(test_data), + batch_size = 512, + shuffle = False, + drop_last = True + ) + + ns = juice.structures.HCLT( + train_data.float().to(device), + num_bins = 32, + sigma = 0.5 / 32, + num_latents = 128, + chunk_size = 32 + ) + ns.init_parameters(perturbation = 2.0) + pc = juice.TensorCircuit(ns) + + pc.to(device) + + optimizer = juice.optim.CircuitOptimizer(pc, lr = 0.1, pseudocount = 0.1) + scheduler = juice.optim.CircuitScheduler( + optimizer, + method = "multi_linear", + lrs = [0.9, 0.1, 0.05], + milestone_steps = [0, len(train_loader) * 100, len(train_loader) * 350] + ) + + for batch in train_loader: + x = batch[0].to(device) + + lls = pc(x, record_cudagraph = True) + lls.mean().backward() + break + + mini_batch_em_epoch(5, pc, optimizer, scheduler, train_loader, test_loader, device, logspace_flows = True) + + test_ll = evaluate(pc, test_loader) + + assert test_ll > -785 + + +@pytest.mark.slow +def test_small_hclt_full(): + + device = torch.device("cuda:0") + + train_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = True, download = True) + test_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = False, download = True) + + train_data = train_dataset.data.reshape(60000, 28*28) + test_data = test_dataset.data.reshape(10000, 28*28) + + num_features = train_data.size(1) + + train_loader = DataLoader( + dataset = TensorDataset(train_data), + batch_size = 512, + shuffle = True, + drop_last = True + ) + test_loader = DataLoader( + dataset = TensorDataset(test_data), + batch_size = 512, + shuffle = False, + drop_last = True + ) + + ns = juice.structures.HCLT( + train_data.float().to(device), + num_bins = 32, + sigma = 0.5 / 32, + num_latents = 32, + chunk_size = 32 + ) + ns.init_parameters(perturbation = 2.0) + pc = juice.TensorCircuit(ns) + + pc.to(device) + + optimizer = juice.optim.CircuitOptimizer(pc, lr = 0.1, pseudocount = 0.1) + scheduler = juice.optim.CircuitScheduler( + optimizer, + method = "multi_linear", + lrs = [0.9, 0.1, 0.05], + milestone_steps = [0, len(train_loader) * 100, len(train_loader) * 350] + ) + + for batch in train_loader: + x = batch[0].to(device) + + lls = pc(x, record_cudagraph = True) + lls.mean().backward() + break + + mini_batch_em_epoch(350, pc, optimizer, scheduler, train_loader, test_loader, device) + full_batch_em_epoch(pc, train_loader, test_loader, device) + + test_ll = evaluate(pc, test_loader) + + assert test_ll > -660 + + +@pytest.mark.slow +def test_large_hclt_full(): + + device = torch.device("cuda:0") + + train_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = True, download = True) + test_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = False, download = True) + + train_data = train_dataset.data.reshape(60000, 28*28) + test_data = test_dataset.data.reshape(10000, 28*28) + + num_features = train_data.size(1) + + train_loader = DataLoader( + dataset = TensorDataset(train_data), + batch_size = 512, + shuffle = True, + drop_last = True + ) + test_loader = DataLoader( + dataset = TensorDataset(test_data), + batch_size = 512, + shuffle = False, + drop_last = True + ) + + ns = juice.structures.HCLT( + train_data.float().to(device), + num_bins = 32, + sigma = 0.5 / 32, + num_latents = 128, + chunk_size = 32 + ) + ns.init_parameters(perturbation = 2.0) + pc = juice.TensorCircuit(ns) + + pc.to(device) + + optimizer = juice.optim.CircuitOptimizer(pc, lr = 0.1, pseudocount = 0.1) + scheduler = juice.optim.CircuitScheduler( + optimizer, + method = "multi_linear", + lrs = [0.9, 0.1, 0.05], + milestone_steps = [0, len(train_loader) * 100, len(train_loader) * 350] + ) + + for batch in train_loader: + x = batch[0].to(device) + + lls = pc(x, record_cudagraph = True) + lls.mean().backward() + break + + mini_batch_em_epoch(350, pc, optimizer, scheduler, train_loader, test_loader, device) + full_batch_em_epoch(pc, train_loader, test_loader, device) test_ll = evaluate(pc, test_loader) - assert test_ll > -770 + assert test_ll > -640 -def hclt_logistic_test(): +def test_hclt_logistic(): device = torch.device("cuda:0") @@ -175,12 +338,10 @@ def hclt_logistic_test(): train_data.float().to(device), num_bins = 32, sigma = 0.5 / 32, - num_latents = 64, + num_latents = 128, chunk_size = 32, - # input_layer_type = dists.Gaussian, - # input_layer_params = {"mu": 0.0, "sigma": 1.0 / 3,"min_sigma": 0.01} - input_layer_type = dists.DiscreteLogistic, - input_layer_params = {"val_range": (-1.0, 1.0), "num_cats": 256} + input_node_type = dists.DiscreteLogistic, + input_node_params = {"val_range": (-1.0, 1.0), "num_cats": 256} ) ns.init_parameters(perturbation = 4.0) pc = juice.TensorCircuit(ns) @@ -203,6 +364,9 @@ def hclt_logistic_test(): if __name__ == "__main__": - torch.manual_seed(3289) - hclt_test() - hclt_logistic_test() + # torch.manual_seed(3289) + # test_hclt() + test_hclt_logspace_flows() + # test_small_hclt_full() + # test_large_hclt_full() + # test_hclt_logistic() diff --git a/tests/structures/hmm_correctness_test.py b/tests/structures/hmm_correctness_test.py index be551b36..4a9df879 100644 --- a/tests/structures/hmm_correctness_test.py +++ b/tests/structures/hmm_correctness_test.py @@ -7,7 +7,7 @@ import pyjuice.nodes.distributions as dists -def hmm_forward_backward_test(): +def test_hmm_forward_backward(): device = torch.device("cuda:0") @@ -204,6 +204,218 @@ def hmm_forward_backward_test(): assert torch.all(torch.abs(pflows - cum_pflows) < 1e-4) +def test_hmm_forward_backward_with_generalized_em(): + + device = torch.device("cuda:0") + + seq_length = 16 + vocab_size = 1023 + batch_size = 32 + + num_node_blocks = 4 # 4096 // 32 # 4 + block_size = 1024 # 32 # 1024 + num_latents = block_size * num_node_blocks + + with juice.set_block_size(block_size = block_size): + ns_input = juice.inputs(seq_length - 1, num_node_blocks = num_node_blocks, + dist = dists.Categorical(num_cats = vocab_size)) + + ns_sum = None + curr_zs = ns_input + for var in range(seq_length - 2, -1, -1): + curr_xs = ns_input.duplicate(var, tie_params = True) + + if ns_sum is None: + ns = juice.summate( + curr_zs, num_node_blocks = num_node_blocks) + ns_sum = ns + else: + ns = ns_sum.duplicate(curr_zs, tie_params=True) + + curr_zs = juice.multiply(curr_xs, ns) + + root_ns = juice.summate(curr_zs, num_node_blocks = 1, block_size = 1) + + root_ns.init_parameters() + + pc = juice.TensorCircuit(root_ns) + pc.to(device) + + data = torch.randint(0, vocab_size, (batch_size, seq_length)).to(device) + data_cpu = data.cpu() + + ## Forward tests ## + + alpha = 1.2 + + lls = pc(data, propagation_alg = "GeneralLL", alpha = alpha) + + ns2mars = dict() + + node_mars = pc.node_mars.detach().cpu() + + with torch.no_grad(): + for ns in root_ns: + if ns.is_input(): + v = ns.scope.to_list()[0] + params = ns.get_source_ns()._params.reshape(num_latents, vocab_size) + + mars = params[:,data_cpu[:,v]].log() + + sid, eid = ns._output_ind_range + assert torch.all(torch.abs(mars - node_mars[sid:eid,:]) < 1e-4) + + ns2mars[ns] = mars + + elif ns.is_prod(): + mars = torch.zeros([num_latents, batch_size]) + for cs in ns.chs: + mars += ns2mars[cs] + + ns2mars[ns] = mars + + elif ns.is_sum() and ns != root_ns: + emars = torch.cat([ns2mars[cs] for cs in ns.chs], dim = 0) + params = ns.get_source_ns()._params.reshape(num_node_blocks, num_node_blocks * ns.num_chs, block_size, block_size).permute(0, 2, 1, 3) + params = params.reshape(num_latents, num_latents * ns.num_chs) + + emars = emars * alpha + params = (params.log() * alpha).exp() + + emars_max = torch.max(emars, dim = 0).values[None,:] + emars = (emars - emars_max).exp() + + nmars = torch.matmul(params, emars) + nmars = nmars.log() + emars_max + + nmars *= (1.0 / alpha) + + sid, eid = ns._output_ind_range + assert torch.all(torch.abs(nmars - node_mars[sid:eid,:]) < 1e-3) + + ns2mars[ns] = nmars + + else: + assert ns == root_ns + + emars = torch.cat([ns2mars[cs] for cs in ns.chs], dim = 0) + params = ns._params.reshape(1, num_node_blocks * ns.num_chs, 1, block_size).permute(0, 2, 1, 3) + params = params.reshape(1, num_latents * ns.num_chs) + + emars = emars * alpha + params = (params.log() * alpha).exp() + + emars_max = torch.max(emars, dim = 0).values[None,:] + emars = (emars - emars_max).exp() + + nmars = torch.matmul(params, emars) + nmars = nmars.log() + emars_max + + nmars *= (1.0 / alpha) + + sid, eid = ns._output_ind_range + assert torch.all(torch.abs(nmars - node_mars[sid:eid,:]) < 1e-3) + + ## Backward tests ## + + pc.backward(data.permute(1, 0), allow_modify_flows = False, + propagation_alg = "GeneralLL", alpha = alpha) + + pc.update_param_flows() + + node_mars = pc.node_mars.cpu() + node_flows = pc.node_flows.cpu() + + ns2flows = dict() + ns2flows[root_ns] = torch.ones([1, batch_size]) + + cum_pflows = 0.0 + + with torch.no_grad(): + for ns in root_ns(reverse = True): + if ns == root_ns: + + sid, eid = ns._output_ind_range + assert torch.all(torch.abs(node_flows[sid:eid,:] - 1.0) < 1e-4) + + nflows = ns2flows[ns] + nmars = node_mars[sid:eid,:] + + for i, cs in enumerate(ns.chs): + params = ns._params.reshape(1, num_node_blocks * ns.num_chs, 1, block_size).permute(0, 2, 1, 3) + params = params[:,:,i*num_node_blocks:(i+1)*num_node_blocks,:].reshape(1, num_latents) + + param_flows = ns._param_flows.reshape(1, num_node_blocks * ns.num_chs, 1, block_size).permute(0, 2, 1, 3) + param_flows = param_flows[:,:,i*num_node_blocks:(i+1)*num_node_blocks,:].reshape(1, num_latents) + + if cs.is_prod(): + emars = torch.zeros([num_latents, batch_size]) + for cns in cs.chs: + sid, eid = cns._output_ind_range + emars += node_mars[sid:eid,:] + else: + raise ValueError() + + eflows = nflows * (params.log() * alpha).exp().permute(1, 0) * ((emars - nmars) * alpha).exp() + pflows = (nflows * params.permute(1, 0) * (emars - nmars).exp()).sum(dim = 1) + + assert torch.all(torch.abs(pflows - param_flows[0,:]) < 1e-4 * batch_size) + + if cs not in ns2flows: + ns2flows[cs] = torch.zeros([num_latents, batch_size]) + ns2flows[cs] += eflows + + elif ns.is_prod(): + nflows = ns2flows[ns] + for cs in ns.chs: + if cs not in ns2flows: + ns2flows[cs] = torch.zeros([num_latents, batch_size]) + ns2flows[cs] += nflows + + elif ns.is_sum(): + + nflows = ns2flows[ns] + + sid, eid = ns._output_ind_range + + assert torch.all(torch.abs(nflows - node_flows[sid:eid,:]) < 1e-5) + + nmars = node_mars[sid:eid,:] + + for i, cs in enumerate(ns.chs): + params = ns.get_source_ns()._params.reshape(num_node_blocks, num_node_blocks * ns.num_chs, block_size, block_size).permute(0, 2, 1, 3) + params = params[:,:,i*num_node_blocks:(i+1)*num_node_blocks,:].reshape(num_latents, num_latents) + + if cs.is_prod(): + emars = torch.zeros([num_latents, batch_size]) + for cns in cs.chs: + sid, eid = cns._output_ind_range + emars += node_mars[sid:eid,:] + else: + raise ValueError() + + log_n_fdm = nflows.log() - nmars * alpha + log_n_fdm_max = log_n_fdm.max(dim = 0).values + n_fdm_sub = (log_n_fdm - log_n_fdm_max[None,:]).exp() + + eflows = torch.matmul((params.log() * alpha).exp().permute(1, 0), n_fdm_sub) * (emars * alpha + log_n_fdm_max[None,:]).exp() + + log_n_fdm = nflows.log() - nmars + log_n_fdm_max = log_n_fdm.max(dim = 0).values + scaled_emars = (emars + log_n_fdm_max[None,:]).exp() + pflows = torch.matmul(n_fdm_sub, scaled_emars.permute(1, 0)) * params + + cum_pflows = cum_pflows + pflows + + if cs not in ns2flows: + ns2flows[cs] = torch.zeros([num_latents, batch_size]) + ns2flows[cs] += eflows + + pflows = ns_sum._param_flows.reshape(num_node_blocks, num_node_blocks, block_size, block_size).permute( + 0, 2, 1, 3).flatten(2, 3).flatten(0, 1) + assert torch.all(torch.abs(pflows - cum_pflows) < 1e-4) + + if __name__ == "__main__": - torch.manual_seed(23289) - hmm_forward_backward_test() + test_hmm_forward_backward() + test_hmm_forward_backward_with_generalized_em() diff --git a/tests/structures/hmm_speed_test.py b/tests/structures/hmm_speed_test.py index 3f6ea6cb..2c9b58bc 100644 --- a/tests/structures/hmm_speed_test.py +++ b/tests/structures/hmm_speed_test.py @@ -3,8 +3,11 @@ import torch import time +import pytest -def hmm_speed_test(): + +@pytest.mark.slow +def test_hmm_speed(): device = torch.device("cuda:0") @@ -82,5 +85,5 @@ def hmm_speed_test(): if __name__ == "__main__": - hmm_speed_test() + test_hmm_speed() diff --git a/tests/structures/logspace_flows_test.py b/tests/structures/logspace_flows_test.py new file mode 100644 index 00000000..4f095d27 --- /dev/null +++ b/tests/structures/logspace_flows_test.py @@ -0,0 +1,170 @@ +import torch +import torchvision +import time +from torch.utils.data import TensorDataset, DataLoader + +import pyjuice as juice +import pyjuice.distributions as dists + + +def logsubexp(x, y): + """ + Compute log(exp(x) - exp(y)) in a numerically stable way. + """ + x, y = torch.maximum(x, y), torch.minimum(x, y) + + # Compute the maximum value between x and y element-wise + max_val = torch.max(x, y) + + # Compute the result using logsumexp trick + result = max_val + torch.log(torch.exp(x - max_val) - torch.exp(y - max_val)) + + return result + + +def test_logspace_hclt_backward(): + + device = torch.device("cuda:0") + + train_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = True, download = True) + + train_data = train_dataset.data.reshape(60000, 28*28)[:5000,:] + + num_features = train_data.size(1) + num_latents = 128 + + root_ns = juice.structures.HCLT( + train_data.float().to(device), + num_bins = 32, + sigma = 0.5 / 32, + num_latents = num_latents, + chunk_size = 32 + ) + root_ns.init_parameters() + + pc = juice.TensorCircuit(root_ns) + + pc.to(device) + + block_size = root_ns.chs[0].block_size + num_blocks = num_latents // block_size + + batch_data = train_data[:512,:].contiguous().to(device) + data_cpu = batch_data.long() + batch_size = batch_data.size(0) + + pc.init_param_flows(flows_memory = 0.0) + + lls = pc(batch_data) + pc.backward(batch_data.permute(1, 0), allow_modify_flows = False, logspace_flows = True) + + pc.update_param_flows() + + node_mars = pc.node_mars + node_flows = pc.node_flows + + temp_node_mars = pc.node_mars.clone() + temp_node_flows = pc.node_flows.clone() + temp_element_mars = pc.element_mars.clone() + temp_element_flows = pc.element_flows.clone() + temp_params = pc.params + temp_param_flows = pc.param_flows.clone() + + ns2flows = dict() + ns2flows[root_ns] = torch.ones([1, batch_size], device = device) + + ch2par = dict() + for ns in root_ns: + for cs in ns.chs: + if cs not in ch2par: + ch2par[cs] = set() + ch2par[cs].add(ns) + + visited = set() + + with torch.no_grad(): + for ns in root_ns(reverse = True): + visited.add(ns) + if ns == root_ns: + + sid, eid = ns._output_ind_range + assert torch.all(torch.abs(node_flows[sid:eid,:] - 0.0) < 1e-4) + + nflows = ns2flows[ns] + nmars = node_mars[sid:eid,:] + + for i, cs in enumerate(ns.chs): + params = ns._params.reshape(1, num_blocks * ns.num_chs, 1, block_size).permute(0, 2, 1, 3).to(device) + params = params[:,:,i*num_blocks:(i+1)*num_blocks,:].reshape(1, num_latents) + + param_flows = ns._param_flows.reshape(1, num_blocks * ns.num_chs, 1, block_size).permute(0, 2, 1, 3).to(device) + param_flows = param_flows[:,:,i*num_blocks:(i+1)*num_blocks,:].reshape(1, num_latents) + + if cs.is_prod(): + emars = torch.zeros([num_latents, batch_size], device = device) + for cns in cs.chs: + sid, eid = cns._output_ind_range + emars += node_mars[sid:eid,:] + else: + raise ValueError() + + eflows = nflows.log() + params.log().permute(1, 0) + emars - nmars + pflows = eflows.exp().sum(dim = 1) + + assert torch.all(torch.abs(pflows - param_flows[0,:]) < 1e-4 * batch_size) + + ns2flows[cs] = eflows + + elif ns.is_prod(): + nflows = ns2flows[ns] + + for cs in ns.chs: + if cs not in ns2flows: + ns2flows[cs] = torch.zeros([num_latents, batch_size], device = device) - float("inf") + ns2flows[cs] = torch.logaddexp(ns2flows[cs], nflows) + + elif ns.is_sum(): + + for par_cs in ch2par[ns]: + assert par_cs in visited + + nflows = ns2flows[ns] + + sid, eid = ns._output_ind_range + + assert torch.all(logsubexp(nflows, node_flows[sid:eid,:]).exp() < 1e-3) + assert (logsubexp(nflows, node_flows[sid:eid,:]).exp() > 1e-5).float().mean() < 0.2 + + nflows = node_flows[sid:eid,:] + + nmars = node_mars[sid:eid,:] + + ch_eflows = [] + + for i, cs in enumerate(ns.chs): + params = ns._params.reshape(num_blocks, num_blocks * ns.num_chs, block_size, block_size).permute(0, 2, 1, 3).to(device) + params = params[:,:,i*num_blocks:(i+1)*num_blocks,:].reshape(num_latents, num_latents) + + param_flows = ns._param_flows.reshape(num_blocks, num_blocks * ns.num_chs, block_size, block_size).permute(0, 2, 1, 3).to(device) + param_flows = param_flows[:,:,i*num_blocks:(i+1)*num_blocks,:].reshape(num_latents, num_latents) + + if cs.is_prod(): + emars = torch.zeros([num_latents, batch_size], device = device) + for cns in cs.chs: + sid, eid = cns._output_ind_range + emars += node_mars[sid:eid,:] + else: + raise ValueError() + + eflows = (nflows[None,:,:] + params.permute(1, 0).log()[:,:,None] + emars[:,None,:] - nmars[None,:,:]).logsumexp(dim = 1) + pflows = (nflows[None,:,:] + params.permute(1, 0).log()[:,:,None] + emars[:,None,:] - nmars[None,:,:]).logsumexp(dim = 2).permute(1, 0).exp() + + ch_eflows.append(eflows) + + assert torch.all(torch.abs(pflows - param_flows) < 1e-4 * batch_size) + + ns2flows[cs] = eflows + + +if __name__ == "__main__": + test_logspace_hclt_backward() diff --git a/tests/structures/pd_hclt_test.py b/tests/structures/pd_hclt_test.py index e9fcc11a..84a18e9c 100644 --- a/tests/structures/pd_hclt_test.py +++ b/tests/structures/pd_hclt_test.py @@ -4,6 +4,8 @@ import time from torch.utils.data import TensorDataset, DataLoader +import pytest + def evaluate(pc, loader): lls_total = 0.0 @@ -65,7 +67,7 @@ def full_batch_em_epoch(pc, train_loader, test_loader, device): print(f"[train LL: {train_ll:.2f}; test LL: {test_ll:.2f}].....[train forward+backward+step {t1-t0:.2f}; test forward {t2-t1:.2f}] ") -def pd_hclt_degenerative_case_test(): +def test_pd_hclt_degenerative_case(): device = torch.device("cuda:0") @@ -118,7 +120,7 @@ def pd_hclt_degenerative_case_test(): assert test_ll > -690.0 -def pd_hclt_test(): +def test_pd_hclt(): device = torch.device("cuda:0") @@ -146,7 +148,7 @@ def pd_hclt_test(): ns = juice.structures.PDHCLT( train_data.cuda(), data_shape = (28, 28), - num_latents = 128, + num_latents = 64, split_intervals = (4, 4), structure_type = "sum_dominated" ) @@ -175,10 +177,10 @@ def pd_hclt_test(): test_ll = evaluate(pc, test_loader) - assert test_ll > -680.0 + assert test_ll > -692.0 if __name__ == "__main__": torch.manual_seed(2391) - pd_hclt_degenerative_case_test() - pd_hclt_test() + test_pd_hclt_degenerative_case() + test_pd_hclt() diff --git a/tests/structures/pd_test.py b/tests/structures/pd_test.py index 4cefa80b..ac1b8b6e 100644 --- a/tests/structures/pd_test.py +++ b/tests/structures/pd_test.py @@ -4,6 +4,8 @@ import time from torch.utils.data import TensorDataset, DataLoader +import pytest + def evaluate(pc, loader): lls_total = 0.0 @@ -65,7 +67,7 @@ def full_batch_em_epoch(pc, train_loader, test_loader, device): print(f"[train LL: {train_ll:.2f}; test LL: {test_ll:.2f}].....[train forward+backward+step {t1-t0:.2f}; test forward {t2-t1:.2f}] ") -def pd_test(): +def test_pd(): device = torch.device("cuda:0") @@ -133,6 +135,52 @@ def pd_test(): assert test_ll > -765.0 +def test_homogeneous_pd(): + + device = torch.device("cuda:0") + + train_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = True, download = True) + test_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = False, download = True) + + train_data = train_dataset.data.reshape(60000, 28*28) + test_data = test_dataset.data.reshape(10000, 28*28) + + num_features = train_data.size(1) + + train_loader = DataLoader( + dataset = TensorDataset(train_data), + batch_size = 512, + shuffle = True, + drop_last = True + ) + test_loader = DataLoader( + dataset = TensorDataset(test_data), + batch_size = 512, + shuffle = False, + drop_last = True + ) + + ns = juice.structures.PD( + data_shape = (28, 28), + num_latents = 256, + split_intervals = (4, 4), + structure_type = "sum_dominated", + tie_homogeneous_params = True + ) + pc = juice.TensorCircuit(ns) + + pc.to(device) + + optimizer = juice.optim.CircuitOptimizer(pc, lr = 0.1, pseudocount = 0.0001) + + mini_batch_em_epoch(10, pc, optimizer, None, train_loader, test_loader, device) + + test_ll = evaluate(pc, test_loader) + + assert test_ll > -780.0 + + if __name__ == "__main__": - torch.manual_seed(2391) - pd_test() + # torch.manual_seed(2391) + test_pd() + test_homogeneous_pd() diff --git a/tests/structures/rat_spn_test.py b/tests/structures/rat_spn_test.py index f616a973..dc220c17 100644 --- a/tests/structures/rat_spn_test.py +++ b/tests/structures/rat_spn_test.py @@ -5,6 +5,8 @@ from torch.utils.data import TensorDataset, DataLoader import pyjuice.nodes.distributions as dists +import pytest + def evaluate(pc, loader): lls_total = 0.0 @@ -64,7 +66,7 @@ def full_batch_em_epoch(pc, train_loader, test_loader, device): print(f"[train LL: {train_ll:.2f}; test LL: {test_ll:.2f}].....[train forward+backward+step {t1-t0:.2f}; test forward {t2-t1:.2f}] ") -def rat_spn_test(): +def test_rat_spn(): device = torch.device("cuda:0") @@ -91,7 +93,7 @@ def rat_spn_test(): ns = juice.structures.RAT_SPN( num_vars = 28 * 28, - num_latents = 256, + num_latents = 64, depth = 5, num_repetitions = 4, num_pieces = 2 @@ -119,9 +121,9 @@ def rat_spn_test(): test_ll = evaluate(pc, test_loader) - assert test_ll > -1015 + assert test_ll > -1020 if __name__ == "__main__": torch.manual_seed(3289) - rat_spn_test() + test_rat_spn() diff --git a/tests/transformations/blockify_test.py b/tests/transformations/blockify_test.py index 97e141ac..aced75af 100644 --- a/tests/transformations/blockify_test.py +++ b/tests/transformations/blockify_test.py @@ -11,7 +11,7 @@ import pytest -def block_test(): +def test_block(): with set_block_size(block_size = 2): @@ -78,7 +78,7 @@ def block_test(): assert torch.all(new_n2._params[0][2:4,2:4] == n2._params[3]) -def block_sparse_block_test(): +def test_block_sparse_block(): with set_block_size(block_size = 4): @@ -159,5 +159,5 @@ def block_sparse_block_test(): if __name__ == "__main__": - block_test() - block_sparse_block_test() + test_block() + test_block_sparse_block() diff --git a/tests/transformations/copy_test.py b/tests/transformations/copy_test.py index cac99794..e3dc9578 100644 --- a/tests/transformations/copy_test.py +++ b/tests/transformations/copy_test.py @@ -10,7 +10,7 @@ import pytest -def copy_test(): +def test_copy(): num_node_blocks_candidates = [2, 4, 7] block_size_candidates = [1, 4, 8] @@ -66,4 +66,4 @@ def copy_test(): if __name__ == "__main__": - copy_test() \ No newline at end of file + test_copy() \ No newline at end of file diff --git a/tests/transformations/merge_test.py b/tests/transformations/merge_test.py index fde18945..d7d95053 100644 --- a/tests/transformations/merge_test.py +++ b/tests/transformations/merge_test.py @@ -10,7 +10,7 @@ import pytest -def sum_nodes_merge_test(): +def test_sum_nodes_merge(): num_node_blocks = 2 for block_size in [1, 2, 4, 8]: @@ -41,7 +41,7 @@ def sum_nodes_merge_test(): assert n_new.chs[0] == m00 -def prod_nodes_merge_test(): +def test_prod_nodes_merge(): num_node_blocks = 2 for block_size in [1, 2, 4, 8]: @@ -70,7 +70,7 @@ def prod_nodes_merge_test(): assert m_new.chs[1] == i10 -def merge_by_region_node_test(): +def test_merge_by_region_node(): num_node_blocks = 2 for block_size in [1, 2, 4, 8]: @@ -121,6 +121,6 @@ def merge_by_region_node_test(): if __name__ == "__main__": - sum_nodes_merge_test() - prod_nodes_merge_test() - merge_by_region_node_test() \ No newline at end of file + test_sum_nodes_merge() + test_prod_nodes_merge() + test_merge_by_region_node() \ No newline at end of file diff --git a/tests/transformations/pruning_test.py b/tests/transformations/pruning_test.py index 51f51072..79e284e5 100644 --- a/tests/transformations/pruning_test.py +++ b/tests/transformations/pruning_test.py @@ -7,7 +7,7 @@ import pytest -def pruning_test(): +def test_pruning(): num_node_blocks = 2 for block_size in [1, 2, 4, 8]: @@ -50,7 +50,7 @@ def pruning_test(): assert torch.all(torch.abs(new_n2._params.sum(dim = 2) - 1.0) < 1e-4) -def pruning_with_param_tying_test(): +def test_pruning_with_param_tying(): num_node_blocks = 2 for block_size in [1, 2, 4, 8]: @@ -91,7 +91,7 @@ def pruning_with_param_tying_test(): assert torch.all(torch.abs(new_n2._source_node._params[0].sum(dim = 1) - 1.0) < 1e-4) -def pruning_by_flow_test(): +def test_pruning_by_flow(): num_nodes = 2 i0 = inputs(0, num_nodes, dists.Categorical(num_cats = 5)) @@ -130,6 +130,6 @@ def pruning_by_flow_test(): if __name__ == "__main__": - pruning_test() - pruning_with_param_tying_test() - pruning_by_flow_test() \ No newline at end of file + test_pruning() + test_pruning_with_param_tying() + test_pruning_by_flow() \ No newline at end of file diff --git a/tests/visualize/plots_test.py b/tests/visualize/plots_test.py index 9006896a..7db7e48f 100644 --- a/tests/visualize/plots_test.py +++ b/tests/visualize/plots_test.py @@ -5,6 +5,8 @@ import pyjuice.visualize as juice_vis +import pytest + def simple_pc_gen(): n0 = inputs(0, num_nodes=256, dist=dists.Categorical(num_cats=5)) @@ -26,19 +28,25 @@ def simple_pc_gen(): ns.init_parameters() return ns -ns = simple_pc_gen() -# case 1 -plt.figure() -juice_vis.plot_pc(ns, node_id=True, node_num_label=True) -plt.show() +def test_plots(): + ns = simple_pc_gen() + + # case 1 + plt.figure() + juice_vis.plot_pc(ns, node_id=True, node_num_label=True) + plt.show() + + # case 2 + juice_vis.plot_tensor_node_connection(ns, node_id=3) + + # case 3 + juice_vis.plot_tensor_node_connection(ns, node_id=4) + plt.show() -# case 2 -juice_vis.plot_tensor_node_connection(ns, node_id=3) + # case 4 + juice_vis.plot_tensor_node_connection(ns, node_id=0) -# case 3 -juice_vis.plot_tensor_node_connection(ns, node_id=4) -plt.show() -# case 4 -juice_vis.plot_tensor_node_connection(ns, node_id=0) +if __name__ == "__main__": + test_plots() \ No newline at end of file