Skip to content

Commit

Permalink
homogeneous PD
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Mar 17, 2024
1 parent 7bd6f74 commit 83694a6
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 14 deletions.
66 changes: 53 additions & 13 deletions src/pyjuice/structures/pd.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def PD(data_shape: Tuple, num_latents: int,
input_dist: Optional[Distribution] = None,
input_node_type: Type[Distribution] = Categorical,
input_node_params: Dict = {"num_cats": 256},
use_linear_mixing: bool = False,
tie_homogeneous_params: bool = False,
block_size: Optional[int] = None):
"""
Generate PCs with the PD structure (https://arxiv.org/pdf/1202.3732.pdf).
Expand Down Expand Up @@ -53,6 +53,9 @@ def PD(data_shape: Tuple, num_latents: int,
:param input_dist: input distribution
:type input_dist: Distribution
:param tie_homogeneous_params: whether to tie parameters of sum/input nodes with compatible structures
:type tie_homogeneous_params: bool
:param block_size: block size
:type block_size: int
"""
Expand All @@ -72,6 +75,9 @@ def PD(data_shape: Tuple, num_latents: int,

num_axes = len(data_shape)

# A dictionary of source nodes
source_ns_dict = dict()

# Construct split points
if split_intervals is not None:
if isinstance(split_intervals, int):
Expand Down Expand Up @@ -117,18 +123,41 @@ def updated_hypercube(hypercube, axis, s = None, e = None):
hypercube = (tuple(hypercube[0]), tuple(hypercube[1]))
return hypercube

def hypercube2shape(hypercube):
return tuple(hypercube[1][i] - hypercube[0][i] for i in range(len(hypercube[0])))

def get_signature(hypercube, *ch_ns):
return (hypercube2shape(hypercube), tuple(len(ns.scope) for ns in ch_ns))

def create_input_ns(hypercube):
scope = hypercube2scope(hypercube)
if input_layer_fn is not None:
return input_layer_fn(scope, num_latents, block_size)
else:
input_nodes = []
for var in scope:
ns = inputs(var, num_node_blocks = num_node_blocks, dist = input_node_type(**input_node_params))
if not tie_homogeneous_params:
ns = inputs(var, num_node_blocks = num_node_blocks, dist = input_node_type(**input_node_params))
else:
if "input" in source_ns_dict:
ns = source_ns_dict["input"].duplicate(var, tie_params = True)
else:
ns = inputs(var, num_node_blocks = num_node_blocks, dist = input_node_type(**input_node_params))
source_ns_dict["input"] = ns
input_nodes.append(ns)

edge_ids = torch.arange(0, num_node_blocks)[None,:].repeat(2, 1)
return summate(multiply(*input_nodes), num_node_blocks = num_node_blocks, edge_ids = edge_ids)
pns = multiply(*input_nodes)
if not tie_homogeneous_params:
return summate(pns, num_node_blocks = num_node_blocks, edge_ids = edge_ids)
else:
signature = get_signature(hypercube, pns)
if signature in source_ns_dict:
return source_ns_dict[signature].duplicate(pns, tie_params = True)
else:
ns = summate(pns, num_node_blocks = num_node_blocks, edge_ids = edge_ids)
source_ns_dict[signature] = ns
return ns

def recursive_construct(hypercube, depth = 1):
if hypercube in hypercube2ns:
Expand Down Expand Up @@ -161,24 +190,35 @@ def recursive_construct(hypercube, depth = 1):
ns = create_input_ns(hypercube)
elif hypercube == root_hypercube:
ns = summate(*pns, num_node_blocks = 1, block_size = 1)
elif not use_linear_mixing:
else:
if len(pns) <= max_prod_block_conns:
ns = summate(*pns, num_node_blocks = num_node_blocks)
if not tie_homogeneous_params:
ns = summate(*pns, num_node_blocks = num_node_blocks)
else:
signature = get_signature(hypercube, *pns)
if signature in source_ns_dict:
ns = source_ns_dict[signature].duplicate(*pns, tie_params = True)
else:
ns = summate(*pns, num_node_blocks = num_node_blocks)
source_ns_dict[signature] = ns
else:
block_ids = torch.topk(torch.rand([num_node_blocks, len(pns)]), k = max_prod_block_conns, dim = 1).indices
par_ids = torch.arange(0, num_node_blocks)[:,None,None].repeat(1, max_prod_block_conns, num_node_blocks)
chs_ids = block_ids[:,:,None] * num_node_blocks + torch.arange(0, num_node_blocks)[None,None,:]
edge_ids = torch.stack((par_ids.reshape(-1), chs_ids.reshape(-1)), dim = 0)
ns = summate(*pns, num_node_blocks = num_node_blocks, edge_ids = edge_ids)
else:
# Linear mixing as implemented in EiNet's Mixing layer
if len(pns) <= max_prod_block_conns:
ns = summate(*pns, num_node_blocks = num_node_blocks)
else:
ch_ns = [multiply(summate(pn, num_node_blocks = num_node_blocks)) for pn in pns]
ns = summate(*ch_ns, num_node_blocks = num_node_blocks, edge_ids = torch.arange(0, num_node_blocks)[None,:].repeat(2, 1))

if not tie_homogeneous_params:
ns = summate(*pns, num_node_blocks = num_node_blocks, edge_ids = edge_ids)
else:
signature = get_signature(hypercube, *pns)
if signature in source_ns_dict:
ns = source_ns_dict[signature].duplicate(*pns, tie_params = True)
else:
ns = summate(*pns, num_node_blocks = num_node_blocks, edge_ids = edge_ids)
source_ns_dict[signature] = ns

hypercube2ns[hypercube] = ns

return ns

with set_block_size(block_size = block_size):
Expand Down
48 changes: 47 additions & 1 deletion tests/structures/pd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,52 @@ def test_pd():
assert test_ll > -765.0


def test_homogeneous_pd():

device = torch.device("cuda:0")

train_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = True, download = True)
test_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = False, download = True)

train_data = train_dataset.data.reshape(60000, 28*28)
test_data = test_dataset.data.reshape(10000, 28*28)

num_features = train_data.size(1)

train_loader = DataLoader(
dataset = TensorDataset(train_data),
batch_size = 512,
shuffle = True,
drop_last = True
)
test_loader = DataLoader(
dataset = TensorDataset(test_data),
batch_size = 512,
shuffle = False,
drop_last = True
)

ns = juice.structures.PD(
data_shape = (28, 28),
num_latents = 256,
split_intervals = (4, 4),
structure_type = "sum_dominated",
tie_homogeneous_params = True
)
pc = juice.TensorCircuit(ns)

pc.to(device)

optimizer = juice.optim.CircuitOptimizer(pc, lr = 0.1, pseudocount = 0.0001)

mini_batch_em_epoch(10, pc, optimizer, None, train_loader, test_loader, device)

test_ll = evaluate(pc, test_loader)

assert test_ll > -780.0


if __name__ == "__main__":
torch.manual_seed(2391)
# torch.manual_seed(2391)
test_pd()
test_homogeneous_pd()

0 comments on commit 83694a6

Please sign in to comment.