From 83694a644f2937cb961ffcdbfb489036c99eb199 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Mon, 18 Mar 2024 01:06:50 +0800 Subject: [PATCH] homogeneous PD --- src/pyjuice/structures/pd.py | 66 +++++++++++++++++++++++++++++------- tests/structures/pd_test.py | 48 +++++++++++++++++++++++++- 2 files changed, 100 insertions(+), 14 deletions(-) diff --git a/src/pyjuice/structures/pd.py b/src/pyjuice/structures/pd.py index a14861ec..05ee7f6e 100644 --- a/src/pyjuice/structures/pd.py +++ b/src/pyjuice/structures/pd.py @@ -24,7 +24,7 @@ def PD(data_shape: Tuple, num_latents: int, input_dist: Optional[Distribution] = None, input_node_type: Type[Distribution] = Categorical, input_node_params: Dict = {"num_cats": 256}, - use_linear_mixing: bool = False, + tie_homogeneous_params: bool = False, block_size: Optional[int] = None): """ Generate PCs with the PD structure (https://arxiv.org/pdf/1202.3732.pdf). @@ -53,6 +53,9 @@ def PD(data_shape: Tuple, num_latents: int, :param input_dist: input distribution :type input_dist: Distribution + :param tie_homogeneous_params: whether to tie parameters of sum/input nodes with compatible structures + :type tie_homogeneous_params: bool + :param block_size: block size :type block_size: int """ @@ -72,6 +75,9 @@ def PD(data_shape: Tuple, num_latents: int, num_axes = len(data_shape) + # A dictionary of source nodes + source_ns_dict = dict() + # Construct split points if split_intervals is not None: if isinstance(split_intervals, int): @@ -117,6 +123,12 @@ def updated_hypercube(hypercube, axis, s = None, e = None): hypercube = (tuple(hypercube[0]), tuple(hypercube[1])) return hypercube + def hypercube2shape(hypercube): + return tuple(hypercube[1][i] - hypercube[0][i] for i in range(len(hypercube[0]))) + + def get_signature(hypercube, *ch_ns): + return (hypercube2shape(hypercube), tuple(len(ns.scope) for ns in ch_ns)) + def create_input_ns(hypercube): scope = hypercube2scope(hypercube) if input_layer_fn is not None: @@ -124,11 +136,28 @@ def create_input_ns(hypercube): else: input_nodes = [] for var in scope: - ns = inputs(var, num_node_blocks = num_node_blocks, dist = input_node_type(**input_node_params)) + if not tie_homogeneous_params: + ns = inputs(var, num_node_blocks = num_node_blocks, dist = input_node_type(**input_node_params)) + else: + if "input" in source_ns_dict: + ns = source_ns_dict["input"].duplicate(var, tie_params = True) + else: + ns = inputs(var, num_node_blocks = num_node_blocks, dist = input_node_type(**input_node_params)) + source_ns_dict["input"] = ns input_nodes.append(ns) edge_ids = torch.arange(0, num_node_blocks)[None,:].repeat(2, 1) - return summate(multiply(*input_nodes), num_node_blocks = num_node_blocks, edge_ids = edge_ids) + pns = multiply(*input_nodes) + if not tie_homogeneous_params: + return summate(pns, num_node_blocks = num_node_blocks, edge_ids = edge_ids) + else: + signature = get_signature(hypercube, pns) + if signature in source_ns_dict: + return source_ns_dict[signature].duplicate(pns, tie_params = True) + else: + ns = summate(pns, num_node_blocks = num_node_blocks, edge_ids = edge_ids) + source_ns_dict[signature] = ns + return ns def recursive_construct(hypercube, depth = 1): if hypercube in hypercube2ns: @@ -161,24 +190,35 @@ def recursive_construct(hypercube, depth = 1): ns = create_input_ns(hypercube) elif hypercube == root_hypercube: ns = summate(*pns, num_node_blocks = 1, block_size = 1) - elif not use_linear_mixing: + else: if len(pns) <= max_prod_block_conns: - ns = summate(*pns, num_node_blocks = num_node_blocks) + if not tie_homogeneous_params: + ns = summate(*pns, num_node_blocks = num_node_blocks) + else: + signature = get_signature(hypercube, *pns) + if signature in source_ns_dict: + ns = source_ns_dict[signature].duplicate(*pns, tie_params = True) + else: + ns = summate(*pns, num_node_blocks = num_node_blocks) + source_ns_dict[signature] = ns else: block_ids = torch.topk(torch.rand([num_node_blocks, len(pns)]), k = max_prod_block_conns, dim = 1).indices par_ids = torch.arange(0, num_node_blocks)[:,None,None].repeat(1, max_prod_block_conns, num_node_blocks) chs_ids = block_ids[:,:,None] * num_node_blocks + torch.arange(0, num_node_blocks)[None,None,:] edge_ids = torch.stack((par_ids.reshape(-1), chs_ids.reshape(-1)), dim = 0) - ns = summate(*pns, num_node_blocks = num_node_blocks, edge_ids = edge_ids) - else: - # Linear mixing as implemented in EiNet's Mixing layer - if len(pns) <= max_prod_block_conns: - ns = summate(*pns, num_node_blocks = num_node_blocks) - else: - ch_ns = [multiply(summate(pn, num_node_blocks = num_node_blocks)) for pn in pns] - ns = summate(*ch_ns, num_node_blocks = num_node_blocks, edge_ids = torch.arange(0, num_node_blocks)[None,:].repeat(2, 1)) + + if not tie_homogeneous_params: + ns = summate(*pns, num_node_blocks = num_node_blocks, edge_ids = edge_ids) + else: + signature = get_signature(hypercube, *pns) + if signature in source_ns_dict: + ns = source_ns_dict[signature].duplicate(*pns, tie_params = True) + else: + ns = summate(*pns, num_node_blocks = num_node_blocks, edge_ids = edge_ids) + source_ns_dict[signature] = ns hypercube2ns[hypercube] = ns + return ns with set_block_size(block_size = block_size): diff --git a/tests/structures/pd_test.py b/tests/structures/pd_test.py index 6fb83592..ac1b8b6e 100644 --- a/tests/structures/pd_test.py +++ b/tests/structures/pd_test.py @@ -135,6 +135,52 @@ def test_pd(): assert test_ll > -765.0 +def test_homogeneous_pd(): + + device = torch.device("cuda:0") + + train_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = True, download = True) + test_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = False, download = True) + + train_data = train_dataset.data.reshape(60000, 28*28) + test_data = test_dataset.data.reshape(10000, 28*28) + + num_features = train_data.size(1) + + train_loader = DataLoader( + dataset = TensorDataset(train_data), + batch_size = 512, + shuffle = True, + drop_last = True + ) + test_loader = DataLoader( + dataset = TensorDataset(test_data), + batch_size = 512, + shuffle = False, + drop_last = True + ) + + ns = juice.structures.PD( + data_shape = (28, 28), + num_latents = 256, + split_intervals = (4, 4), + structure_type = "sum_dominated", + tie_homogeneous_params = True + ) + pc = juice.TensorCircuit(ns) + + pc.to(device) + + optimizer = juice.optim.CircuitOptimizer(pc, lr = 0.1, pseudocount = 0.0001) + + mini_batch_em_epoch(10, pc, optimizer, None, train_loader, test_loader, device) + + test_ll = evaluate(pc, test_loader) + + assert test_ll > -780.0 + + if __name__ == "__main__": - torch.manual_seed(2391) + # torch.manual_seed(2391) test_pd() + test_homogeneous_pd()