Skip to content

Commit

Permalink
add runtests for logspace flows
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Mar 17, 2024
1 parent d1e7c02 commit 7bd6f74
Show file tree
Hide file tree
Showing 2 changed files with 239 additions and 6 deletions.
75 changes: 69 additions & 6 deletions tests/structures/hclt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def evaluate(pc, loader):
return lls_total


def mini_batch_em_epoch(num_epochs, pc, optimizer, scheduler, train_loader, test_loader, device):
def mini_batch_em_epoch(num_epochs, pc, optimizer, scheduler, train_loader, test_loader, device, logspace_flows = False):
for epoch in range(num_epochs):
t0 = time.time()
train_ll = 0.0
Expand All @@ -29,7 +29,10 @@ def mini_batch_em_epoch(num_epochs, pc, optimizer, scheduler, train_loader, test
optimizer.zero_grad()

lls = pc(x)
lls.mean().backward()
if not logspace_flows:
lls.mean().backward()
else:
pc.backward(x.permute(1, 0), allow_modify_flows = False, logspace_flows = True)

train_ll += lls.mean().detach().cpu().numpy().item()

Expand Down Expand Up @@ -125,6 +128,65 @@ def test_hclt():
assert test_ll > -785


def test_hclt_logspace_flows():

device = torch.device("cuda:0")

train_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = True, download = True)
test_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = False, download = True)

train_data = train_dataset.data.reshape(60000, 28*28)
test_data = test_dataset.data.reshape(10000, 28*28)

num_features = train_data.size(1)

train_loader = DataLoader(
dataset = TensorDataset(train_data),
batch_size = 512,
shuffle = True,
drop_last = True
)
test_loader = DataLoader(
dataset = TensorDataset(test_data),
batch_size = 512,
shuffle = False,
drop_last = True
)

ns = juice.structures.HCLT(
train_data.float().to(device),
num_bins = 32,
sigma = 0.5 / 32,
num_latents = 128,
chunk_size = 32
)
ns.init_parameters(perturbation = 2.0)
pc = juice.TensorCircuit(ns)

pc.to(device)

optimizer = juice.optim.CircuitOptimizer(pc, lr = 0.1, pseudocount = 0.1)
scheduler = juice.optim.CircuitScheduler(
optimizer,
method = "multi_linear",
lrs = [0.9, 0.1, 0.05],
milestone_steps = [0, len(train_loader) * 100, len(train_loader) * 350]
)

for batch in train_loader:
x = batch[0].to(device)

lls = pc(x, record_cudagraph = True)
lls.mean().backward()
break

mini_batch_em_epoch(5, pc, optimizer, scheduler, train_loader, test_loader, device, logspace_flows = True)

test_ll = evaluate(pc, test_loader)

assert test_ll > -785


@pytest.mark.slow
def test_small_hclt_full():

Expand Down Expand Up @@ -303,7 +365,8 @@ def test_hclt_logistic():

if __name__ == "__main__":
# torch.manual_seed(3289)
test_hclt()
test_small_hclt_full()
test_large_hclt_full()
test_hclt_logistic()
# test_hclt()
test_hclt_logspace_flows()
# test_small_hclt_full()
# test_large_hclt_full()
# test_hclt_logistic()
170 changes: 170 additions & 0 deletions tests/structures/logspace_flows_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import torch
import torchvision
import time
from torch.utils.data import TensorDataset, DataLoader

import pyjuice as juice
import pyjuice.distributions as dists


def logsubexp(x, y):
"""
Compute log(exp(x) - exp(y)) in a numerically stable way.
"""
x, y = torch.maximum(x, y), torch.minimum(x, y)

# Compute the maximum value between x and y element-wise
max_val = torch.max(x, y)

# Compute the result using logsumexp trick
result = max_val + torch.log(torch.exp(x - max_val) - torch.exp(y - max_val))

return result


def test_logspace_hclt_backward():

device = torch.device("cuda:0")

train_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = True, download = True)

train_data = train_dataset.data.reshape(60000, 28*28)[:5000,:]

num_features = train_data.size(1)
num_latents = 128

root_ns = juice.structures.HCLT(
train_data.float().to(device),
num_bins = 32,
sigma = 0.5 / 32,
num_latents = num_latents,
chunk_size = 32
)
root_ns.init_parameters()

pc = juice.TensorCircuit(root_ns)

pc.to(device)

block_size = root_ns.chs[0].block_size
num_blocks = num_latents // block_size

batch_data = train_data[:512,:].contiguous().to(device)
data_cpu = batch_data.long()
batch_size = batch_data.size(0)

pc.init_param_flows(flows_memory = 0.0)

lls = pc(batch_data)
pc.backward(batch_data.permute(1, 0), allow_modify_flows = False, logspace_flows = True)

pc.update_param_flows()

node_mars = pc.node_mars
node_flows = pc.node_flows

temp_node_mars = pc.node_mars.clone()
temp_node_flows = pc.node_flows.clone()
temp_element_mars = pc.element_mars.clone()
temp_element_flows = pc.element_flows.clone()
temp_params = pc.params
temp_param_flows = pc.param_flows.clone()

ns2flows = dict()
ns2flows[root_ns] = torch.ones([1, batch_size], device = device)

ch2par = dict()
for ns in root_ns:
for cs in ns.chs:
if cs not in ch2par:
ch2par[cs] = set()
ch2par[cs].add(ns)

visited = set()

with torch.no_grad():
for ns in root_ns(reverse = True):
visited.add(ns)
if ns == root_ns:

sid, eid = ns._output_ind_range
assert torch.all(torch.abs(node_flows[sid:eid,:] - 0.0) < 1e-4)

nflows = ns2flows[ns]
nmars = node_mars[sid:eid,:]

for i, cs in enumerate(ns.chs):
params = ns._params.reshape(1, num_blocks * ns.num_chs, 1, block_size).permute(0, 2, 1, 3).to(device)
params = params[:,:,i*num_blocks:(i+1)*num_blocks,:].reshape(1, num_latents)

param_flows = ns._param_flows.reshape(1, num_blocks * ns.num_chs, 1, block_size).permute(0, 2, 1, 3).to(device)
param_flows = param_flows[:,:,i*num_blocks:(i+1)*num_blocks,:].reshape(1, num_latents)

if cs.is_prod():
emars = torch.zeros([num_latents, batch_size], device = device)
for cns in cs.chs:
sid, eid = cns._output_ind_range
emars += node_mars[sid:eid,:]
else:
raise ValueError()

eflows = nflows.log() + params.log().permute(1, 0) + emars - nmars
pflows = eflows.exp().sum(dim = 1)

assert torch.all(torch.abs(pflows - param_flows[0,:]) < 1e-4 * batch_size)

ns2flows[cs] = eflows

elif ns.is_prod():
nflows = ns2flows[ns]

for cs in ns.chs:
if cs not in ns2flows:
ns2flows[cs] = torch.zeros([num_latents, batch_size], device = device) - float("inf")
ns2flows[cs] = torch.logaddexp(ns2flows[cs], nflows)

elif ns.is_sum():

for par_cs in ch2par[ns]:
assert par_cs in visited

nflows = ns2flows[ns]

sid, eid = ns._output_ind_range

assert torch.all(logsubexp(nflows, node_flows[sid:eid,:]).exp() < 1e-3)
assert (logsubexp(nflows, node_flows[sid:eid,:]).exp() > 1e-5).float().mean() < 0.2

nflows = node_flows[sid:eid,:]

nmars = node_mars[sid:eid,:]

ch_eflows = []

for i, cs in enumerate(ns.chs):
params = ns._params.reshape(num_blocks, num_blocks * ns.num_chs, block_size, block_size).permute(0, 2, 1, 3).to(device)
params = params[:,:,i*num_blocks:(i+1)*num_blocks,:].reshape(num_latents, num_latents)

param_flows = ns._param_flows.reshape(num_blocks, num_blocks * ns.num_chs, block_size, block_size).permute(0, 2, 1, 3).to(device)
param_flows = param_flows[:,:,i*num_blocks:(i+1)*num_blocks,:].reshape(num_latents, num_latents)

if cs.is_prod():
emars = torch.zeros([num_latents, batch_size], device = device)
for cns in cs.chs:
sid, eid = cns._output_ind_range
emars += node_mars[sid:eid,:]
else:
raise ValueError()

eflows = (nflows[None,:,:] + params.permute(1, 0).log()[:,:,None] + emars[:,None,:] - nmars[None,:,:]).logsumexp(dim = 1)
pflows = (nflows[None,:,:] + params.permute(1, 0).log()[:,:,None] + emars[:,None,:] - nmars[None,:,:]).logsumexp(dim = 2).permute(1, 0).exp()

ch_eflows.append(eflows)

assert torch.all(torch.abs(pflows - param_flows) < 1e-4 * batch_size)

ns2flows[cs] = eflows


if __name__ == "__main__":
test_logspace_hclt_backward()

0 comments on commit 7bd6f74

Please sign in to comment.