Skip to content

Commit

Permalink
update huge layer runtests
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Nov 9, 2024
1 parent 4232c50 commit 4502f7b
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions tests/model/huge_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,24 @@ def test_huge_model():

device = torch.device("cuda:0")

n_blocks = 100000
block_size = 16
for n_blocks in [100000, 200000, 500000]:
for block_size in [4, 8, 16]:

ns = juice.inputs(var = 0, num_node_blocks = n_blocks, block_size = 8, dist = dists.Categorical(num_cats = 2))
ms = juice.multiply(ns)
ns = juice.summate(ms, num_node_blocks = n_blocks, block_size = block_size,
edge_ids = torch.arange(0, n_blocks)[None,:].repeat(2, 1))
ns = juice.inputs(var = 0, num_node_blocks = n_blocks, block_size = 8, dist = dists.Categorical(num_cats = 2))
ms = juice.multiply(ns)
ns = juice.summate(ms, num_node_blocks = n_blocks, block_size = block_size,
edge_ids = torch.arange(0, n_blocks)[None,:].repeat(2, 1))

pc = juice.compile(ns)
pc.to(device)
pc = juice.compile(ns)
pc.to(device)

x = torch.zeros((16, 1), dtype = torch.long).to(device)

lls = pc(x, propagation_alg = "LL")
pc.backward(x, flows_memory = 1.0, allow_modify_flows = False,
propagation_alg = "LL", logspace_flows = True)
x = torch.zeros((16, 1), dtype = torch.long).to(device)

lls = pc(x, propagation_alg = "LL")
pc.backward(x, flows_memory = 1.0, allow_modify_flows = False,
propagation_alg = "LL", logspace_flows = True)

assert (lls < 0.0).all()


if __name__ == "__main__":
Expand Down

0 comments on commit 4502f7b

Please sign in to comment.