From 4502f7b2040c71f659ef1f8097d49b483f42600c Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Sat, 9 Nov 2024 23:24:37 +0800 Subject: [PATCH] update huge layer runtests --- tests/model/huge_model_test.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/tests/model/huge_model_test.py b/tests/model/huge_model_test.py index 4f3d824..a00a307 100644 --- a/tests/model/huge_model_test.py +++ b/tests/model/huge_model_test.py @@ -13,22 +13,24 @@ def test_huge_model(): device = torch.device("cuda:0") - n_blocks = 100000 - block_size = 16 + for n_blocks in [100000, 200000, 500000]: + for block_size in [4, 8, 16]: - ns = juice.inputs(var = 0, num_node_blocks = n_blocks, block_size = 8, dist = dists.Categorical(num_cats = 2)) - ms = juice.multiply(ns) - ns = juice.summate(ms, num_node_blocks = n_blocks, block_size = block_size, - edge_ids = torch.arange(0, n_blocks)[None,:].repeat(2, 1)) + ns = juice.inputs(var = 0, num_node_blocks = n_blocks, block_size = 8, dist = dists.Categorical(num_cats = 2)) + ms = juice.multiply(ns) + ns = juice.summate(ms, num_node_blocks = n_blocks, block_size = block_size, + edge_ids = torch.arange(0, n_blocks)[None,:].repeat(2, 1)) - pc = juice.compile(ns) - pc.to(device) + pc = juice.compile(ns) + pc.to(device) - x = torch.zeros((16, 1), dtype = torch.long).to(device) - - lls = pc(x, propagation_alg = "LL") - pc.backward(x, flows_memory = 1.0, allow_modify_flows = False, - propagation_alg = "LL", logspace_flows = True) + x = torch.zeros((16, 1), dtype = torch.long).to(device) + + lls = pc(x, propagation_alg = "LL") + pc.backward(x, flows_memory = 1.0, allow_modify_flows = False, + propagation_alg = "LL", logspace_flows = True) + + assert (lls < 0.0).all() if __name__ == "__main__":