From edb67ecc3f77593db13899091786fbfad312a248 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Sun, 10 Mar 2024 00:44:41 +0800 Subject: [PATCH] `MPE` and `GeneralLL` for forward pass (block sparse kernels) --- src/pyjuice/layer/layer.py | 17 +++ src/pyjuice/layer/sum_layer.py | 196 ++++++++++++++++++++------- src/pyjuice/model/tensorcircuit.py | 32 ++++- tests/layer/propagation_algs_test.py | 124 +++++++++++++++++ 4 files changed, 318 insertions(+), 51 deletions(-) create mode 100644 tests/layer/propagation_algs_test.py diff --git a/src/pyjuice/layer/layer.py b/src/pyjuice/layer/layer.py index a8585fbf..00677336 100644 --- a/src/pyjuice/layer/layer.py +++ b/src/pyjuice/layer/layer.py @@ -7,6 +7,13 @@ class Layer(): + + propagation_alg_mapping = { + "LL": 0, + "MPE": 1, + "GeneralLL": 2 + } + def __init__(self, nodes: Sequence[CircuitNodes], disable_block_size_check: bool = False) -> None: if disable_block_size_check: @@ -60,3 +67,13 @@ def disable_partial_evaluation(self, forward: bool = True, backward: bool = True def provided(self, var_name): return hasattr(self, var_name) and getattr(self, var_name) is not None + + def _get_propagation_alg_kwargs(self, propagation_alg: str, **kwargs): + if propagation_alg == "LL": + return {} + elif propagation_alg == "MPE": + return {} + elif propagation_alg == "GeneralLL": + return {"alpha": kwargs["alpha"]} + else: + raise ValueError(f"Unknown propagation algorithm {propagation_alg}.") diff --git a/src/pyjuice/layer/sum_layer.py b/src/pyjuice/layer/sum_layer.py index f28597ed..7bd58282 100644 --- a/src/pyjuice/layer/sum_layer.py +++ b/src/pyjuice/layer/sum_layer.py @@ -208,7 +208,8 @@ def num_param_flows(self): return self._layer_pfid_range[1] - self._layer_pfid_range[0] def forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, params: torch.Tensor, - force_use_fp16: bool = False, force_use_fp32: bool = False) -> None: + force_use_fp16: bool = False, force_use_fp32: bool = False, + propagation_alg: str = "LL", **kwargs) -> None: """ Computes the forward pass of a sum layer. @@ -228,7 +229,8 @@ def forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, params: t self._forward( node_mars, element_mars, params, nids, cids, pids, partition_id = partition_id, force_use_fp16 = force_use_fp16, - force_use_fp32 = force_use_fp32 + force_use_fp32 = force_use_fp32, + propagation_alg = propagation_alg, **kwargs ) else: @@ -243,7 +245,8 @@ def forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, params: t node_mars, element_mars, params, nids, cids, pids, local_ids = local_ids, partition_id = partition_id, force_use_fp16 = force_use_fp16, - force_use_fp32 = force_use_fp32 + force_use_fp32 = force_use_fp32, + propagation_alg = propagation_alg, **kwargs ) return None @@ -344,7 +347,8 @@ def _forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, params: torch.Tensor, nids: torch.Tensor, cids: torch.Tensor, pids: torch.Tensor, local_ids: Optional[torch.Tensor] = None, partition_id: int = -1, mode: Optional[str] = None, - force_use_fp16: bool = False, force_use_fp32: bool = False) -> None: + force_use_fp16: bool = False, force_use_fp32: bool = False, + propagation_alg: str = "LL", **kwargs) -> None: """ Forward pass of sum layers. @@ -380,18 +384,19 @@ def _forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, self._forward_block_sparse( node_mars, element_mars, params, nids, cids, pids, local_ids, partition_id = partition_id, force_use_fp16 = force_use_fp16, - force_use_fp32 = force_use_fp32 + force_use_fp32 = force_use_fp32, propagation_alg = propagation_alg, **kwargs ) elif mode == self.SPARSE: self._forward_sparse( node_mars, element_mars, params, nids, cids, pids, local_ids, - partition_id = partition_id + partition_id = partition_id, propagation_alg = propagation_alg, **kwargs ) elif mode == self.PYTORCH: self._forward_pytorch( - node_mars, element_mars, params, nids, cids, pids, local_ids + node_mars, element_mars, params, nids, cids, pids, local_ids, + propagation_alg = propagation_alg, **kwargs ) else: @@ -403,7 +408,8 @@ def _forward(self, node_mars: torch.Tensor, element_mars: torch.Tensor, def _fw_triton_block_sparse_tlmm_kernel(node_mars, element_mars, params, nids, cids_start, cids_increment, pids_start, pids_increment, local_ids, batch_size: tl.constexpr, partial_eval: tl.constexpr, BLOCK_B: tl.constexpr, TILE_SIZE_K: tl.constexpr, K_NUM_TILES: tl.constexpr, - TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, use_fp16: tl.constexpr): + TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, use_fp16: tl.constexpr, + propagation_alg_id: tl.constexpr, alpha = 0.0): pid_b = tl.program_id(0) # ID of size-`BLOCK_B` batches pid_m = tl.program_id(1) # ID of size-`TILE_SIZE_M` nodes @@ -453,22 +459,45 @@ def _fw_triton_block_sparse_tlmm_kernel(node_mars, element_mars, params, nids, c epars = tl.load(epars_ptr) emars = tl.load(emars_ptr, mask = mask_batch[None,:]) - emars_max = tl.max(emars, axis = 0)[None,:] - emars_sub = tl.where(emars_max != -float("inf"), tl.exp(emars - emars_max), 0.0) + if propagation_alg_id == 1: + # MPE propagation method + lpars = tl.log(epars) + nmars = tl.max(lpars[:,:,None] + emars[None,:,:], axis = 1) + + acc = tl.maximum(acc, nmars) - if use_fp16 == 1: - # Built-in matmul kernel of triton + float16 - epars_fp16 = (epars * (2**12)).to(tl.float16) - emars_fp16 = emars_sub.to(tl.float16) - nmars = tl.dot(epars_fp16, emars_fp16).to(tl.float32) / (2**12) else: - # Built-in matmul kernel of triton + float32 - nmars = tl.dot(epars, emars_sub) - acc = tl.where(emars_max > acc, - tl.log(nmars + tl.exp(acc - emars_max)) + emars_max, - tl.log(tl.exp(emars_max - acc) * nmars + 1.0) + acc - ) + if propagation_alg_id == 0: + # LL propagation method + emars_max = tl.max(emars, axis = 0)[None,:] + emars_sub = tl.where(emars_max != -float("inf"), tl.exp(emars - emars_max), 0.0) + + if propagation_alg_id == 2: + # GeneralLL propagation method + + emars_max = tl.max(emars, axis = 0)[None,:] + # Compute p_i^{alpha} for every i + emars_sub = tl.where(emars_max != -float("inf"), tl.exp((emars - emars_max) * alpha), 0.0) + # Compute w_i^{alpha} for every i + epars = tl.exp(tl.log(epars) * alpha) + + # Also scale `emars_max` + emars_max *= alpha + + if use_fp16 == 1: + # Built-in matmul kernel of triton + float16 + epars_fp16 = (epars * (2**12)).to(tl.float16) + emars_fp16 = emars_sub.to(tl.float16) + nmars = tl.dot(epars_fp16, emars_fp16).to(tl.float32) / (2**12) + else: + # Built-in matmul kernel of triton + float32 + nmars = tl.dot(epars, emars_sub) + + acc = tl.where(emars_max > acc, + tl.log(nmars + tl.exp(acc - emars_max)) + emars_max, + tl.log(tl.exp(emars_max - acc) * nmars + 1.0) + acc + ) # Increment `epars_ptr` pids_inc = tl.load(pids_inc_ptr) @@ -480,6 +509,10 @@ def _fw_triton_block_sparse_tlmm_kernel(node_mars, element_mars, params, nids, c emars_ptr += cids_inc[:,None] * batch_size cids_inc_ptr += TILE_SIZE_K + if propagation_alg_id == 2: + # Compute p_i^{1/alpha} + acc *= (1.0 / alpha) + # Write back off_nids = tl.load(nids + nblock_id) offs_nmars = (off_nids + offs_node[:,None]) * batch_size + offs_batch[None,:] @@ -491,7 +524,8 @@ def _fw_triton_block_sparse_tlmm_kernel(node_mars, element_mars, params, nids, c def _fw_triton_block_sparse_csmm1_kernel(node_mars, element_mars, params, nids, cids_start, cids_increment, pids_start, pids_increment, local_ids, batch_size: tl.constexpr, partial_eval: tl.constexpr, BLOCK_B: tl.constexpr, TILE_SIZE_K: tl.constexpr, K_NUM_TILES: tl.constexpr, - TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, use_fp16: tl.constexpr): + TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, use_fp16: tl.constexpr, + propagation_alg_id: tl.constexpr, alpha = 0.0): pid_b = tl.program_id(0) # ID of size-`BLOCK_B` batches pid_m = tl.program_id(1) # ID of size-`TILE_SIZE_M` nodes @@ -541,22 +575,45 @@ def _fw_triton_block_sparse_csmm1_kernel(node_mars, element_mars, params, nids, epars = tl.load(epars_ptr) emars = tl.load(emars_ptr, mask = mask_batch[None,:]) - emars_max = tl.max(emars, axis = 0)[None,:] - emars_sub = tl.where(emars_max != -float("inf"), tl.exp(emars - emars_max), 0.0) + if propagation_alg_id == 1: + # MPE propagation method + lpars = tl.log(epars) + nmars = tl.max(lpars[:,:,None] + emars[None,:,:], axis = 1) + + acc = tl.maximum(acc, nmars) - if use_fp16 == 1: - # Simulated matmul kernel + float16 - epars = (epars * (2**4)).to(tl.float16) - emars_sub = emars_sub.to(tl.float16) - nmars = tl.sum(epars[:,:,None] * emars_sub[None,:,:], axis = 1).to(tl.float32) / (2**4) else: - # Simulated matmul kernel + float32 - nmars = tl.sum(epars[:,:,None] * emars_sub[None,:,:], axis = 1) - acc = tl.where(emars_max > acc, - tl.log(nmars + tl.exp(acc - emars_max)) + emars_max, - tl.log(tl.exp(emars_max - acc) * nmars + 1.0) + acc - ) + if propagation_alg_id == 0: + # LL propagation method + emars_max = tl.max(emars, axis = 0)[None,:] + emars_sub = tl.where(emars_max != -float("inf"), tl.exp(emars - emars_max), 0.0) + + if propagation_alg_id == 2: + # GeneralLL propagation method + + emars_max = tl.max(emars, axis = 0)[None,:] + # Compute p_i^{alpha} for every i + emars_sub = tl.where(emars_max != -float("inf"), tl.exp((emars - emars_max) * alpha), 0.0) + # Compute w_i^{alpha} for every i + epars = tl.exp(tl.log(epars) * alpha) + + # Also scale `emars_max` + emars_max *= alpha + + if use_fp16 == 1: + # Simulated matmul kernel + float16 + epars = (epars * (2**4)).to(tl.float16) + emars_sub = emars_sub.to(tl.float16) + nmars = tl.sum(epars[:,:,None] * emars_sub[None,:,:], axis = 1).to(tl.float32) / (2**4) + else: + # Simulated matmul kernel + float32 + nmars = tl.sum(epars[:,:,None] * emars_sub[None,:,:], axis = 1) + + acc = tl.where(emars_max > acc, + tl.log(nmars + tl.exp(acc - emars_max)) + emars_max, + tl.log(tl.exp(emars_max - acc) * nmars + 1.0) + acc + ) # Increment `epars_ptr` pids_inc = tl.load(pids_inc_ptr) @@ -568,6 +625,10 @@ def _fw_triton_block_sparse_csmm1_kernel(node_mars, element_mars, params, nids, emars_ptr += cids_inc[:,None] * batch_size cids_inc_ptr += TILE_SIZE_K + if propagation_alg_id == 2: + # Compute p_i^{1/alpha} + acc *= (1.0 / alpha) + # Write back off_nids = tl.load(nids + nblock_id) offs_nmars = (off_nids + offs_node[:,None]) * batch_size + offs_batch[None,:] @@ -579,7 +640,8 @@ def _fw_triton_block_sparse_csmm1_kernel(node_mars, element_mars, params, nids, def _fw_triton_block_sparse_csmm2_kernel(node_mars, element_mars, params, nids, cids_start, cids_increment, pids_start, pids_increment, local_ids, batch_size: tl.constexpr, partial_eval: tl.constexpr, BLOCK_B: tl.constexpr, TILE_SIZE_K: tl.constexpr, K_NUM_TILES: tl.constexpr, - TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, use_fp16: tl.constexpr): + TILE_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, use_fp16: tl.constexpr, + propagation_alg_id: tl.constexpr, alpha = 0.0): pid_b = tl.program_id(0) # ID of size-`BLOCK_B` batches pid_m = tl.program_id(1) # ID of size-`TILE_SIZE_M` nodes @@ -629,16 +691,39 @@ def _fw_triton_block_sparse_csmm2_kernel(node_mars, element_mars, params, nids, epars = tl.load(epars_ptr) emars = tl.load(emars_ptr, mask = mask_batch[:,None]) - emars_max = tl.max(emars, axis = 1) - emars_sub = tl.where(emars_max[:,None] != -float("inf"), tl.exp(emars - emars_max[:,None]), 0.0) + if propagation_alg_id == 1: + # MPE propagation method + lpars = tl.log(epars) + nmars = tl.max(lpars[:,:,None] + tl.trans(emars)[None,:,:], axis = 1) - # Simulated matmul kernel + float32 - nmars = tl.sum(epars[:,:,None] * tl.trans(emars_sub)[None,:,:], axis = 1) + acc = tl.maximum(acc, nmars) - acc = tl.where(emars_max[None,:] > acc, - tl.log(nmars + tl.exp(acc - emars_max[None,:])) + emars_max[None,:], - tl.log(tl.exp(emars_max[None,:] - acc) * nmars + 1.0) + acc - ) + else: + + if propagation_alg_id == 0: + # LL propagation method + emars_max = tl.max(emars, axis = 1) + emars_sub = tl.where(emars_max[:,None] != -float("inf"), tl.exp(emars - emars_max[:,None]), 0.0) + + if propagation_alg_id == 2: + # GeneralLL propagation method + + emars_max = tl.max(emars, axis = 1) + # Compute p_i^{alpha} for every i + emars_sub = tl.where(emars_max[:,None] != -float("inf"), tl.exp((emars - emars_max[:,None]) * alpha), 0.0) + # Compute w_i^{alpha} for every i + epars = tl.exp(tl.log(epars) * alpha) + + # Also scale `emars_max` + emars_max *= alpha + + # Simulated matmul kernel + float32 + nmars = tl.sum(epars[:,:,None] * tl.trans(emars_sub)[None,:,:], axis = 1) + + acc = tl.where(emars_max[None,:] > acc, + tl.log(nmars + tl.exp(acc - emars_max[None,:])) + emars_max[None,:], + tl.log(tl.exp(emars_max[None,:] - acc) * nmars + 1.0) + acc + ) # Increment `epars_ptr` pids_inc = tl.load(pids_inc_ptr) @@ -650,6 +735,10 @@ def _fw_triton_block_sparse_csmm2_kernel(node_mars, element_mars, params, nids, emars_ptr += cids_inc[None,:] * batch_size cids_inc_ptr += TILE_SIZE_K + if propagation_alg_id == 2: + # Compute p_i^{1/alpha} + acc *= (1.0 / alpha) + # Write back off_nids = tl.load(nids + nblock_id) offs_nmars = (off_nids + offs_node[:,None]) * batch_size + offs_batch[None,:] @@ -659,7 +748,7 @@ def _forward_block_sparse(self, node_mars: torch.Tensor, element_mars: torch.Ten params: torch.Tensor, nids: torch.Tensor, cids: torch.Tensor, pids: torch.Tensor, local_ids: Optional[torch.Tensor] = None, partition_id: int = -1, force_use_fp16: bool = False, - force_use_fp32: bool = False) -> None: + force_use_fp32: bool = False, propagation_alg: str = "LL", **kwargs) -> None: """ Forward pass of sum layers with the block-sparse processing kernel. @@ -680,6 +769,10 @@ def _forward_block_sparse(self, node_mars: torch.Tensor, element_mars: torch.Ten batch_size = node_mars.size(1) BATCH_SIZE_NP2 = triton.next_power_of_2(batch_size) + # Propagation algorithm + propagation_alg_id = self.propagation_alg_mapping[propagation_alg] + propagation_alg_kwargs = self._get_propagation_alg_kwargs(propagation_alg, **kwargs) + # Heuristic to set `TILE_SIZE_M`, `TILE_SIZE_K`, and `BLOCK_B` base_size = min(self.block_size, num_edges, BATCH_SIZE_NP2, 128) if base_size >= 64: @@ -751,7 +844,9 @@ def _forward_block_sparse(self, node_mars: torch.Tensor, element_mars: torch.Ten K_NUM_TILES = K_NUM_TILES, TILE_SIZE_M = TILE_SIZE_M, BLOCK_SIZE_M = BLOCK_SIZE_M, - use_fp16 = use_fp16 + use_fp16 = use_fp16, + propagation_alg_id = propagation_alg_id, + **propagation_alg_kwargs ) elif TILE_SIZE_M >= 8 and TILE_SIZE_K >= 8 and BLOCK_B >= 8: @@ -772,8 +867,11 @@ def _forward_block_sparse(self, node_mars: torch.Tensor, element_mars: torch.Ten K_NUM_TILES = K_NUM_TILES, TILE_SIZE_M = TILE_SIZE_M, BLOCK_SIZE_M = BLOCK_SIZE_M, - use_fp16 = use_fp16 + use_fp16 = use_fp16, + propagation_alg_id = propagation_alg_id, + **propagation_alg_kwargs ) + else: self._fw_triton_block_sparse_csmm2_kernel[grid]( node_mars, @@ -792,7 +890,9 @@ def _forward_block_sparse(self, node_mars: torch.Tensor, element_mars: torch.Ten K_NUM_TILES = K_NUM_TILES, TILE_SIZE_M = TILE_SIZE_M, BLOCK_SIZE_M = BLOCK_SIZE_M, - use_fp16 = use_fp16 + use_fp16 = use_fp16, + propagation_alg_id = propagation_alg_id, + **propagation_alg_kwargs ) return None diff --git a/src/pyjuice/model/tensorcircuit.py b/src/pyjuice/model/tensorcircuit.py index 2fa826e3..183a6a8f 100644 --- a/src/pyjuice/model/tensorcircuit.py +++ b/src/pyjuice/model/tensorcircuit.py @@ -101,6 +101,10 @@ def __init__(self, root_ns: CircuitNodes, layer_sparsity_tol: float = 0.5, # CudaGraph options self._recorded_cuda_graphs = dict() + # Mode for forward and backward pass + self.default_propagation_alg = "LL" # Could be "LL", "MPE", or "GeneralLL" + self.propagation_alg_kwargs = dict() + def to(self, device): super(TensorCircuit, self).to(device) @@ -115,10 +119,26 @@ def to(self, device): self.par_update_kwargs = par_update_to_device(self.par_update_kwargs, device) return self + + def set_propagation_alg(self, propagation_alg: str, **kwargs): + if propagation_alg == "LL": + self.default_propagation_alg = "LL" + self.propagation_alg_kwargs.clear() + elif propagation_alg == "MPE": + self.default_propagation_alg = "MPE" + self.propagation_alg_kwargs.clear() + elif propagation_alg == "GeneralLL": + assert "alpha" in kwargs, "Argument `alpha` should be provided for the `GeneralLL` propagation algorithm." + self.default_propagation_alg = "GeneralLL" + self.propagation_alg_kwargs.clear() + self.propagation_alg_kwargs["alpha"] = kwargs["alpha"] + else: + raise NotImplementedError(f"Unknown propagation algorithm {propagation_alg}.") def forward(self, inputs: torch.Tensor, input_layer_fn: Optional[Union[str,Callable]] = None, cache: Optional[dict] = None, return_cache: bool = False, record_cudagraph: bool = False, - apply_cudagraph: bool = True, force_use_fp16: bool = False, force_use_fp32: bool = False, **kwargs): + apply_cudagraph: bool = True, force_use_fp16: bool = False, force_use_fp32: bool = False, + propagation_alg: Optional[str] = None, **kwargs): """ Forward evaluation of the PC. @@ -135,6 +155,11 @@ def forward(self, inputs: torch.Tensor, input_layer_fn: Optional[Union[str,Calla assert inputs.dim() == 2 and inputs.size(1) == self.num_vars inputs = inputs.permute(1, 0) + + # Set propagation algorithm + if propagation_alg is None: + propagation_alg = self.default_propagation_alg + kwargs.update(self.propagation_alg_kwargs) ## Initialize buffers for forward pass ## @@ -173,8 +198,9 @@ def _run_inner_layers(): elif layer_group.is_sum(): # Sum layer layer_group(self.node_mars, self.element_mars, self.params, - force_use_fp16 = force_use_fp16, - force_use_fp32 = force_use_fp32) + force_use_fp16 = force_use_fp16, + force_use_fp32 = force_use_fp32, + propagation_alg = propagation_alg, **kwargs) else: raise ValueError(f"Unknown layer type {type(layer)}.") diff --git a/tests/layer/propagation_algs_test.py b/tests/layer/propagation_algs_test.py new file mode 100644 index 00000000..a4113f93 --- /dev/null +++ b/tests/layer/propagation_algs_test.py @@ -0,0 +1,124 @@ +import pyjuice as juice +import torch +import numpy as np +import time +import random + +import pyjuice.nodes.distributions as dists +from pyjuice.utils import BitSet +from pyjuice.nodes import multiply, summate, inputs +from pyjuice.model import TensorCircuit + +from pyjuice.layer import InputLayer, ProdLayer, SumLayer + +import pytest + + +def general_ll_prop_test(): + + device = torch.device("cuda:0") + + batch_size = 16 + + for block_size in [4, 8, 16]: + + with juice.set_block_size(block_size): + + ni0 = inputs(0, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni1 = inputs(1, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni2 = inputs(2, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni3 = inputs(3, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + + np0 = multiply(ni0, ni1) + np1 = multiply(ni2, ni3) + np2 = multiply(ni1, ni2) + + ns0 = summate(np0, num_node_blocks = 2) + ns1 = summate(np1, num_node_blocks = 2) + ns2 = summate(np2, num_node_blocks = 2) + + input_layer = InputLayer([ni0, ni1, ni2, ni3], cum_nodes = block_size) + + prod_layer = ProdLayer([np0, np1, np2]) + + layer = SumLayer([ns0, ns1, ns2], global_nid_start = block_size, + global_pid_start = block_size ** 2, + global_pfid_start = 0, node2tiednodes = dict()) + + layer.to(device) + + ## Forward pass ## + + element_mars = torch.rand([block_size + 3 * 2 * 2 * block_size, batch_size]).log().to(device) + element_mars[:block_size,:] = -float("inf") + node_mars = torch.zeros([block_size + block_size * 2 * 3, batch_size]).to(device) + + params = torch.rand([block_size ** 2 + 3 * 4 * block_size * block_size]).to(device) + + alphas = [1.2, 2.0, 3.0] + + for alpha in alphas: + layer(node_mars, element_mars, params, propagation_alg = "GeneralLL", alpha = alpha) + + for i in range(block_size): + for j in range(6): + cmars = element_mars[layer.partitioned_cids[0][j,:]].exp() + epars = params[layer.partitioned_pids[0][j,:]+i] + scaled_lls = (epars[:,None]**alpha * cmars**alpha).sum(dim = 0).log() * (1.0 / alpha) + assert torch.all(torch.abs(node_mars[(j+1)*block_size+i,:] - scaled_lls) < 1e-3) + + +def mpe_prop_test(): + + device = torch.device("cuda:0") + + batch_size = 16 + + for block_size in [4, 8, 16]: + + with juice.set_block_size(block_size): + + ni0 = inputs(0, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni1 = inputs(1, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni2 = inputs(2, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + ni3 = inputs(3, num_node_blocks = 2, dist = dists.Categorical(num_cats = 2)) + + np0 = multiply(ni0, ni1) + np1 = multiply(ni2, ni3) + np2 = multiply(ni1, ni2) + + ns0 = summate(np0, num_node_blocks = 2) + ns1 = summate(np1, num_node_blocks = 2) + ns2 = summate(np2, num_node_blocks = 2) + + input_layer = InputLayer([ni0, ni1, ni2, ni3], cum_nodes = block_size) + + prod_layer = ProdLayer([np0, np1, np2]) + + layer = SumLayer([ns0, ns1, ns2], global_nid_start = block_size, + global_pid_start = block_size ** 2, + global_pfid_start = 0, node2tiednodes = dict()) + + layer.to(device) + + ## Forward pass ## + + element_mars = torch.rand([block_size + 3 * 2 * 2 * block_size, batch_size]).log().to(device) + element_mars[:block_size,:] = -float("inf") + node_mars = torch.zeros([block_size + block_size * 2 * 3, batch_size]).to(device) + + params = torch.rand([block_size ** 2 + 3 * 4 * block_size * block_size]).to(device) + + layer(node_mars, element_mars, params, propagation_alg = "MPE") + + for i in range(block_size): + for j in range(6): + cmars = element_mars[layer.partitioned_cids[0][j,:]].exp() + epars = params[layer.partitioned_pids[0][j,:]+i] + scaled_lls = (epars[:,None] * cmars).max(dim = 0).values.log() + assert torch.all(torch.abs(node_mars[(j+1)*block_size+i,:] - scaled_lls) < 1e-3) + + +if __name__ == "__main__": + general_ll_prop_test() + mpe_prop_test()