From ab8a5376392e83bd9bd17b06edb012aa1da9669a Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Fri, 19 Apr 2024 20:29:23 +0800 Subject: [PATCH] add literal distribution --- src/pyjuice/layer/input_layer.py | 20 ++++-- src/pyjuice/nodes/distributions/__init__.py | 1 + src/pyjuice/nodes/distributions/literal.py | 80 +++++++++++++++++++++ tests/nodes/input_dists_test.py | 29 +++++++- 4 files changed, 122 insertions(+), 8 deletions(-) create mode 100644 src/pyjuice/nodes/distributions/literal.py diff --git a/src/pyjuice/layer/input_layer.py b/src/pyjuice/layer/input_layer.py index fae2c44c..c1926171 100644 --- a/src/pyjuice/layer/input_layer.py +++ b/src/pyjuice/layer/input_layer.py @@ -120,7 +120,7 @@ def __init__(self, nodes: Sequence[InputNodes], cum_nodes: int = 0, pc_num_vars: source_nids = torch.empty([cum_source_ns], dtype = torch.long) # Parameters of this layer - params = torch.empty([self.num_parameters], dtype = torch.float32) + params = torch.empty([max(self.num_parameters, 1)], dtype = torch.float32) n_start = 0 source_n_start = 0 @@ -132,11 +132,17 @@ def __init__(self, nodes: Sequence[InputNodes], cum_nodes: int = 0, pc_num_vars: vids[n_start:n_end,:] = torch.tensor(node_vars[ns_id]).view(1, -1) # `s_pids` and `s_pfids` - pid_offsets = torch.arange(0, ns.num_nodes * ns.dist.num_parameters(), ns.dist.num_parameters()) - s_pids[n_start:n_end] = ns._param_range[0] + pid_offsets + if ns.dist.num_parameters() > 0: + pid_offsets = torch.arange(0, ns.num_nodes * ns.dist.num_parameters(), ns.dist.num_parameters()) + s_pids[n_start:n_end] = ns._param_range[0] + pid_offsets + else: + s_pids[n_start:n_end] = 0 - pfid_offsets = torch.arange(0, ns.num_nodes * ns.dist.num_param_flows(), ns.dist.num_param_flows()) - s_pfids[n_start:n_end] = ns._param_flow_range[0] + pfid_offsets + if ns.dist.num_param_flows() > 0: + pfid_offsets = torch.arange(0, ns.num_nodes * ns.dist.num_param_flows(), ns.dist.num_param_flows()) + s_pfids[n_start:n_end] = ns._param_flow_range[0] + pfid_offsets + else: + s_pfids[n_start:n_end] = 0 # `source_nids` if not ns.is_tied(): @@ -202,9 +208,9 @@ def init_param_flows(self, flows_memory: float = 1.0): or (self.param_flows.dim() == 1 and batch_size > 1) \ or (self.param_flows.dim() == 2 and batch_size != self.param_flows.size(1)): if batch_size == 1: - shape = [self.num_param_flows] + shape = [max(self.num_param_flows, 1)] else: - shape = [self.num_param_flows, batch_size] + shape = [max(self.num_param_flows, 1), batch_size] self.param_flows = torch.zeros(shape, device = self.device) else: assert self.param_flows.size(0) == self.num_param_flows diff --git a/src/pyjuice/nodes/distributions/__init__.py b/src/pyjuice/nodes/distributions/__init__.py index c70f2e73..d4dd55ba 100644 --- a/src/pyjuice/nodes/distributions/__init__.py +++ b/src/pyjuice/nodes/distributions/__init__.py @@ -1,5 +1,6 @@ from .distributions import Distribution from .categorical import Categorical +from .literal import Literal from .bernoulli import Bernoulli from .gaussian import Gaussian from .discrete_logistic import DiscreteLogistic diff --git a/src/pyjuice/nodes/distributions/literal.py b/src/pyjuice/nodes/distributions/literal.py new file mode 100644 index 00000000..49add4aa --- /dev/null +++ b/src/pyjuice/nodes/distributions/literal.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import torch +import triton +import triton.language as tl + +from typing import Optional, Any, Union + +from .distributions import Distribution + + +class Literal(Distribution): + """ + A class representing Literal (indicator) distributions. + """ + def __init__(self, lit: Union[bool,int], p: float = 1.0): + super(Literal, self).__init__() + + self.lit = int(lit) # Convert True/False to 1/0 + self.p = p + + def get_signature(self): + """ + Get the signature of the current distribution. + """ + return "Literal" + + def get_metadata(self): + """ + Get the metadata of the current distribution. + """ + return [self.lit, self.p] + + def num_parameters(self): + """ + The number of parameters per node. + """ + return 0 + + def num_param_flows(self): + """ + The number of parameter flows per node. + """ + return 0 + + def init_parameters(self, num_nodes: int, perturbation: float = 2.0, params: Optional[Any] = None, **kwargs): + """ + Initialize parameters for `num_nodes` nodes. + Returned parameters should be flattened into a vector. + """ + + return torch.zeros([0]) + + @staticmethod + def fw_mar_fn(local_offsets, data, params_ptr, s_pids, metadata_ptr, s_mids_ptr, mask, num_vars_per_node, BLOCK_SIZE): + s_mids = tl.load(s_mids_ptr + local_offsets, mask = mask, other = 0) + lit = tl.load(metadata_ptr + s_mids, mask = mask, other = 0).to(tl.int64) + prob = tl.load(metadata_ptr + s_mids + 1, mask = mask, other = 0).to(tl.int64) + + probs = tl.where(data == lit, prob, 1.0 - prob) + log_probs = tl.log(probs) + + return log_probs + + @staticmethod + def bk_flow_fn(local_offsets, ns_offsets, data, flows, node_mars_ptr, params_ptr, param_flows_ptr, s_pids, s_pfids, metadata_ptr, + s_mids_ptr, mask, num_vars_per_node, BLOCK_SIZE): + pass + + @staticmethod + def sample_fn(samples_ptr, local_offsets, batch_offsets, vids, s_pids, params_ptr, metadata_ptr, s_mids_ptr, mask, batch_size, BLOCK_SIZE, seed): + pass + + @staticmethod + def em_fn(local_offsets, params_ptr, param_flows_ptr, s_pids, s_pfids, metadata_ptr, s_mids_ptr, mask, + step_size, pseudocount, BLOCK_SIZE): + pass + + def _get_constructor(self): + return Literal, {"lit": self.lit, "p": self.p} \ No newline at end of file diff --git a/tests/nodes/input_dists_test.py b/tests/nodes/input_dists_test.py index bf60f0c5..238867d4 100644 --- a/tests/nodes/input_dists_test.py +++ b/tests/nodes/input_dists_test.py @@ -733,8 +733,34 @@ def test_masked_categorical_nodes_rev_range(): assert torch.all(torch.abs(updated_params - pc.input_layer_group[0].params.reshape(8, 8)) < 1e-4) +def test_literal_nodes(): + + n0_pos = inputs(0, num_nodes = 1, dist = dists.Literal(lit = True)) + n0_neg = inputs(0, num_nodes = 1, dist = dists.Literal(lit = False)) + n1_pos = inputs(1, num_nodes = 1, dist = dists.Literal(lit = True)) + n1_neg = inputs(1, num_nodes = 1, dist = dists.Literal(lit = False)) + + pn0 = multiply(n0_pos, n1_neg) + pn1 = multiply(n0_neg, n1_pos) + + ns = summate(pn0, pn1, num_node_blocks = 1, block_size = 1) + ns.set_params(torch.tensor([[0.5, 0.5]])) + + device = torch.device("cuda:0") + + pc = TensorCircuit(ns) + pc.to(device) + + data = torch.tensor([[False, True], [True, False], [True, True]], device = device) + + lls = pc(data) + + probs = lls.exp().view(-1) + assert ((probs.cpu() - torch.tensor([0.5, 0.5, 0.0])).abs() < 1e-4).all() + + if __name__ == "__main__": - # torch.manual_seed(235) + torch.manual_seed(235) test_categorical_nodes() test_bernoulli_nodes() test_gaussian_nodes() @@ -743,3 +769,4 @@ def test_masked_categorical_nodes_rev_range(): test_masked_categorical_nodes_range() test_masked_categorical_nodes_full_mask() test_masked_categorical_nodes_rev_range() + test_literal_nodes()