Skip to content

Commit

Permalink
large HMM test
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Oct 14, 2024
1 parent 3686af4 commit 9baeefb
Showing 1 changed file with 40 additions and 0 deletions.
40 changes: 40 additions & 0 deletions tests/structures/large_hmm_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pyjuice as juice
import pyjuice.nodes.distributions as dists
import torch
import time

import pytest


@pytest.mark.slow
def test_large_hmm():

device = torch.device("cuda:0")

seq_length = 16
vocab_size = 1024
num_latents = 65536

root_ns = juice.structures.HMM(
seq_length = seq_length,
num_latents = num_latents,
num_emits = vocab_size,
homogeneous = True
)

pc = juice.compile(root_ns)
pc.print_statistics()

pc.to(device)

data = torch.randint(0, vocab_size, (64, seq_length)).to(device)

lls = pc(data, propagation_alg = "LL")
pc.backward(data, flows_memory = 1.0, allow_modify_flows = False,
propagation_alg = "LL", logspace_flows = True)

import pdb; pdb.set_trace()


if __name__ == "__main__":
test_large_hmm()

0 comments on commit 9baeefb

Please sign in to comment.